release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .head import HEAD_REGISTRY, build_head
from .network import NETWORK_REGISTRY, build_network
from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone

View File

@@ -0,0 +1,27 @@
from .build import build_backbone, BACKBONE_REGISTRY # isort:skip
from .backbone import Backbone # isort:skip
from .vgg import vgg16
from .resnet import (
resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1,
resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1,
resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123,
resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12,
resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123,
resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123
)
from .alexnet import alexnet
from .mobilenetv2 import mobilenetv2
from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2
from .cnn_digitsdg import cnn_digitsdg
from .efficientnet import (
efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3,
efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
)
from .shufflenetv2 import (
shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5,
shufflenet_v2_x2_0
)
from .cnn_digitsingle import cnn_digitsingle
from .preact_resnet18 import preact_resnet18
from .cnn_digit5_m3sda import cnn_digit5_m3sda

View File

@@ -0,0 +1,64 @@
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth",
}
class AlexNet(Backbone):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
# Note that self.classifier outputs features rather than logits
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
)
self._out_features = 4096
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return self.classifier(x)
def init_pretrained_weights(model, model_url):
pretrain_dict = model_zoo.load_url(model_url)
model.load_state_dict(pretrain_dict, strict=False)
@BACKBONE_REGISTRY.register()
def alexnet(pretrained=True, **kwargs):
model = AlexNet()
if pretrained:
init_pretrained_weights(model, model_urls["alexnet"])
return model

View File

@@ -0,0 +1,17 @@
import torch.nn as nn
class Backbone(nn.Module):
def __init__(self):
super().__init__()
def forward(self):
pass
@property
def out_features(self):
"""Output feature dimension."""
if self.__dict__.get("_out_features") is None:
return None
return self._out_features

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
BACKBONE_REGISTRY = Registry("BACKBONE")
def build_backbone(name, verbose=True, **kwargs):
avai_backbones = BACKBONE_REGISTRY.registered_names()
check_availability(name, avai_backbones)
if verbose:
print("Backbone: {}".format(name))
return BACKBONE_REGISTRY.get(name)(**kwargs)

View File

@@ -0,0 +1,58 @@
"""
Reference
https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA
"""
import torch.nn as nn
from torch.nn import functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class FeatureExtractor(Backbone):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
self.bn3 = nn.BatchNorm2d(128)
self.fc1 = nn.Linear(8192, 3072)
self.bn1_fc = nn.BatchNorm1d(3072)
self.fc2 = nn.Linear(3072, 2048)
self.bn2_fc = nn.BatchNorm1d(2048)
self._out_features = 2048
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(x.size(0), 8192)
x = F.relu(self.bn1_fc(self.fc1(x)))
x = F.dropout(x, training=self.training)
x = F.relu(self.bn2_fc(self.fc2(x)))
return x
@BACKBONE_REGISTRY.register()
def cnn_digit5_m3sda(**kwargs):
"""
This architecture was used for the Digit-5 dataset in:
- Peng et al. Moment Matching for Multi-Source
Domain Adaptation. ICCV 2019.
"""
return FeatureExtractor()

View File

@@ -0,0 +1,61 @@
import torch.nn as nn
from torch.nn import functional as F
from dassl.utils import init_network_weights
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class Convolution(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1)
self.relu = nn.ReLU(True)
def forward(self, x):
return self.relu(self.conv(x))
class ConvNet(Backbone):
def __init__(self, c_hidden=64):
super().__init__()
self.conv1 = Convolution(3, c_hidden)
self.conv2 = Convolution(c_hidden, c_hidden)
self.conv3 = Convolution(c_hidden, c_hidden)
self.conv4 = Convolution(c_hidden, c_hidden)
self._out_features = 2**2 * c_hidden
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = self.conv1(x)
x = F.max_pool2d(x, 2)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.conv3(x)
x = F.max_pool2d(x, 2)
x = self.conv4(x)
x = F.max_pool2d(x, 2)
return x.view(x.size(0), -1)
@BACKBONE_REGISTRY.register()
def cnn_digitsdg(**kwargs):
"""
This architecture was used for DigitsDG dataset in:
- Zhou et al. Deep Domain-Adversarial Image Generation
for Domain Generalisation. AAAI 2020.
"""
model = ConvNet(c_hidden=64)
init_network_weights(model, init_type="kaiming")
return model

View File

@@ -0,0 +1,56 @@
"""
This model is built based on
https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py
"""
import torch.nn as nn
from torch.nn import functional as F
from dassl.utils import init_network_weights
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class CNN(Backbone):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 5)
self.conv2 = nn.Conv2d(64, 128, 5)
self.fc3 = nn.Linear(5 * 5 * 128, 1024)
self.fc4 = nn.Linear(1024, 1024)
self._out_features = 1024
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = self.fc3(x)
x = F.relu(x)
x = self.fc4(x)
x = F.relu(x)
return x
@BACKBONE_REGISTRY.register()
def cnn_digitsingle(**kwargs):
model = CNN()
init_network_weights(model, init_type="kaiming")
return model

View File

@@ -0,0 +1,12 @@
"""
Source: https://github.com/lukemelas/EfficientNet-PyTorch.
"""
__version__ = "0.6.4"
from .model import (
EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2,
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6,
efficientnet_b7
)
from .utils import (
BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params
)

View File

