release code
This commit is contained in:
329
Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py
Normal file
329
Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
||||
"""
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .build import NETWORK_REGISTRY
|
||||
|
||||
|
||||
def init_network_weights(model, init_type="normal", gain=0.02):
|
||||
|
||||
def _init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, "weight") and (
|
||||
classname.find("Conv") != -1 or classname.find("Linear") != -1
|
||||
):
|
||||
if init_type == "normal":
|
||||
nn.init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == "xavier":
|
||||
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == "kaiming":
|
||||
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
||||
elif init_type == "orthogonal":
|
||||
nn.init.orthogonal_(m.weight.data, gain=gain)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"initialization method {} is not implemented".
|
||||
format(init_type)
|
||||
)
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find("BatchNorm2d") != -1:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find("InstanceNorm2d") != -1:
|
||||
if m.weight is not None and m.bias is not None:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
model.apply(_init_func)
|
||||
|
||||
|
||||
def get_norm_layer(norm_type="instance"):
|
||||
if norm_type == "batch":
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||||
elif norm_type == "instance":
|
||||
norm_layer = functools.partial(
|
||||
nn.InstanceNorm2d, affine=False, track_running_stats=False
|
||||
)
|
||||
elif norm_type == "none":
|
||||
norm_layer = None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"normalization layer [%s] is not found" % norm_type
|
||||
)
|
||||
return norm_layer
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||
super().__init__()
|
||||
self.conv_block = self.build_conv_block(
|
||||
dim, padding_type, norm_layer, use_dropout, use_bias
|
||||
)
|
||||
|
||||
def build_conv_block(
|
||||
self, dim, padding_type, norm_layer, use_dropout, use_bias
|
||||
):
|
||||
conv_block = []
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"padding [%s] is not implemented" % padding_type
|
||||
)
|
||||
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim),
|
||||
nn.ReLU(True),
|
||||
]
|
||||
if use_dropout:
|
||||
conv_block += [nn.Dropout(0.5)]
|
||||
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"padding [%s] is not implemented" % padding_type
|
||||
)
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim),
|
||||
]
|
||||
|
||||
return nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.conv_block(x)
|
||||
|
||||
|
||||
class LocNet(nn.Module):
|
||||
"""Localization network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nc,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
use_dropout=False,
|
||||
padding_type="zero",
|
||||
image_size=32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
backbone = []
|
||||
backbone += [
|
||||
nn.Conv2d(
|
||||
input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False
|
||||
)
|
||||
]
|
||||
backbone += [nn.BatchNorm2d(nc)]
|
||||
backbone += [nn.ReLU(True)]
|
||||
for _ in range(n_blocks):
|
||||
backbone += [
|
||||
ResnetBlock(
|
||||
nc,
|
||||
padding_type=padding_type,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
use_dropout=use_dropout,
|
||||
use_bias=False,
|
||||
)
|
||||
]
|
||||
backbone += [nn.MaxPool2d(2, stride=2)]
|
||||
self.backbone = nn.Sequential(*backbone)
|
||||
reduced_imsize = int(image_size * 0.5**(n_blocks + 1))
|
||||
self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc_loc(x)
|
||||
x = torch.tanh(x)
|
||||
x = x.view(-1, 2, 2)
|
||||
theta = x.data.new_zeros(x.size(0), 2, 3)
|
||||
theta[:, :, :2] = x
|
||||
return theta
|
||||
|
||||
|
||||
class FCN(nn.Module):
|
||||
"""Fully convolutional network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nc,
|
||||
output_nc,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
use_dropout=False,
|
||||
padding_type="reflect",
|
||||
gctx=True,
|
||||
stn=False,
|
||||
image_size=32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
backbone = []
|
||||
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
backbone += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
backbone += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
backbone += [
|
||||
nn.Conv2d(
|
||||
input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False
|
||||
)
|
||||
]
|
||||
backbone += [norm_layer(nc)]
|
||||
backbone += [nn.ReLU(True)]
|
||||
|
||||
for _ in range(n_blocks):
|
||||
backbone += [
|
||||
ResnetBlock(
|
||||
nc,
|
||||
padding_type=padding_type,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=use_dropout,
|
||||
use_bias=False,
|
||||
)
|
||||
]
|
||||
self.backbone = nn.Sequential(*backbone)
|
||||
|
||||
# global context fusion layer
|
||||
self.gctx_fusion = None
|
||||
if gctx:
|
||||
self.gctx_fusion = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False
|
||||
),
|
||||
norm_layer(nc),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
self.regress = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True
|
||||
),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.locnet = None
|
||||
if stn:
|
||||
self.locnet = LocNet(
|
||||
input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size
|
||||
)
|
||||
|
||||
def init_loc_layer(self):
|
||||
"""Initialize the weights/bias with identity transformation."""
|
||||
if self.locnet is not None:
|
||||
self.locnet.fc_loc.weight.data.zero_()
|
||||
self.locnet.fc_loc.bias.data.copy_(
|
||||
torch.tensor([1, 0, 0, 1], dtype=torch.float)
|
||||
)
|
||||
|
||||
def stn(self, x):
|
||||
"""Spatial transformer network."""
|
||||
theta = self.locnet(x)
|
||||
grid = F.affine_grid(theta, x.size())
|
||||
return F.grid_sample(x, grid), theta
|
||||
|
||||
def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): input mini-batch.
|
||||
lmda (float): multiplier for perturbation.
|
||||
return_p (bool): return perturbation.
|
||||
return_stn_output (bool): return the output of stn.
|
||||
"""
|
||||
theta = None
|
||||
if self.locnet is not None:
|
||||
x, theta = self.stn(x)
|
||||
input = x
|
||||
|
||||
x = self.backbone(x)
|
||||
if self.gctx_fusion is not None:
|
||||
c = F.adaptive_avg_pool2d(x, (1, 1))
|
||||
c = c.expand_as(x)
|
||||
x = torch.cat([x, c], 1)
|
||||
x = self.gctx_fusion(x)
|
||||
|
||||
p = self.regress(x)
|
||||
x_p = input + lmda*p
|
||||
|
||||
if return_stn_output:
|
||||
return x_p, p, input
|
||||
|
||||
if return_p:
|
||||
return x_p, p
|
||||
|
||||
return x_p
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x32_gctx(**kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x64_gctx(**kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x32_gctx_stn(image_size=32, **kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(
|
||||
3,
|
||||
3,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
norm_layer=norm_layer,
|
||||
stn=True,
|
||||
image_size=image_size
|
||||
)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
net.init_loc_layer()
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x64_gctx_stn(image_size=224, **kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(
|
||||
3,
|
||||
3,
|
||||
nc=64,
|
||||
n_blocks=3,
|
||||
norm_layer=norm_layer,
|
||||
stn=True,
|
||||
image_size=image_size
|
||||
)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
net.init_loc_layer()
|
||||
return net
|
||||
Reference in New Issue
Block a user