330 lines
9.4 KiB
Python
330 lines
9.4 KiB
Python
"""
|
|
Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
|
"""
|
|
import functools
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from .build import NETWORK_REGISTRY
|
|
|
|
|
|
def init_network_weights(model, init_type="normal", gain=0.02):
|
|
|
|
def _init_func(m):
|
|
classname = m.__class__.__name__
|
|
if hasattr(m, "weight") and (
|
|
classname.find("Conv") != -1 or classname.find("Linear") != -1
|
|
):
|
|
if init_type == "normal":
|
|
nn.init.normal_(m.weight.data, 0.0, gain)
|
|
elif init_type == "xavier":
|
|
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
|
elif init_type == "kaiming":
|
|
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
|
elif init_type == "orthogonal":
|
|
nn.init.orthogonal_(m.weight.data, gain=gain)
|
|
else:
|
|
raise NotImplementedError(
|
|
"initialization method {} is not implemented".
|
|
format(init_type)
|
|
)
|
|
if hasattr(m, "bias") and m.bias is not None:
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
elif classname.find("BatchNorm2d") != -1:
|
|
nn.init.constant_(m.weight.data, 1.0)
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
elif classname.find("InstanceNorm2d") != -1:
|
|
if m.weight is not None and m.bias is not None:
|
|
nn.init.constant_(m.weight.data, 1.0)
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
|
|
model.apply(_init_func)
|
|
|
|
|
|
def get_norm_layer(norm_type="instance"):
|
|
if norm_type == "batch":
|
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
|
elif norm_type == "instance":
|
|
norm_layer = functools.partial(
|
|
nn.InstanceNorm2d, affine=False, track_running_stats=False
|
|
)
|
|
elif norm_type == "none":
|
|
norm_layer = None
|
|
else:
|
|
raise NotImplementedError(
|
|
"normalization layer [%s] is not found" % norm_type
|
|
)
|
|
return norm_layer
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
|
|
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
|
super().__init__()
|
|
self.conv_block = self.build_conv_block(
|
|
dim, padding_type, norm_layer, use_dropout, use_bias
|
|
)
|
|
|
|
def build_conv_block(
|
|
self, dim, padding_type, norm_layer, use_dropout, use_bias
|
|
):
|
|
conv_block = []
|
|
p = 0
|
|
if padding_type == "reflect":
|
|
conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == "replicate":
|
|
conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == "zero":
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError(
|
|
"padding [%s] is not implemented" % padding_type
|
|
)
|
|
|
|
conv_block += [
|
|
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
|
norm_layer(dim),
|
|
nn.ReLU(True),
|
|
]
|
|
if use_dropout:
|
|
conv_block += [nn.Dropout(0.5)]
|
|
|
|
p = 0
|
|
if padding_type == "reflect":
|
|
conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == "replicate":
|
|
conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == "zero":
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError(
|
|
"padding [%s] is not implemented" % padding_type
|
|
)
|
|
conv_block += [
|
|
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
|
norm_layer(dim),
|
|
]
|
|
|
|
return nn.Sequential(*conv_block)
|
|
|
|
def forward(self, x):
|
|
return x + self.conv_block(x)
|
|
|
|
|
|
class LocNet(nn.Module):
|
|
"""Localization network."""
|
|
|
|
def __init__(
|
|
self,
|
|
input_nc,
|
|
nc=32,
|
|
n_blocks=3,
|
|
use_dropout=False,
|
|
padding_type="zero",
|
|
image_size=32,
|
|
):
|
|
super().__init__()
|
|
|
|
backbone = []
|
|
backbone += [
|
|
nn.Conv2d(
|
|
input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False
|
|
)
|
|
]
|
|
backbone += [nn.BatchNorm2d(nc)]
|
|
backbone += [nn.ReLU(True)]
|
|
for _ in range(n_blocks):
|
|
backbone += [
|
|
ResnetBlock(
|
|
nc,
|
|
padding_type=padding_type,
|
|
norm_layer=nn.BatchNorm2d,
|
|
use_dropout=use_dropout,
|
|
use_bias=False,
|
|
)
|
|
]
|
|
backbone += [nn.MaxPool2d(2, stride=2)]
|
|
self.backbone = nn.Sequential(*backbone)
|
|
reduced_imsize = int(image_size * 0.5**(n_blocks + 1))
|
|
self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)
|
|
|
|
def forward(self, x):
|
|
x = self.backbone(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc_loc(x)
|
|
x = torch.tanh(x)
|
|
x = x.view(-1, 2, 2)
|
|
theta = x.data.new_zeros(x.size(0), 2, 3)
|
|
theta[:, :, :2] = x
|
|
return theta
|
|
|
|
|
|
class FCN(nn.Module):
|
|
"""Fully convolutional network."""
|
|
|
|
def __init__(
|
|
self,
|
|
input_nc,
|
|
output_nc,
|
|
nc=32,
|
|
n_blocks=3,
|
|
norm_layer=nn.BatchNorm2d,
|
|
use_dropout=False,
|
|
padding_type="reflect",
|
|
gctx=True,
|
|
stn=False,
|
|
image_size=32,
|
|
):
|
|
super().__init__()
|
|
|
|
backbone = []
|
|
|
|
p = 0
|
|
if padding_type == "reflect":
|
|
backbone += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == "replicate":
|
|
backbone += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == "zero":
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError
|
|
backbone += [
|
|
nn.Conv2d(
|
|
input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False
|
|
)
|
|
]
|
|
backbone += [norm_layer(nc)]
|
|
backbone += [nn.ReLU(True)]
|
|
|
|
for _ in range(n_blocks):
|
|
backbone += [
|
|
ResnetBlock(
|
|
nc,
|
|
padding_type=padding_type,
|
|
norm_layer=norm_layer,
|
|
use_dropout=use_dropout,
|
|
use_bias=False,
|
|
)
|
|
]
|
|
self.backbone = nn.Sequential(*backbone)
|
|
|
|
# global context fusion layer
|
|
self.gctx_fusion = None
|
|
if gctx:
|
|
self.gctx_fusion = nn.Sequential(
|
|
nn.Conv2d(
|
|
2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False
|
|
),
|
|
norm_layer(nc),
|
|
nn.ReLU(True),
|
|
)
|
|
|
|
self.regress = nn.Sequential(
|
|
nn.Conv2d(
|
|
nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True
|
|
),
|
|
nn.Tanh(),
|
|
)
|
|
|
|
self.locnet = None
|
|
if stn:
|
|
self.locnet = LocNet(
|
|
input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size
|
|
)
|
|
|
|
def init_loc_layer(self):
|
|
"""Initialize the weights/bias with identity transformation."""
|
|
if self.locnet is not None:
|
|
self.locnet.fc_loc.weight.data.zero_()
|
|
self.locnet.fc_loc.bias.data.copy_(
|
|
torch.tensor([1, 0, 0, 1], dtype=torch.float)
|
|
)
|
|
|
|
def stn(self, x):
|
|
"""Spatial transformer network."""
|
|
theta = self.locnet(x)
|
|
grid = F.affine_grid(theta, x.size())
|
|
return F.grid_sample(x, grid), theta
|
|
|
|
def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
|
|
"""
|
|
Args:
|
|
x (torch.Tensor): input mini-batch.
|
|
lmda (float): multiplier for perturbation.
|
|
return_p (bool): return perturbation.
|
|
return_stn_output (bool): return the output of stn.
|
|
"""
|
|
theta = None
|
|
if self.locnet is not None:
|
|
x, theta = self.stn(x)
|
|
input = x
|
|
|
|
x = self.backbone(x)
|
|
if self.gctx_fusion is not None:
|
|
c = F.adaptive_avg_pool2d(x, (1, 1))
|
|
c = c.expand_as(x)
|
|
x = torch.cat([x, c], 1)
|
|
x = self.gctx_fusion(x)
|
|
|
|
p = self.regress(x)
|
|
x_p = input + lmda*p
|
|
|
|
if return_stn_output:
|
|
return x_p, p, input
|
|
|
|
if return_p:
|
|
return x_p, p
|
|
|
|
return x_p
|
|
|
|
|
|
@NETWORK_REGISTRY.register()
|
|
def fcn_3x32_gctx(**kwargs):
|
|
norm_layer = get_norm_layer(norm_type="instance")
|
|
net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)
|
|
init_network_weights(net, init_type="normal", gain=0.02)
|
|
return net
|
|
|
|
|
|
@NETWORK_REGISTRY.register()
|
|
def fcn_3x64_gctx(**kwargs):
|
|
norm_layer = get_norm_layer(norm_type="instance")
|
|
net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)
|
|
init_network_weights(net, init_type="normal", gain=0.02)
|
|
return net
|
|
|
|
|
|
@NETWORK_REGISTRY.register()
|
|
def fcn_3x32_gctx_stn(image_size=32, **kwargs):
|
|
norm_layer = get_norm_layer(norm_type="instance")
|
|
net = FCN(
|
|
3,
|
|
3,
|
|
nc=32,
|
|
n_blocks=3,
|
|
norm_layer=norm_layer,
|
|
stn=True,
|
|
image_size=image_size
|
|
)
|
|
init_network_weights(net, init_type="normal", gain=0.02)
|
|
net.init_loc_layer()
|
|
return net
|
|
|
|
|
|
@NETWORK_REGISTRY.register()
|
|
def fcn_3x64_gctx_stn(image_size=224, **kwargs):
|
|
norm_layer = get_norm_layer(norm_type="instance")
|
|
net = FCN(
|
|
3,
|
|
3,
|
|
nc=64,
|
|
n_blocks=3,
|
|
norm_layer=norm_layer,
|
|
stn=True,
|
|
image_size=image_size
|
|
)
|
|
init_network_weights(net, init_type="normal", gain=0.02)
|
|
net.init_loc_layer()
|
|
return net
|