@@ -0,0 +1,371 @@
import torch
from torch import nn
from torch.nn import functional as F
from .utils import (
Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats,
get_model_params, efficientnet_params, get_same_padding_conv2d,
load_pretrained_weights, calculate_output_image_size
)
from ..build import BACKBONE_REGISTRY
from ..backbone import Backbone
class MBConvBlock(nn.Module):
"""
Mobile Inverted Residual Bottleneck Block
Args:
block_args (namedtuple): BlockArgs, see above
global_params (namedtuple): GlobalParam, see above
Attributes:
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
"""
def __init__(self, block_args, global_params, image_size=None):
super().__init__()
self._block_args = block_args
self._bn_mom = 1 - global_params.batch_norm_momentum
self._bn_eps = global_params.batch_norm_epsilon
self.has_se = (self._block_args.se_ratio
is not None) and (0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip # skip connection and drop connect
# Expansion phase
inp = self._block_args.input_filters # number of input channels
oup = (
self._block_args.input_filters * self._block_args.expand_ratio
) # number of output channels
if self._block_args.expand_ratio != 1:
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._expand_conv = Conv2d(
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
)
self._bn0 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
)
# image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing
# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._depthwise_conv = Conv2d(
in_channels=oup,
out_channels=oup,
groups=oup, # groups makes it depthwise
kernel_size=k,
stride=s,
bias=False,
)
self._bn1 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
)
image_size = calculate_output_image_size(image_size, s)
# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(
1,
int(
self._block_args.input_filters * self._block_args.se_ratio
)
)
self._se_reduce = Conv2d(
in_channels=oup,
out_channels=num_squeezed_channels,
kernel_size=1
)
self._se_expand = Conv2d(
in_channels=num_squeezed_channels,
out_channels=oup,
kernel_size=1
)
# Output phase
final_oup = self._block_args.output_filters
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._project_conv = Conv2d(
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
)
self._bn2 = nn.BatchNorm2d(
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
)
self._swish = MemoryEfficientSwish()
def forward(self, inputs, drop_connect_rate=None):
"""
:param inputs: input tensor
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
:return: output of block
"""
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self._swish(self._bn1(self._depthwise_conv(x)))
# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_expand(
self._swish(self._se_reduce(x_squeezed))
)
x = torch.sigmoid(x_squeezed) * x
x = self._bn2(self._project_conv(x))
# Skip connection and drop connect
input_filters, output_filters = (
self._block_args.input_filters,
self._block_args.output_filters,
)
if (
self.id_skip and self._block_args.stride == 1
and input_filters == output_filters
):
if drop_connect_rate:
x = drop_connect(
x, p=drop_connect_rate, training=self.training
)
x = x + inputs # skip connection
return x
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
class EfficientNet(Backbone):
"""
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
Args:
blocks_args (list): A list of BlockArgs to construct blocks
global_params (namedtuple): A set of GlobalParams shared between blocks
Example:
model = EfficientNet.from_pretrained('efficientnet-b0')
"""
def __init__(self, blocks_args=None, global_params=None):
super().__init__()
assert isinstance(blocks_args, list), "blocks_args should be a list"
assert len(blocks_args) > 0, "block args must be greater than 0"
self._global_params = global_params
self._blocks_args = blocks_args
# Batch norm parameters
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
# Get stem static or dynamic convolution depending on image size
image_size = global_params.image_size
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
# Stem
in_channels = 3 # rgb
out_channels = round_filters(
32, self._global_params
) # number of output channels
self._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)
self._bn0 = nn.BatchNorm2d(
num_features=out_channels, momentum=bn_mom, eps=bn_eps
)
image_size = calculate_output_image_size(image_size, 2)
# Build blocks
self._blocks = nn.ModuleList([])
for block_args in self._blocks_args:
# Update block input and output filters based on depth multiplier.
block_args = block_args._replace(
input_filters=round_filters(
block_args.input_filters, self._global_params
),
output_filters=round_filters(
block_args.output_filters, self._global_params
),
num_repeat=round_repeats(
block_args.num_repeat, self._global_params
),
)
# The first block needs to take care of stride and filter size increase.
self._blocks.append(
MBConvBlock(
block_args, self._global_params, image_size=image_size
)
)
image_size = calculate_output_image_size(
image_size, block_args.stride
)
if block_args.num_repeat > 1:
block_args = block_args._replace(
input_filters=block_args.output_filters, stride=1
)
for _ in range(block_args.num_repeat - 1):
self._blocks.append(
MBConvBlock(
block_args, self._global_params, image_size=image_size
)
)
# image_size = calculate_output_image_size(image_size, block_args.stride) # ?
# Head
in_channels = block_args.output_filters # output of final block
out_channels = round_filters(1280, self._global_params)
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._conv_head = Conv2d(
in_channels, out_channels, kernel_size=1, bias=False
)
self._bn1 = nn.BatchNorm2d(
num_features=out_channels, momentum=bn_mom, eps=bn_eps
)
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
# self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()
self._out_features = out_channels
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
for block in self._blocks:
block.set_swish(memory_efficient)
def extract_features(self, inputs):
"""Returns output of the final convolution layer"""
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))
# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
# Head
x = self._swish(self._bn1(self._conv_head(x)))
return x
def forward(self, inputs):
"""
Calls extract_features to extract features, applies
final linear layer, and returns logits.
"""
bs = inputs.size(0)
# Convolution layers
x = self.extract_features(inputs)
# Pooling and final linear layer
x = self._avg_pooling(x)
x = x.view(bs, -1)
x = self._dropout(x)
# x = self._fc(x)
return x
@classmethod
def from_name(cls, model_name, override_params=None):
cls._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(
model_name, override_params
)
return cls(blocks_args, global_params)
@classmethod
def from_pretrained(
cls, model_name, advprop=False, num_classes=1000, in_channels=3
):
model = cls.from_name(
model_name, override_params={"num_classes": num_classes}
)
load_pretrained_weights(
model, model_name, load_fc=(num_classes == 1000), advprop=advprop
)
model._change_in_channels(in_channels)
return model
@classmethod
def get_image_size(cls, model_name):
cls._check_model_name_is_valid(model_name)
_, _, res, _ = efficientnet_params(model_name)
return res
@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name."""
valid_models = ["efficientnet-b" + str(i) for i in range(9)]
if model_name not in valid_models:
raise ValueError(
"model_name should be one of: " + ", ".join(valid_models)
)
def _change_in_channels(model, in_channels):
if in_channels != 3:
Conv2d = get_same_padding_conv2d(
image_size=model._global_params.image_size
)
out_channels = round_filters(32, model._global_params)
model._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)
def build_efficientnet(name, pretrained):
if pretrained:
return EfficientNet.from_pretrained("efficientnet-{}".format(name))
else:
return EfficientNet.from_name("efficientnet-{}".format(name))
@BACKBONE_REGISTRY.register()
def efficientnet_b0(pretrained=True, **kwargs):
return build_efficientnet("b0", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b1(pretrained=True, **kwargs):
return build_efficientnet("b1", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b2(pretrained=True, **kwargs):
return build_efficientnet("b2", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b3(pretrained=True, **kwargs):
return build_efficientnet("b3", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b4(pretrained=True, **kwargs):
return build_efficientnet("b4", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b5(pretrained=True, **kwargs):
return build_efficientnet("b5", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b6(pretrained=True, **kwargs):
return build_efficientnet("b6", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b7(pretrained=True, **kwargs):
return build_efficientnet("b7", pretrained)

View File

@@ -0,0 +1,477 @@
"""
This file contains helper functions for building the model and for loading model parameters.
These helper functions are built to mirror those in the official TensorFlow implementation.
"""
import re
import math
import collections
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
########################################################################
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
########################################################################
# Parameters for the entire model (stem, all blocks, and head)
GlobalParams = collections.namedtuple(
"GlobalParams",
[
"batch_norm_momentum",
"batch_norm_epsilon",
"dropout_rate",
"num_classes",
"width_coefficient",
"depth_coefficient",
"depth_divisor",
"min_depth",
"drop_connect_rate",
"image_size",
],
)
# Parameters for an individual model block
BlockArgs = collections.namedtuple(
"BlockArgs",
[
"kernel_size",
"num_repeat",
"input_filters",
"output_filters",
"expand_ratio",
"id_skip",
"stride",
"se_ratio",
],
)
# Change namedtuple defaults
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
def round_filters(filters, global_params):
"""Calculate and round number of filters based on depth multiplier."""
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
min_depth = min_depth or divisor
new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor
return int(new_filters)
def round_repeats(repeats, global_params):
"""Round number of filters based on depth multiplier."""
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
def drop_connect(inputs, p, training):
"""Drop connect."""
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
random_tensor += torch.rand(
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
)
binary_tensor = torch.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def get_same_padding_conv2d(image_size=None):
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
Static padding is necessary for ONNX exporting of models."""
if image_size is None:
return Conv2dDynamicSamePadding
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)
def get_width_and_height_from_size(x):
"""Obtains width and height from a int or tuple"""
if isinstance(x, int):
return x, x
if isinstance(x, list) or isinstance(x, tuple):
return x
else:
raise TypeError()
def calculate_output_image_size(input_image_size, stride):
"""
Calculates the output image size when using Conv2dSamePadding with a stride.
Necessary for static padding. Thanks to mannatsingh for pointing this out.
"""
if input_image_size is None:
return None
image_height, image_width = get_width_and_height_from_size(
input_image_size
)
stride = stride if isinstance(stride, int) else stride[0]
image_height = int(math.ceil(image_height / stride))
image_width = int(math.ceil(image_width / stride))
return [image_height, image_width]
class Conv2dDynamicSamePadding(nn.Conv2d):
"""2D Convolutions like TensorFlow, for a dynamic image size"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
):
super().__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation,
groups, bias
)
self.stride = self.stride if len(self.stride
) == 2 else [self.stride[0]] * 2
def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max(
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
)
pad_w = max(
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
)
if pad_h > 0 or pad_w > 0:
x = F.pad(
x,
[pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2]
)
return F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
class Conv2dStaticSamePadding(nn.Conv2d):
"""2D Convolutions like TensorFlow, for a fixed image size"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
image_size=None,
**kwargs
):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.stride = self.stride if len(self.stride
) == 2 else [self.stride[0]] * 2
# Calculate padding based on image size and save it
assert image_size is not None
ih, iw = (image_size,
image_size) if isinstance(image_size, int) else image_size
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max(
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
)
pad_w = max(
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
)
if pad_h > 0 or pad_w > 0:
self.static_padding = nn.ZeroPad2d(
(pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2)
)
else:
self.static_padding = Identity()
def forward(self, x):
x = self.static_padding(x)
x = F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return x
class Identity(nn.Module):
def __init__(self, ):
super(Identity, self).__init__()
def forward(self, input):
return input
########################################################################
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
########################################################################
def efficientnet_params(model_name):
"""Map EfficientNet model name to parameter coefficients."""
params_dict = {
# Coefficients: width,depth,res,dropout
"efficientnet-b0": (1.0, 1.0, 224, 0.2),
"efficientnet-b1": (1.0, 1.1, 240, 0.2),
"efficientnet-b2": (1.1, 1.2, 260, 0.3),
"efficientnet-b3": (1.2, 1.4, 300, 0.3),
"efficientnet-b4": (1.4, 1.8, 380, 0.4),
"efficientnet-b5": (1.6, 2.2, 456, 0.4),
"efficientnet-b6": (1.8, 2.6, 528, 0.5),
"efficientnet-b7": (2.0, 3.1, 600, 0.5),
"efficientnet-b8": (2.2, 3.6, 672, 0.5),
"efficientnet-l2": (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]
class BlockDecoder(object):
"""Block Decoder for readability, straight from the official TensorFlow repository"""
@staticmethod
def _decode_block_string(block_string):
"""Gets a block through a string notation of arguments."""
assert isinstance(block_string, str)
ops = block_string.split("_")
options = {}
for op in ops:
splits = re.split(r"(\d.*)", op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# Check stride
assert ("s" in options and len(options["s"]) == 1) or (
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
)
return BlockArgs(
kernel_size=int(options["k"]),
num_repeat=int(options["r"]),
input_filters=int(options["i"]),
output_filters=int(options["o"]),
expand_ratio=int(options["e"]),
id_skip=("noskip" not in block_string),
se_ratio=float(options["se"]) if "se" in options else None,
stride=[int(options["s"][0])],
)
@staticmethod
def _encode_block_string(block):
"""Encodes a block to a string."""
args = [
"r%d" % block.num_repeat,
"k%d" % block.kernel_size,
"s%d%d" % (block.strides[0], block.strides[1]),
"e%s" % block.expand_ratio,
"i%d" % block.input_filters,
"o%d" % block.output_filters,
]
if 0 < block.se_ratio <= 1:
args.append("se%s" % block.se_ratio)
if block.id_skip is False:
args.append("noskip")
return "_".join(args)
@staticmethod
def decode(string_list):
"""
Decodes a list of string notations to specify blocks inside the network.
:param string_list: a list of strings, each string is a notation of block
:return: a list of BlockArgs namedtuples of block args
"""
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
return blocks_args
@staticmethod
def encode(blocks_args):
"""
Encodes a list of BlockArgs to a list of strings.
:param blocks_args: a list of BlockArgs namedtuples of block args
:return: a list of strings, each string is a notation of block
"""
block_strings = []
for block in blocks_args:
block_strings.append(BlockDecoder._encode_block_string(block))
return block_strings
def efficientnet(
width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
drop_connect_rate=0.2,
image_size=None,
num_classes=1000,
):
"""Creates a efficientnet model."""
blocks_args = [
"r1_k3_s11_e1_i32_o16_se0.25",
"r2_k3_s22_e6_i16_o24_se0.25",
"r2_k5_s22_e6_i24_o40_se0.25",
"r3_k3_s22_e6_i40_o80_se0.25",
"r3_k5_s11_e6_i80_o112_se0.25",
"r4_k5_s22_e6_i112_o192_se0.25",
"r1_k3_s11_e6_i192_o320_se0.25",
]
blocks_args = BlockDecoder.decode(blocks_args)
global_params = GlobalParams(
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
drop_connect_rate=drop_connect_rate,
# data_format='channels_last', # removed, this is always true in PyTorch
num_classes=num_classes,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
image_size=image_size,
)
return blocks_args, global_params
def get_model_params(model_name, override_params):
"""Get the block args and global params for a given model"""
if model_name.startswith("efficientnet"):
w, d, s, p = efficientnet_params(model_name)
# note: all models have drop connect rate = 0.2
blocks_args, global_params = efficientnet(
width_coefficient=w,
depth_coefficient=d,
dropout_rate=p,
image_size=s
)
else:
raise NotImplementedError(
"model name is not pre-defined: %s" % model_name
)
if override_params:
# ValueError will be raised here if override_params has fields not included in global_params.
global_params = global_params._replace(**override_params)
return blocks_args, global_params
url_map = {
"efficientnet-b0":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
"efficientnet-b1":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
"efficientnet-b2":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
"efficientnet-b3":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
"efficientnet-b4":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
"efficientnet-b5":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
"efficientnet-b6":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
"efficientnet-b7":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
}
url_map_advprop = {
"efficientnet-b0":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
"efficientnet-b1":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
"efficientnet-b2":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
"efficientnet-b3":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
"efficientnet-b4":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
"efficientnet-b5":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
"efficientnet-b6":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
"efficientnet-b7":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
"efficientnet-b8":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
}
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
"""Loads pretrained weights, and downloads if loading for the first time."""
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map_advprop if advprop else url_map
state_dict = model_zoo.load_url(url_map_[model_name])
model.load_state_dict(state_dict, strict=False)
"""
if load_fc:
model.load_state_dict(state_dict)
else:
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
res = model.load_state_dict(state_dict, strict=False)
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
print('Loaded pretrained weights for {}'.format(model_name))
"""

View File

@@ -0,0 +1,217 @@
import torch.utils.model_zoo as model_zoo
from torch import nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"mobilenet_v2":
"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
}
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor/2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(
self, in_planes, out_planes, kernel_size=3, stride=1, groups=1
):
padding = (kernel_size-1) // 2
super().__init__(
nn.Conv2d(
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias=False,
),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True),
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend(
[
# dw
ConvBNReLU(
hidden_dim, hidden_dim, stride=stride, groups=hidden_dim
),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
]
)
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(Backbone):
def __init__(
self,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None,
):
"""
MobileNet V2.
Args:
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
"""
super().__init__()
if block is None:
block = InvertedResidual
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if (
len(inverted_residual_setting) == 0
or len(inverted_residual_setting[0]) != 4
):
raise ValueError(
"inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".
format(inverted_residual_setting)
)
# building first layer
input_channel = _make_divisible(
input_channel * width_mult, round_nearest
)
self.last_channel = _make_divisible(
last_channel * max(1.0, width_mult), round_nearest
)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(
block(
input_channel, output_channel, stride, expand_ratio=t
)
)
input_channel = output_channel
# building last several layers
features.append(
ConvBNReLU(input_channel, self.last_channel, kernel_size=1)
)
# make it nn.Sequential
self.features = nn.Sequential(*features)
self._out_features = self.last_channel
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
x = x.mean([2, 3])
return x
def forward(self, x):
return self._forward_impl(x)
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
if model_url is None:
import warnings
warnings.warn(
"ImageNet pretrained weights are unavailable for this model"
)
return
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
@BACKBONE_REGISTRY.register()
def mobilenetv2(pretrained=True, **kwargs):
model = MobileNetV2(**kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["mobilenet_v2"])
return model

View File

@@ -0,0 +1,135 @@
import torch.nn as nn
import torch.nn.functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class PreActBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
class PreActBottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, self.expansion * planes, kernel_size=1, bias=False
)
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += shortcut
return out
class PreActResNet(Backbone):
def __init__(self, block, num_blocks):
super().__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self._out_features = 512 * block.expansion
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
return out
"""
Preact-ResNet18 was used for the CIFAR10 and
SVHN datasets (both are SSL tasks) in
- Wang et al. Semi-Supervised Learning by
Augmented Distribution Alignment. ICCV 2019.
"""
@BACKBONE_REGISTRY.register()
def preact_resnet18(**kwargs):
return PreActResNet(PreActBlock, [2, 2, 2, 2])

View File

@@ -0,0 +1,589 @@
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(Backbone):
def __init__(
self,
block,
layers,
ms_class=None,
ms_layers=[],
ms_p=0.5,
ms_a=0.1,
**kwargs
):
self.inplanes = 64
super().__init__()
# backbone network
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self._out_features = 512 * block.expansion
self.mixstyle = None
if ms_layers:
self.mixstyle = ms_class(p=ms_p, alpha=ms_a)
for layer_name in ms_layers:
assert layer_name in ["layer1", "layer2", "layer3"]
print(f"Insert MixStyle after {ms_layers}")
self.ms_layers = ms_layers
self._init_params()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
if "layer1" in self.ms_layers:
x = self.mixstyle(x)
x = self.layer2(x)
if "layer2" in self.ms_layers:
x = self.mixstyle(x)
x = self.layer3(x)
if "layer3" in self.ms_layers:
x = self.mixstyle(x)
return self.layer4(x)
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
return v.view(v.size(0), -1)
def init_pretrained_weights(model, model_url):
pretrain_dict = model_zoo.load_url(model_url)
model.load_state_dict(pretrain_dict, strict=False)
"""
Residual network configurations:
--
resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
"""
@BACKBONE_REGISTRY.register()
def resnet18(pretrained=True, **kwargs):
model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2])
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet34(pretrained=True, **kwargs):
model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet34"])
return model
@BACKBONE_REGISTRY.register()
def resnet50(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 4, 23, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet152(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 8, 36, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet152"])
return model
"""
Residual networks with mixstyle
"""
@BACKBONE_REGISTRY.register()
def resnet18_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
"""
Residual networks with efdmix
"""
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model

View File

@@ -0,0 +1,229 @@
"""
Code source: https://github.com/pytorch/vision
"""
import torch
import torch.utils.model_zoo as model_zoo
from torch import nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"shufflenetv2_x0.5":
"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0":
"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": None,
"shufflenetv2_x2.0": None,
}
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super().__init__()
if not (1 <= stride <= 3):
raise ValueError("illegal stride value")
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(
inp, inp, kernel_size=3, stride=self.stride, padding=1
),
nn.BatchNorm2d(inp),
nn.Conv2d(
inp,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
),
nn.BatchNorm2d(branch_features),
nn.Conv2d(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(
i, o, kernel_size, stride, padding, bias=bias, groups=i
)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(Backbone):
def __init__(self, stages_repeats, stages_out_channels, **kwargs):
super().__init__()
if len(stages_repeats) != 3:
raise ValueError(
"expected stages_repeats as list of 3 positive ints"
)
if len(stages_out_channels) != 5:
raise ValueError(
"expected stages_out_channels as list of 5 positive ints"
)
self._stage_out_channels = stages_out_channels
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]
):
seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(
InvertedResidual(output_channels, output_channels, 1)
)
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
self._out_features = output_channels
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
return v.view(v.size(0), -1)
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
if model_url is None:
import warnings
warnings.warn(
"ImageNet pretrained weights are unavailable for this model"
)
return
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x0_5(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x0.5"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x1_0(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x1.0"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x1_5(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x1.5"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x2_0(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x2.0"])
return model

View File

@@ -0,0 +1,147 @@
import torch
import torch.nn as nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
model_urls = {
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
}
class VGG(Backbone):
def __init__(self, features, init_weights=True):
super().__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
# Note that self.classifier outputs features rather than logits
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
)
self._out_features = 4096
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return self.classifier(x)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == "M":
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfgs = {
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"B":
[64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"D": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
"M",
512,
512,
512,
"M",
512,
512,
512,
"M",
],
"E": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
256,
"M",
512,
512,
512,
512,
"M",
512,
512,
512,
512,
"M",
],
}
def _vgg(arch, cfg, batch_norm, pretrained):
init_weights = False if pretrained else True
model = VGG(
make_layers(cfgs[cfg], batch_norm=batch_norm),
init_weights=init_weights
)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=True)
model.load_state_dict(state_dict, strict=False)
return model
@BACKBONE_REGISTRY.register()
def vgg16(pretrained=True, **kwargs):
return _vgg("vgg16", "D", False, pretrained)

View File

@@ -0,0 +1,150 @@
"""
Modified from https://github.com/xternalz/WideResNet-pytorch
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu1 = nn.LeakyReLU(0.01, inplace=True)
self.conv1 = nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(out_planes)
self.relu2 = nn.LeakyReLU(0.01, inplace=True)
self.conv2 = nn.Conv2d(
out_planes,
out_planes,
kernel_size=3,
stride=1,
padding=1,
bias=False
)
self.droprate = dropRate
self.equalInOut = in_planes == out_planes
self.convShortcut = (
(not self.equalInOut) and nn.Conv2d(
in_planes,
out_planes,
kernel_size=1,
stride=stride,
padding=0,
bias=False,
) or None
)
def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class NetworkBlock(nn.Module):
def __init__(
self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0
):
super().__init__()
self.layer = self._make_layer(
block, in_planes, out_planes, nb_layers, stride, dropRate
)
def _make_layer(
self, block, in_planes, out_planes, nb_layers, stride, dropRate
):
layers = []
for i in range(int(nb_layers)):
layers.append(
block(
i == 0 and in_planes or out_planes,
out_planes,
i == 0 and stride or 1,
dropRate,
)
)
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNet(Backbone):
def __init__(self, depth, widen_factor, dropRate=0.0):
super().__init__()
nChannels = [
16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor
]
assert (depth-4) % 6 == 0
n = (depth-4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(
3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False
)
# 1st block
self.block1 = NetworkBlock(
n, nChannels[0], nChannels[1], block, 1, dropRate
)
# 2nd block
self.block2 = NetworkBlock(
n, nChannels[1], nChannels[2], block, 2, dropRate
)
# 3rd block
self.block3 = NetworkBlock(
n, nChannels[2], nChannels[3], block, 2, dropRate
)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3])
self.relu = nn.LeakyReLU(0.01, inplace=True)
self._out_features = nChannels[3]
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.adaptive_avg_pool2d(out, 1)
return out.view(out.size(0), -1)
@BACKBONE_REGISTRY.register()
def wide_resnet_28_2(**kwargs):
return WideResNet(28, 2)
@BACKBONE_REGISTRY.register()
def wide_resnet_16_4(**kwargs):
return WideResNet(16, 4)

View File

@@ -0,0 +1,3 @@
from .build import build_head, HEAD_REGISTRY # isort:skip
from .mlp import mlp

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
HEAD_REGISTRY = Registry("HEAD")
def build_head(name, verbose=True, **kwargs):
avai_heads = HEAD_REGISTRY.registered_names()
check_availability(name, avai_heads)
if verbose:
print("Head: {}".format(name))
return HEAD_REGISTRY.get(name)(**kwargs)

View File

@@ -0,0 +1,50 @@
import functools
import torch.nn as nn
from .build import HEAD_REGISTRY
class MLP(nn.Module):
def __init__(
self,
in_features=2048,
hidden_layers=[],
activation="relu",
bn=True,
dropout=0.0,
):
super().__init__()
if isinstance(hidden_layers, int):
hidden_layers = [hidden_layers]
assert len(hidden_layers) > 0
self.out_features = hidden_layers[-1]
mlp = []
if activation == "relu":
act_fn = functools.partial(nn.ReLU, inplace=True)
elif activation == "leaky_relu":
act_fn = functools.partial(nn.LeakyReLU, inplace=True)
else:
raise NotImplementedError
for hidden_dim in hidden_layers:
mlp += [nn.Linear(in_features, hidden_dim)]
if bn:
mlp += [nn.BatchNorm1d(hidden_dim)]
mlp += [act_fn()]
if dropout > 0:
mlp += [nn.Dropout(dropout)]
in_features = hidden_dim
self.mlp = nn.Sequential(*mlp)
def forward(self, x):
return self.mlp(x)
@HEAD_REGISTRY.register()
def mlp(**kwargs):
return MLP(**kwargs)

View File

@@ -0,0 +1,5 @@
from .build import build_network, NETWORK_REGISTRY # isort:skip
from .ddaig_fcn import (
fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn
)

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
NETWORK_REGISTRY = Registry("NETWORK")
def build_network(name, verbose=True, **kwargs):
avai_models = NETWORK_REGISTRY.registered_names()
check_availability(name, avai_models)
if verbose:
print("Network: {}".format(name))
return NETWORK_REGISTRY.get(name)(**kwargs)

View File

@@ -0,0 +1,329 @@
"""
Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
"""
import functools
import torch
import torch.nn as nn
from torch.nn import functional as F
from .build import NETWORK_REGISTRY
def init_network_weights(model, init_type="normal", gain=0.02):
def _init_func(m):
classname = m.__class__.__name__
if hasattr(m, "weight") and (
classname.find("Conv") != -1 or classname.find("Linear") != -1
):
if init_type == "normal":
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "kaiming":
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
nn.init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError(
"initialization method {} is not implemented".
format(init_type)
)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("InstanceNorm2d") != -1:
if m.weight is not None and m.bias is not None:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
model.apply(_init_func)
def get_norm_layer(norm_type="instance"):
if norm_type == "batch":
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == "instance":
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=False
)
elif norm_type == "none":
norm_layer = None
else:
raise NotImplementedError(
"normalization layer [%s] is not found" % norm_type
)
return norm_layer
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super().__init__()
self.conv_block = self.build_conv_block(
dim, padding_type, norm_layer, use_dropout, use_bias
)
def build_conv_block(
self, dim, padding_type, norm_layer, use_dropout, use_bias
):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
"padding [%s] is not implemented" % padding_type
)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True),
]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
"padding [%s] is not implemented" % padding_type
)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
]
return nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class LocNet(nn.Module):
"""Localization network."""
def __init__(
self,
input_nc,
nc=32,
n_blocks=3,
use_dropout=False,
padding_type="zero",
image_size=32,
):
super().__init__()
backbone = []
backbone += [
nn.Conv2d(
input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False
)
]
backbone += [nn.BatchNorm2d(nc)]
backbone += [nn.ReLU(True)]
for _ in range(n_blocks):
backbone += [
ResnetBlock(
nc,
padding_type=padding_type,
norm_layer=nn.BatchNorm2d,
use_dropout=use_dropout,
use_bias=False,
)
]
backbone += [nn.MaxPool2d(2, stride=2)]
self.backbone = nn.Sequential(*backbone)
reduced_imsize = int(image_size * 0.5**(n_blocks + 1))
self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)
def forward(self, x):
x = self.backbone(x)
x = x.view(x.size(0), -1)
x = self.fc_loc(x)
x = torch.tanh(x)
x = x.view(-1, 2, 2)
theta = x.data.new_zeros(x.size(0), 2, 3)
theta[:, :, :2] = x
return theta
class FCN(nn.Module):
"""Fully convolutional network."""
def __init__(
self,
input_nc,
output_nc,
nc=32,
n_blocks=3,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
padding_type="reflect",
gctx=True,
stn=False,
image_size=32,
):
super().__init__()
backbone = []
p = 0
if padding_type == "reflect":
backbone += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
backbone += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError
backbone += [
nn.Conv2d(
input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False
)
]
backbone += [norm_layer(nc)]
backbone += [nn.ReLU(True)]
for _ in range(n_blocks):
backbone += [
ResnetBlock(
nc,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
use_bias=False,
)
]
self.backbone = nn.Sequential(*backbone)
# global context fusion layer
self.gctx_fusion = None
if gctx:
self.gctx_fusion = nn.Sequential(
nn.Conv2d(
2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False
),
norm_layer(nc),
nn.ReLU(True),
)
self.regress = nn.Sequential(
nn.Conv2d(
nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True
),
nn.Tanh(),
)
self.locnet = None
if stn:
self.locnet = LocNet(
input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size
)
def init_loc_layer(self):
"""Initialize the weights/bias with identity transformation."""
if self.locnet is not None:
self.locnet.fc_loc.weight.data.zero_()
self.locnet.fc_loc.bias.data.copy_(
torch.tensor([1, 0, 0, 1], dtype=torch.float)
)
def stn(self, x):
"""Spatial transformer network."""
theta = self.locnet(x)
grid = F.affine_grid(theta, x.size())
return F.grid_sample(x, grid), theta
def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
"""
Args:
x (torch.Tensor): input mini-batch.
lmda (float): multiplier for perturbation.
return_p (bool): return perturbation.
return_stn_output (bool): return the output of stn.
"""
theta = None
if self.locnet is not None:
x, theta = self.stn(x)
input = x
x = self.backbone(x)
if self.gctx_fusion is not None:
c = F.adaptive_avg_pool2d(x, (1, 1))
c = c.expand_as(x)
x = torch.cat([x, c], 1)
x = self.gctx_fusion(x)
p = self.regress(x)
x_p = input + lmda*p
if return_stn_output:
return x_p, p, input
if return_p:
return x_p, p
return x_p
@NETWORK_REGISTRY.register()
def fcn_3x32_gctx(**kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)
init_network_weights(net, init_type="normal", gain=0.02)
return net
@NETWORK_REGISTRY.register()
def fcn_3x64_gctx(**kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)
init_network_weights(net, init_type="normal", gain=0.02)
return net
@NETWORK_REGISTRY.register()
def fcn_3x32_gctx_stn(image_size=32, **kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(
3,
3,
nc=32,
n_blocks=3,
norm_layer=norm_layer,
stn=True,
image_size=image_size
)
init_network_weights(net, init_type="normal", gain=0.02)
net.init_loc_layer()
return net
@NETWORK_REGISTRY.register()
def fcn_3x64_gctx_stn(image_size=224, **kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(
3,
3,
nc=64,
n_blocks=3,
norm_layer=norm_layer,
stn=True,
image_size=image_size
)
init_network_weights(net, init_type="normal", gain=0.02)
net.init_loc_layer()
return net

View File

@@ -0,0 +1,16 @@
from .mmd import MaximumMeanDiscrepancy
from .dsbn import DSBN1d, DSBN2d
from .mixup import mixup
from .efdmix import (
EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix,
crossdomain_efdmix, run_without_efdmix
)
from .mixstyle import (
MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle,
deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle
)
from .transnorm import TransNorm1d, TransNorm2d
from .sequential2 import Sequential2
from .reverse_grad import ReverseGrad
from .cross_entropy import cross_entropy
from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance

View File

@@ -0,0 +1,30 @@
import torch
from torch.nn import functional as F
def cross_entropy(input, target, label_smooth=0, reduction="mean"):
"""Cross entropy loss.
Args:
input (torch.Tensor): logit matrix with shape of (batch, num_classes).
target (torch.LongTensor): int label matrix.
label_smooth (float, optional): label smoothing hyper-parameter.
Default is 0.
reduction (str, optional): how the losses for a mini-batch
will be aggregated. Default is 'mean'.
"""
num_classes = input.shape[1]
log_prob = F.log_softmax(input, dim=1)
zeros = torch.zeros(log_prob.size())
target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1)
target = target.type_as(input)
target = (1-label_smooth) * target + label_smooth/num_classes
loss = (-target * log_prob).sum(1)
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
elif reduction == "none":
return loss
else:
raise ValueError

View File

@@ -0,0 +1,45 @@
import torch.nn as nn
class _DSBN(nn.Module):
"""Domain Specific Batch Normalization.
Args:
num_features (int): number of features.
n_domain (int): number of domains.
bn_type (str): type of bn. Choices are ['1d', '2d'].
"""
def __init__(self, num_features, n_domain, bn_type):
super().__init__()
if bn_type == "1d":
BN = nn.BatchNorm1d
elif bn_type == "2d":
BN = nn.BatchNorm2d
else:
raise ValueError
self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))
self.valid_domain_idxs = list(range(n_domain))
self.n_domain = n_domain
self.domain_idx = 0
def select_bn(self, domain_idx=0):
assert domain_idx in self.valid_domain_idxs
self.domain_idx = domain_idx
def forward(self, x):
return self.bn[self.domain_idx](x)
class DSBN1d(_DSBN):
def __init__(self, num_features, n_domain):
super().__init__(num_features, n_domain, "1d")
class DSBN2d(_DSBN):
def __init__(self, num_features, n_domain):
super().__init__(num_features, n_domain, "2d")

View File

@@ -0,0 +1,118 @@
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(False)
def activate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(True)
def random_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("random")
def crossdomain_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_efdmix(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_efdmix)
yield
finally:
model.apply(activate_efdmix)
@contextmanager
def run_with_efdmix(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_efdmix)
elif mix == "crossdomain":
model.apply(crossdomain_efdmix)
try:
model.apply(activate_efdmix)
yield
finally:
model.apply(deactivate_efdmix)
class EFDMix(nn.Module):
"""EFDMix.
Reference:
Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)
x_view = x.view(B, C, -1)
value_x, index_x = torch.sort(x_view) # sort inputs
lmda = self.beta.sample((B, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
inverse_index = index_x.argsort(-1)
x_view_copy = value_x[perm].gather(-1, inverse_index)
new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)
return new_x.view(B, C, W, H)

View File

@@ -0,0 +1,124 @@
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(False)
def activate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(True)
def random_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("random")
def crossdomain_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_mixstyle(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_mixstyle)
yield
finally:
model.apply(activate_mixstyle)
@contextmanager
def run_with_mixstyle(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_mixstyle)
elif mix == "crossdomain":
model.apply(crossdomain_mixstyle)
try:
model.apply(activate_mixstyle)
yield
finally:
model.apply(deactivate_mixstyle)
class MixStyle(nn.Module):
"""MixStyle.
Reference:
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B = x.size(0)
mu = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], keepdim=True)
sig = (var + self.eps).sqrt()
mu, sig = mu.detach(), sig.detach()
x_normed = (x-mu) / sig
lmda = self.beta.sample((B, 1, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
mu2, sig2 = mu[perm], sig[perm]
mu_mix = mu*lmda + mu2 * (1-lmda)
sig_mix = sig*lmda + sig2 * (1-lmda)
return x_normed*sig_mix + mu_mix

View File

@@ -0,0 +1,23 @@
import torch
def mixup(x1, x2, y1, y2, beta, preserve_order=False):
"""Mixup.
Args:
x1 (torch.Tensor): data with shape of (b, c, h, w).
x2 (torch.Tensor): data with shape of (b, c, h, w).
y1 (torch.Tensor): label with shape of (b, n).
y2 (torch.Tensor): label with shape of (b, n).
beta (float): hyper-parameter for Beta sampling.
preserve_order (bool): apply lmda=max(lmda, 1-lmda).
Default is False.
"""
lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])
if preserve_order:
lmda = torch.max(lmda, 1 - lmda)
lmda = lmda.to(x1.device)
xmix = x1*lmda + x2 * (1-lmda)
lmda = lmda[:, :, 0, 0]
ymix = y1*lmda + y2 * (1-lmda)
return xmix, ymix

View File

@@ -0,0 +1,91 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
class MaximumMeanDiscrepancy(nn.Module):
def __init__(self, kernel_type="rbf", normalize=False):
super().__init__()
self.kernel_type = kernel_type
self.normalize = normalize
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
if self.normalize:
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
if self.kernel_type == "linear":
return self.linear_mmd(x, y)
elif self.kernel_type == "poly":
return self.poly_mmd(x, y)
elif self.kernel_type == "rbf":
return self.rbf_mmd(x, y)
else:
raise NotImplementedError
def linear_mmd(self, x, y):
# k(x, y) = x^T y
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_xy = torch.mm(x, y.t())
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
# k(x, y) = (alpha * x^T y + c)^d
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_xx = (alpha*k_xx + c).pow(d)
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_yy = (alpha*k_yy + c).pow(d)
k_xy = torch.mm(x, y.t())
k_xy = (alpha*k_xy + c).pow(d)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def rbf_mmd(self, x, y):
# k_xx
d_xx = self.euclidean_squared_distance(x, x)
d_xx = self.remove_self_distance(d_xx)
k_xx = self.rbf_kernel_mixture(d_xx)
# k_yy
d_yy = self.euclidean_squared_distance(y, y)
d_yy = self.remove_self_distance(d_yy)
k_yy = self.rbf_kernel_mixture(d_yy)
# k_xy
d_xy = self.euclidean_squared_distance(x, y)
k_xy = self.rbf_kernel_mixture(d_xy)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
@staticmethod
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
K = 0
for sigma in sigmas:
gamma = 1.0 / (2.0 * sigma**2)
K += torch.exp(-gamma * exponent)
return K
@staticmethod
def remove_self_distance(distmat):
tmp_list = []
for i, row in enumerate(distmat):
row1 = torch.cat([row[:i], row[i + 1:]])
tmp_list.append(row1)
return torch.stack(tmp_list)
@staticmethod
def euclidean_squared_distance(x, y):
m, n = x.size(0), y.size(0)
distmat = (
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
# distmat.addmm_(1, -2, x, y.t())
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
return distmat
if __name__ == "__main__":
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
d = mmd(input1, input2)
print(d.item())

View File

@@ -0,0 +1,147 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
class OptimalTransport(nn.Module):
@staticmethod
def distance(batch1, batch2, dist_metric="cosine"):
if dist_metric == "cosine":
batch1 = F.normalize(batch1, p=2, dim=1)
batch2 = F.normalize(batch2, p=2, dim=1)
dist_mat = 1 - torch.mm(batch1, batch2.t())
elif dist_metric == "euclidean":
m, n = batch1.size(0), batch2.size(0)
dist_mat = (
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
dist_mat.addmm_(
1, -2, batch1, batch2.t()
) # squared euclidean distance
elif dist_metric == "fast_euclidean":
batch1 = batch1.unsqueeze(-2)
batch2 = batch2.unsqueeze(-3)
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
else:
raise ValueError(
"Unknown cost function: {}. Expected to "
"be one of [cosine | euclidean]".format(dist_metric)
)
return dist_mat
class SinkhornDivergence(OptimalTransport):
thre = 1e-3
def __init__(
self,
dist_metric="cosine",
eps=0.01,
max_iter=5,
bp_to_sinkhorn=False
):
super().__init__()
self.dist_metric = dist_metric
self.eps = eps
self.max_iter = max_iter
self.bp_to_sinkhorn = bp_to_sinkhorn
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
W_xy = self.transport_cost(x, y)
W_xx = self.transport_cost(x, x)
W_yy = self.transport_cost(y, y)
return 2*W_xy - W_xx - W_yy
def transport_cost(self, x, y, return_pi=False):
C = self.distance(x, y, dist_metric=self.dist_metric)
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
if not self.bp_to_sinkhorn:
pi = pi.detach()
cost = torch.sum(pi * C)
if return_pi:
return cost, pi
return cost
@staticmethod
def sinkhorn_iterate(C, eps, max_iter, thre):
nx, ny = C.shape
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
u = torch.zeros_like(mu)
v = torch.zeros_like(nu)
def M(_C, _u, _v):
"""Modified cost for logarithmic updates.
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
"""
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
real_iter = 0 # check if algorithm terminates before max_iter
# Sinkhorn iterations
for i in range(max_iter):
u0 = u
u = eps * (
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
) + u
v = (
eps * (
torch.log(nu + 1e-8) -
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
) + v
)
err = (u - u0).abs().sum()
real_iter += 1
if err.item() < thre:
break
# Transport plan pi = diag(a)*K*diag(b)
return torch.exp(M(C, u, v))
class MinibatchEnergyDistance(SinkhornDivergence):
def __init__(
self,
dist_metric="cosine",
eps=0.01,
max_iter=5,
bp_to_sinkhorn=False
):
super().__init__(
dist_metric=dist_metric,
eps=eps,
max_iter=max_iter,
bp_to_sinkhorn=bp_to_sinkhorn,
)
def forward(self, x, y):
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
cost = 0
cost += self.transport_cost(x1, y1)
cost += self.transport_cost(x1, y2)
cost += self.transport_cost(x2, y1)
cost += self.transport_cost(x2, y2)
cost -= 2 * self.transport_cost(x1, x2)
cost -= 2 * self.transport_cost(y1, y2)
return cost
if __name__ == "__main__":
# example: https://dfdazac.github.io/sinkhorn.html
import numpy as np
n_points = 5
a = np.array([[i, 0] for i in range(n_points)])
b = np.array([[i, 1] for i in range(n_points)])
x = torch.tensor(a, dtype=torch.float)
y = torch.tensor(b, dtype=torch.float)
sinkhorn = SinkhornDivergence(
dist_metric="euclidean", eps=0.01, max_iter=5
)
dist, pi = sinkhorn.transport_cost(x, y, True)
import pdb
pdb.set_trace()

View File

@@ -0,0 +1,34 @@
import torch.nn as nn
from torch.autograd import Function
class _ReverseGrad(Function):
@staticmethod
def forward(ctx, input, grad_scaling):
ctx.grad_scaling = grad_scaling
return input.view_as(input)
@staticmethod
def backward(ctx, grad_output):
grad_scaling = ctx.grad_scaling
return -grad_scaling * grad_output, None
reverse_grad = _ReverseGrad.apply
class ReverseGrad(nn.Module):
"""Gradient reversal layer.
It acts as an identity layer in the forward,
but reverses the sign of the gradient in
the backward.
"""
def forward(self, x, grad_scaling=1.0):
assert (grad_scaling >=
0), "grad_scaling must be non-negative, " "but got {}".format(
grad_scaling
)
return reverse_grad(x, grad_scaling)

View File

@@ -0,0 +1,15 @@
import torch.nn as nn
class Sequential2(nn.Sequential):
"""An alternative sequential container to nn.Sequential,
which accepts an arbitrary number of input arguments.
"""
def forward(self, *inputs):
for module in self._modules.values():
if isinstance(inputs, tuple):
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs

View File

@@ -0,0 +1,138 @@
import torch
import torch.nn as nn
class _TransNorm(nn.Module):
"""Transferable normalization.
Reference:
- Wang et al. Transferable Normalization: Towards Improving
Transferability of Deep Neural Networks. NeurIPS 2019.
Args:
num_features (int): number of features.
eps (float): epsilon.
momentum (float): value for updating running_mean and running_var.
adaptive_alpha (bool): apply domain adaptive alpha.
"""
def __init__(
self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.adaptive_alpha = adaptive_alpha
self.register_buffer("running_mean_s", torch.zeros(num_features))
self.register_buffer("running_var_s", torch.ones(num_features))
self.register_buffer("running_mean_t", torch.zeros(num_features))
self.register_buffer("running_var_t", torch.ones(num_features))
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def resnet_running_stats(self):
self.running_mean_s.zero_()
self.running_var_s.fill_(1)
self.running_mean_t.zero_()
self.running_var_t.fill_(1)
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def _check_input(self, x):
raise NotImplementedError
def _compute_alpha(self, mean_s, var_s, mean_t, var_t):
C = self.num_features
ratio_s = mean_s / (var_s + self.eps).sqrt()
ratio_t = mean_t / (var_t + self.eps).sqrt()
dist = (ratio_s - ratio_t).abs()
dist_inv = 1 / (1+dist)
return C * dist_inv / dist_inv.sum()
def forward(self, input):
self._check_input(input)
C = self.num_features
if input.dim() == 2:
new_shape = (1, C)
elif input.dim() == 4:
new_shape = (1, C, 1, 1)
else:
raise ValueError
weight = self.weight.view(*new_shape)
bias = self.bias.view(*new_shape)
if not self.training:
mean_t = self.running_mean_t.view(*new_shape)
var_t = self.running_var_t.view(*new_shape)
output = (input-mean_t) / (var_t + self.eps).sqrt()
output = output*weight + bias
if self.adaptive_alpha:
mean_s = self.running_mean_s.view(*new_shape)
var_s = self.running_var_s.view(*new_shape)
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)
x_s = input_s.transpose(0, 1).reshape(C, -1)
mean_s = x_s.mean(1)
var_s = x_s.var(1)
self.running_mean_s.mul_(self.momentum)
self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
self.running_var_s.mul_(self.momentum)
self.running_var_s.add_((1 - self.momentum) * var_s.data)
mean_s = mean_s.reshape(*new_shape)
var_s = var_s.reshape(*new_shape)
output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
output_s = output_s*weight + bias
x_t = input_t.transpose(0, 1).reshape(C, -1)
mean_t = x_t.mean(1)
var_t = x_t.var(1)
self.running_mean_t.mul_(self.momentum)
self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
self.running_var_t.mul_(self.momentum)
self.running_var_t.add_((1 - self.momentum) * var_t.data)
mean_t = mean_t.reshape(*new_shape)
var_t = var_t.reshape(*new_shape)
output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
output_t = output_t*weight + bias
output = torch.cat([output_s, output_t], 0)
if self.adaptive_alpha:
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
class TransNorm1d(_TransNorm):
def _check_input(self, x):
if x.dim() != 2:
raise ValueError(
"Expected the input to be 2-D, "
"but got {}-D".format(x.dim())
)
class TransNorm2d(_TransNorm):
def _check_input(self, x):
if x.dim() != 4:
raise ValueError(
"Expected the input to be 4-D, "
"but got {}-D".format(x.dim())
)

View File

@@ -0,0 +1,75 @@
import numpy as np
import torch
def sharpen_prob(p, temperature=2):
"""Sharpening probability with a temperature.
Args:
p (torch.Tensor): probability matrix (batch_size, n_classes)
temperature (float): temperature.
"""
p = p.pow(temperature)
return p / p.sum(1, keepdim=True)
def reverse_index(data, label):
"""Reverse order."""
inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()
return data[inv_idx], label[inv_idx]
def shuffle_index(data, label):
"""Shuffle order."""
rnd_idx = torch.randperm(data.shape[0])
return data[rnd_idx], label[rnd_idx]
def create_onehot(label, num_classes):
"""Create one-hot tensor.
We suggest using nn.functional.one_hot.
Args:
label (torch.Tensor): 1-D tensor.
num_classes (int): number of classes.
"""
onehot = torch.zeros(label.shape[0], num_classes)
return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)
def sigmoid_rampup(current, rampup_length):
"""Exponential rampup.
Args:
current (int): current step.
rampup_length (int): maximum step.
"""
assert rampup_length > 0
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current/rampup_length
return float(np.exp(-5.0 * phase * phase))
def linear_rampup(current, rampup_length):
"""Linear rampup.
Args:
current (int): current step.
rampup_length (int): maximum step.
"""
assert rampup_length > 0
ratio = np.clip(current / rampup_length, 0.0, 1.0)
return float(ratio)
def ema_model_update(model, ema_model, alpha):
"""Exponential moving average of model parameters.
Args:
model (nn.Module): model being trained.
ema_model (nn.Module): ema of the model.
alpha (float): ema decay rate.
"""
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)