release code
This commit is contained in:
3
Dassl.ProGrad.pytorch/dassl/modeling/__init__.py
Normal file
3
Dassl.ProGrad.pytorch/dassl/modeling/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .head import HEAD_REGISTRY, build_head
|
||||
from .network import NETWORK_REGISTRY, build_network
|
||||
from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone
|
||||
27
Dassl.ProGrad.pytorch/dassl/modeling/backbone/__init__.py
Normal file
27
Dassl.ProGrad.pytorch/dassl/modeling/backbone/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from .build import build_backbone, BACKBONE_REGISTRY # isort:skip
|
||||
from .backbone import Backbone # isort:skip
|
||||
|
||||
from .vgg import vgg16
|
||||
from .resnet import (
|
||||
resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1,
|
||||
resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1,
|
||||
resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123,
|
||||
resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12,
|
||||
resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123,
|
||||
resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123
|
||||
)
|
||||
from .alexnet import alexnet
|
||||
from .mobilenetv2 import mobilenetv2
|
||||
from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2
|
||||
from .cnn_digitsdg import cnn_digitsdg
|
||||
from .efficientnet import (
|
||||
efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3,
|
||||
efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
|
||||
)
|
||||
from .shufflenetv2 import (
|
||||
shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5,
|
||||
shufflenet_v2_x2_0
|
||||
)
|
||||
from .cnn_digitsingle import cnn_digitsingle
|
||||
from .preact_resnet18 import preact_resnet18
|
||||
from .cnn_digit5_m3sda import cnn_digit5_m3sda
|
||||
64
Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py
Normal file
64
Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
model_urls = {
|
||||
"alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth",
|
||||
}
|
||||
|
||||
|
||||
class AlexNet(Backbone):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
|
||||
# Note that self.classifier outputs features rather than logits
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(),
|
||||
nn.Linear(256 * 6 * 6, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self._out_features = 4096
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model.load_state_dict(pretrain_dict, strict=False)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def alexnet(pretrained=True, **kwargs):
|
||||
model = AlexNet()
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["alexnet"])
|
||||
|
||||
return model
|
||||
17
Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py
Normal file
17
Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Backbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
"""Output feature dimension."""
|
||||
if self.__dict__.get("_out_features") is None:
|
||||
return None
|
||||
return self._out_features
|
||||
11
Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
BACKBONE_REGISTRY = Registry("BACKBONE")
|
||||
|
||||
|
||||
def build_backbone(name, verbose=True, **kwargs):
|
||||
avai_backbones = BACKBONE_REGISTRY.registered_names()
|
||||
check_availability(name, avai_backbones)
|
||||
if verbose:
|
||||
print("Backbone: {}".format(name))
|
||||
return BACKBONE_REGISTRY.get(name)(**kwargs)
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Reference
|
||||
|
||||
https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class FeatureExtractor(Backbone):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
|
||||
self.bn2 = nn.BatchNorm2d(64)
|
||||
self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
|
||||
self.bn3 = nn.BatchNorm2d(128)
|
||||
self.fc1 = nn.Linear(8192, 3072)
|
||||
self.bn1_fc = nn.BatchNorm1d(3072)
|
||||
self.fc2 = nn.Linear(3072, 2048)
|
||||
self.bn2_fc = nn.BatchNorm1d(2048)
|
||||
|
||||
self._out_features = 2048
|
||||
|
||||
def _check_input(self, x):
|
||||
H, W = x.shape[2:]
|
||||
assert (
|
||||
H == 32 and W == 32
|
||||
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input(x)
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = x.view(x.size(0), 8192)
|
||||
x = F.relu(self.bn1_fc(self.fc1(x)))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = F.relu(self.bn2_fc(self.fc2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def cnn_digit5_m3sda(**kwargs):
|
||||
"""
|
||||
This architecture was used for the Digit-5 dataset in:
|
||||
|
||||
- Peng et al. Moment Matching for Multi-Source
|
||||
Domain Adaptation. ICCV 2019.
|
||||
"""
|
||||
return FeatureExtractor()
|
||||
@@ -0,0 +1,61 @@
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.utils import init_network_weights
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class Convolution(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.conv(x))
|
||||
|
||||
|
||||
class ConvNet(Backbone):
|
||||
|
||||
def __init__(self, c_hidden=64):
|
||||
super().__init__()
|
||||
self.conv1 = Convolution(3, c_hidden)
|
||||
self.conv2 = Convolution(c_hidden, c_hidden)
|
||||
self.conv3 = Convolution(c_hidden, c_hidden)
|
||||
self.conv4 = Convolution(c_hidden, c_hidden)
|
||||
|
||||
self._out_features = 2**2 * c_hidden
|
||||
|
||||
def _check_input(self, x):
|
||||
H, W = x.shape[2:]
|
||||
assert (
|
||||
H == 32 and W == 32
|
||||
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input(x)
|
||||
x = self.conv1(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.conv2(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.conv3(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.conv4(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def cnn_digitsdg(**kwargs):
|
||||
"""
|
||||
This architecture was used for DigitsDG dataset in:
|
||||
|
||||
- Zhou et al. Deep Domain-Adversarial Image Generation
|
||||
for Domain Generalisation. AAAI 2020.
|
||||
"""
|
||||
model = ConvNet(c_hidden=64)
|
||||
init_network_weights(model, init_type="kaiming")
|
||||
return model
|
||||
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
This model is built based on
|
||||
https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.utils import init_network_weights
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class CNN(Backbone):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, 5)
|
||||
self.conv2 = nn.Conv2d(64, 128, 5)
|
||||
self.fc3 = nn.Linear(5 * 5 * 128, 1024)
|
||||
self.fc4 = nn.Linear(1024, 1024)
|
||||
|
||||
self._out_features = 1024
|
||||
|
||||
def _check_input(self, x):
|
||||
H, W = x.shape[2:]
|
||||
assert (
|
||||
H == 32 and W == 32
|
||||
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input(x)
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
x = self.fc3(x)
|
||||
x = F.relu(x)
|
||||
|
||||
x = self.fc4(x)
|
||||
x = F.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def cnn_digitsingle(**kwargs):
|
||||
model = CNN()
|
||||
init_network_weights(model, init_type="kaiming")
|
||||
return model
|
||||
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Source: https://github.com/lukemelas/EfficientNet-PyTorch.
|
||||
"""
|
||||
__version__ = "0.6.4"
|
||||
from .model import (
|
||||
EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2,
|
||||
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6,
|
||||
efficientnet_b7
|
||||
)
|
||||
from .utils import (
|
||||
BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params
|
||||
)
|
||||
@@ -0,0 +1,371 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .utils import (
|
||||
Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats,
|
||||
get_model_params, efficientnet_params, get_same_padding_conv2d,
|
||||
load_pretrained_weights, calculate_output_image_size
|
||||
)
|
||||
from ..build import BACKBONE_REGISTRY
|
||||
from ..backbone import Backbone
|
||||
|
||||
|
||||
class MBConvBlock(nn.Module):
|
||||
"""
|
||||
Mobile Inverted Residual Bottleneck Block
|
||||
|
||||
Args:
|
||||
block_args (namedtuple): BlockArgs, see above
|
||||
global_params (namedtuple): GlobalParam, see above
|
||||
|
||||
Attributes:
|
||||
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, global_params, image_size=None):
|
||||
super().__init__()
|
||||
self._block_args = block_args
|
||||
self._bn_mom = 1 - global_params.batch_norm_momentum
|
||||
self._bn_eps = global_params.batch_norm_epsilon
|
||||
self.has_se = (self._block_args.se_ratio
|
||||
is not None) and (0 < self._block_args.se_ratio <= 1)
|
||||
self.id_skip = block_args.id_skip # skip connection and drop connect
|
||||
|
||||
# Expansion phase
|
||||
inp = self._block_args.input_filters # number of input channels
|
||||
oup = (
|
||||
self._block_args.input_filters * self._block_args.expand_ratio
|
||||
) # number of output channels
|
||||
if self._block_args.expand_ratio != 1:
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._expand_conv = Conv2d(
|
||||
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn0 = nn.BatchNorm2d(
|
||||
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
# image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing
|
||||
|
||||
# Depthwise convolution phase
|
||||
k = self._block_args.kernel_size
|
||||
s = self._block_args.stride
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._depthwise_conv = Conv2d(
|
||||
in_channels=oup,
|
||||
out_channels=oup,
|
||||
groups=oup, # groups makes it depthwise
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
bias=False,
|
||||
)
|
||||
self._bn1 = nn.BatchNorm2d(
|
||||
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
image_size = calculate_output_image_size(image_size, s)
|
||||
|
||||
# Squeeze and Excitation layer, if desired
|
||||
if self.has_se:
|
||||
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
||||
num_squeezed_channels = max(
|
||||
1,
|
||||
int(
|
||||
self._block_args.input_filters * self._block_args.se_ratio
|
||||
)
|
||||
)
|
||||
self._se_reduce = Conv2d(
|
||||
in_channels=oup,
|
||||
out_channels=num_squeezed_channels,
|
||||
kernel_size=1
|
||||
)
|
||||
self._se_expand = Conv2d(
|
||||
in_channels=num_squeezed_channels,
|
||||
out_channels=oup,
|
||||
kernel_size=1
|
||||
)
|
||||
|
||||
# Output phase
|
||||
final_oup = self._block_args.output_filters
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._project_conv = Conv2d(
|
||||
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn2 = nn.BatchNorm2d(
|
||||
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
|
||||
def forward(self, inputs, drop_connect_rate=None):
|
||||
"""
|
||||
:param inputs: input tensor
|
||||
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
|
||||
:return: output of block
|
||||
"""
|
||||
|
||||
# Expansion and Depthwise Convolution
|
||||
x = inputs
|
||||
if self._block_args.expand_ratio != 1:
|
||||
x = self._swish(self._bn0(self._expand_conv(inputs)))
|
||||
x = self._swish(self._bn1(self._depthwise_conv(x)))
|
||||
|
||||
# Squeeze and Excitation
|
||||
if self.has_se:
|
||||
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
||||
x_squeezed = self._se_expand(
|
||||
self._swish(self._se_reduce(x_squeezed))
|
||||
)
|
||||
x = torch.sigmoid(x_squeezed) * x
|
||||
|
||||
x = self._bn2(self._project_conv(x))
|
||||
|
||||
# Skip connection and drop connect
|
||||
input_filters, output_filters = (
|
||||
self._block_args.input_filters,
|
||||
self._block_args.output_filters,
|
||||
)
|
||||
if (
|
||||
self.id_skip and self._block_args.stride == 1
|
||||
and input_filters == output_filters
|
||||
):
|
||||
if drop_connect_rate:
|
||||
x = drop_connect(
|
||||
x, p=drop_connect_rate, training=self.training
|
||||
)
|
||||
x = x + inputs # skip connection
|
||||
return x
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
|
||||
|
||||
class EfficientNet(Backbone):
|
||||
"""
|
||||
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
|
||||
|
||||
Args:
|
||||
blocks_args (list): A list of BlockArgs to construct blocks
|
||||
global_params (namedtuple): A set of GlobalParams shared between blocks
|
||||
|
||||
Example:
|
||||
model = EfficientNet.from_pretrained('efficientnet-b0')
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, blocks_args=None, global_params=None):
|
||||
super().__init__()
|
||||
assert isinstance(blocks_args, list), "blocks_args should be a list"
|
||||
assert len(blocks_args) > 0, "block args must be greater than 0"
|
||||
self._global_params = global_params
|
||||
self._blocks_args = blocks_args
|
||||
|
||||
# Batch norm parameters
|
||||
bn_mom = 1 - self._global_params.batch_norm_momentum
|
||||
bn_eps = self._global_params.batch_norm_epsilon
|
||||
|
||||
# Get stem static or dynamic convolution depending on image size
|
||||
image_size = global_params.image_size
|
||||
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
||||
|
||||
# Stem
|
||||
in_channels = 3 # rgb
|
||||
out_channels = round_filters(
|
||||
32, self._global_params
|
||||
) # number of output channels
|
||||
self._conv_stem = Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
||||
)
|
||||
self._bn0 = nn.BatchNorm2d(
|
||||
num_features=out_channels, momentum=bn_mom, eps=bn_eps
|
||||
)
|
||||
image_size = calculate_output_image_size(image_size, 2)
|
||||
|
||||
# Build blocks
|
||||
self._blocks = nn.ModuleList([])
|
||||
for block_args in self._blocks_args:
|
||||
|
||||
# Update block input and output filters based on depth multiplier.
|
||||
block_args = block_args._replace(
|
||||
input_filters=round_filters(
|
||||
block_args.input_filters, self._global_params
|
||||
),
|
||||
output_filters=round_filters(
|
||||
block_args.output_filters, self._global_params
|
||||
),
|
||||
num_repeat=round_repeats(
|
||||
block_args.num_repeat, self._global_params
|
||||
),
|
||||
)
|
||||
|
||||
# The first block needs to take care of stride and filter size increase.
|
||||
self._blocks.append(
|
||||
MBConvBlock(
|
||||
block_args, self._global_params, image_size=image_size
|
||||
)
|
||||
)
|
||||
image_size = calculate_output_image_size(
|
||||
image_size, block_args.stride
|
||||
)
|
||||
if block_args.num_repeat > 1:
|
||||
block_args = block_args._replace(
|
||||
input_filters=block_args.output_filters, stride=1
|
||||
)
|
||||
for _ in range(block_args.num_repeat - 1):
|
||||
self._blocks.append(
|
||||
MBConvBlock(
|
||||
block_args, self._global_params, image_size=image_size
|
||||
)
|
||||
)
|
||||
# image_size = calculate_output_image_size(image_size, block_args.stride) # ?
|
||||
|
||||
# Head
|
||||
in_channels = block_args.output_filters # output of final block
|
||||
out_channels = round_filters(1280, self._global_params)
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._conv_head = Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn1 = nn.BatchNorm2d(
|
||||
num_features=out_channels, momentum=bn_mom, eps=bn_eps
|
||||
)
|
||||
|
||||
# Final linear layer
|
||||
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
||||
# self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
|
||||
self._out_features = out_channels
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
for block in self._blocks:
|
||||
block.set_swish(memory_efficient)
|
||||
|
||||
def extract_features(self, inputs):
|
||||
"""Returns output of the final convolution layer"""
|
||||
|
||||
# Stem
|
||||
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
||||
|
||||
# Blocks
|
||||
for idx, block in enumerate(self._blocks):
|
||||
drop_connect_rate = self._global_params.drop_connect_rate
|
||||
if drop_connect_rate:
|
||||
drop_connect_rate *= float(idx) / len(self._blocks)
|
||||
x = block(x, drop_connect_rate=drop_connect_rate)
|
||||
|
||||
# Head
|
||||
x = self._swish(self._bn1(self._conv_head(x)))
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Calls extract_features to extract features, applies
|
||||
final linear layer, and returns logits.
|
||||
"""
|
||||
bs = inputs.size(0)
|
||||
# Convolution layers
|
||||
x = self.extract_features(inputs)
|
||||
|
||||
# Pooling and final linear layer
|
||||
x = self._avg_pooling(x)
|
||||
x = x.view(bs, -1)
|
||||
x = self._dropout(x)
|
||||
# x = self._fc(x)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, model_name, override_params=None):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
blocks_args, global_params = get_model_params(
|
||||
model_name, override_params
|
||||
)
|
||||
return cls(blocks_args, global_params)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name, advprop=False, num_classes=1000, in_channels=3
|
||||
):
|
||||
model = cls.from_name(
|
||||
model_name, override_params={"num_classes": num_classes}
|
||||
)
|
||||
load_pretrained_weights(
|
||||
model, model_name, load_fc=(num_classes == 1000), advprop=advprop
|
||||
)
|
||||
model._change_in_channels(in_channels)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_image_size(cls, model_name):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
_, _, res, _ = efficientnet_params(model_name)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _check_model_name_is_valid(cls, model_name):
|
||||
"""Validates model name."""
|
||||
valid_models = ["efficientnet-b" + str(i) for i in range(9)]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError(
|
||||
"model_name should be one of: " + ", ".join(valid_models)
|
||||
)
|
||||
|
||||
def _change_in_channels(model, in_channels):
|
||||
if in_channels != 3:
|
||||
Conv2d = get_same_padding_conv2d(
|
||||
image_size=model._global_params.image_size
|
||||
)
|
||||
out_channels = round_filters(32, model._global_params)
|
||||
model._conv_stem = Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
||||
)
|
||||
|
||||
|
||||
def build_efficientnet(name, pretrained):
|
||||
if pretrained:
|
||||
return EfficientNet.from_pretrained("efficientnet-{}".format(name))
|
||||
else:
|
||||
return EfficientNet.from_name("efficientnet-{}".format(name))
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b0(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b0", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b1(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b1", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b2(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b2", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b3(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b3", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b4(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b4", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b5(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b5", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b6(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b6", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b7(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b7", pretrained)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
This file contains helper functions for building the model and for loading model parameters.
|
||||
These helper functions are built to mirror those in the official TensorFlow implementation.
|
||||
"""
|
||||
|
||||
import re
|
||||
import math
|
||||
import collections
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils import model_zoo
|
||||
|
||||
########################################################################
|
||||
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
|
||||
########################################################################
|
||||
|
||||
# Parameters for the entire model (stem, all blocks, and head)
|
||||
GlobalParams = collections.namedtuple(
|
||||
"GlobalParams",
|
||||
[
|
||||
"batch_norm_momentum",
|
||||
"batch_norm_epsilon",
|
||||
"dropout_rate",
|
||||
"num_classes",
|
||||
"width_coefficient",
|
||||
"depth_coefficient",
|
||||
"depth_divisor",
|
||||
"min_depth",
|
||||
"drop_connect_rate",
|
||||
"image_size",
|
||||
],
|
||||
)
|
||||
|
||||
# Parameters for an individual model block
|
||||
BlockArgs = collections.namedtuple(
|
||||
"BlockArgs",
|
||||
[
|
||||
"kernel_size",
|
||||
"num_repeat",
|
||||
"input_filters",
|
||||
"output_filters",
|
||||
"expand_ratio",
|
||||
"id_skip",
|
||||
"stride",
|
||||
"se_ratio",
|
||||
],
|
||||
)
|
||||
|
||||
# Change namedtuple defaults
|
||||
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
|
||||
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
|
||||
|
||||
|
||||
class SwishImplementation(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, i):
|
||||
result = i * torch.sigmoid(i)
|
||||
ctx.save_for_backward(i)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
i = ctx.saved_variables[0]
|
||||
sigmoid_i = torch.sigmoid(i)
|
||||
return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i)))
|
||||
|
||||
|
||||
class MemoryEfficientSwish(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return SwishImplementation.apply(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def round_filters(filters, global_params):
|
||||
"""Calculate and round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.width_coefficient
|
||||
if not multiplier:
|
||||
return filters
|
||||
divisor = global_params.depth_divisor
|
||||
min_depth = global_params.min_depth
|
||||
filters *= multiplier
|
||||
min_depth = min_depth or divisor
|
||||
new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor)
|
||||
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
|
||||
def round_repeats(repeats, global_params):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.depth_coefficient
|
||||
if not multiplier:
|
||||
return repeats
|
||||
return int(math.ceil(multiplier * repeats))
|
||||
|
||||
|
||||
def drop_connect(inputs, p, training):
|
||||
"""Drop connect."""
|
||||
if not training:
|
||||
return inputs
|
||||
batch_size = inputs.shape[0]
|
||||
keep_prob = 1 - p
|
||||
random_tensor = keep_prob
|
||||
random_tensor += torch.rand(
|
||||
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
|
||||
)
|
||||
binary_tensor = torch.floor(random_tensor)
|
||||
output = inputs / keep_prob * binary_tensor
|
||||
return output
|
||||
|
||||
|
||||
def get_same_padding_conv2d(image_size=None):
|
||||
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
||||
Static padding is necessary for ONNX exporting of models."""
|
||||
if image_size is None:
|
||||
return Conv2dDynamicSamePadding
|
||||
else:
|
||||
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
||||
|
||||
|
||||
def get_width_and_height_from_size(x):
|
||||
"""Obtains width and height from a int or tuple"""
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
if isinstance(x, list) or isinstance(x, tuple):
|
||||
return x
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
|
||||
def calculate_output_image_size(input_image_size, stride):
|
||||
"""
|
||||
Calculates the output image size when using Conv2dSamePadding with a stride.
|
||||
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
||||
"""
|
||||
if input_image_size is None:
|
||||
return None
|
||||
image_height, image_width = get_width_and_height_from_size(
|
||||
input_image_size
|
||||
)
|
||||
stride = stride if isinstance(stride, int) else stride[0]
|
||||
image_height = int(math.ceil(image_height / stride))
|
||||
image_width = int(math.ceil(image_width / stride))
|
||||
return [image_height, image_width]
|
||||
|
||||
|
||||
class Conv2dDynamicSamePadding(nn.Conv2d):
|
||||
"""2D Convolutions like TensorFlow, for a dynamic image size"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
||||
groups, bias
|
||||
)
|
||||
self.stride = self.stride if len(self.stride
|
||||
) == 2 else [self.stride[0]] * 2
|
||||
|
||||
def forward(self, x):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
sh, sw = self.stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max(
|
||||
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
|
||||
)
|
||||
pad_w = max(
|
||||
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
|
||||
)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(
|
||||
x,
|
||||
[pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2]
|
||||
)
|
||||
return F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
class Conv2dStaticSamePadding(nn.Conv2d):
|
||||
"""2D Convolutions like TensorFlow, for a fixed image size"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
image_size=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
||||
self.stride = self.stride if len(self.stride
|
||||
) == 2 else [self.stride[0]] * 2
|
||||
|
||||
# Calculate padding based on image size and save it
|
||||
assert image_size is not None
|
||||
ih, iw = (image_size,
|
||||
image_size) if isinstance(image_size, int) else image_size
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
sh, sw = self.stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max(
|
||||
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
|
||||
)
|
||||
pad_w = max(
|
||||
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
|
||||
)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
self.static_padding = nn.ZeroPad2d(
|
||||
(pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2)
|
||||
)
|
||||
else:
|
||||
self.static_padding = Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.static_padding(x)
|
||||
x = F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self, ):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
########################################################################
|
||||
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
|
||||
########################################################################
|
||||
|
||||
|
||||
def efficientnet_params(model_name):
|
||||
"""Map EfficientNet model name to parameter coefficients."""
|
||||
params_dict = {
|
||||
# Coefficients: width,depth,res,dropout
|
||||
"efficientnet-b0": (1.0, 1.0, 224, 0.2),
|
||||
"efficientnet-b1": (1.0, 1.1, 240, 0.2),
|
||||
"efficientnet-b2": (1.1, 1.2, 260, 0.3),
|
||||
"efficientnet-b3": (1.2, 1.4, 300, 0.3),
|
||||
"efficientnet-b4": (1.4, 1.8, 380, 0.4),
|
||||
"efficientnet-b5": (1.6, 2.2, 456, 0.4),
|
||||
"efficientnet-b6": (1.8, 2.6, 528, 0.5),
|
||||
"efficientnet-b7": (2.0, 3.1, 600, 0.5),
|
||||
"efficientnet-b8": (2.2, 3.6, 672, 0.5),
|
||||
"efficientnet-l2": (4.3, 5.3, 800, 0.5),
|
||||
}
|
||||
return params_dict[model_name]
|
||||
|
||||
|
||||
class BlockDecoder(object):
|
||||
"""Block Decoder for readability, straight from the official TensorFlow repository"""
|
||||
|
||||
@staticmethod
|
||||
def _decode_block_string(block_string):
|
||||
"""Gets a block through a string notation of arguments."""
|
||||
assert isinstance(block_string, str)
|
||||
|
||||
ops = block_string.split("_")
|
||||
options = {}
|
||||
for op in ops:
|
||||
splits = re.split(r"(\d.*)", op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
# Check stride
|
||||
assert ("s" in options and len(options["s"]) == 1) or (
|
||||
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
|
||||
)
|
||||
|
||||
return BlockArgs(
|
||||
kernel_size=int(options["k"]),
|
||||
num_repeat=int(options["r"]),
|
||||
input_filters=int(options["i"]),
|
||||
output_filters=int(options["o"]),
|
||||
expand_ratio=int(options["e"]),
|
||||
id_skip=("noskip" not in block_string),
|
||||
se_ratio=float(options["se"]) if "se" in options else None,
|
||||
stride=[int(options["s"][0])],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _encode_block_string(block):
|
||||
"""Encodes a block to a string."""
|
||||
args = [
|
||||
"r%d" % block.num_repeat,
|
||||
"k%d" % block.kernel_size,
|
||||
"s%d%d" % (block.strides[0], block.strides[1]),
|
||||
"e%s" % block.expand_ratio,
|
||||
"i%d" % block.input_filters,
|
||||
"o%d" % block.output_filters,
|
||||
]
|
||||
if 0 < block.se_ratio <= 1:
|
||||
args.append("se%s" % block.se_ratio)
|
||||
if block.id_skip is False:
|
||||
args.append("noskip")
|
||||
return "_".join(args)
|
||||
|
||||
@staticmethod
|
||||
def decode(string_list):
|
||||
"""
|
||||
Decodes a list of string notations to specify blocks inside the network.
|
||||
|
||||
:param string_list: a list of strings, each string is a notation of block
|
||||
:return: a list of BlockArgs namedtuples of block args
|
||||
"""
|
||||
assert isinstance(string_list, list)
|
||||
blocks_args = []
|
||||
for block_string in string_list:
|
||||
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
||||
return blocks_args
|
||||
|
||||
@staticmethod
|
||||
def encode(blocks_args):
|
||||
"""
|
||||
Encodes a list of BlockArgs to a list of strings.
|
||||
|
||||
:param blocks_args: a list of BlockArgs namedtuples of block args
|
||||
:return: a list of strings, each string is a notation of block
|
||||
"""
|
||||
block_strings = []
|
||||
for block in blocks_args:
|
||||
block_strings.append(BlockDecoder._encode_block_string(block))
|
||||
return block_strings
|
||||
|
||||
|
||||
def efficientnet(
|
||||
width_coefficient=None,
|
||||
depth_coefficient=None,
|
||||
dropout_rate=0.2,
|
||||
drop_connect_rate=0.2,
|
||||
image_size=None,
|
||||
num_classes=1000,
|
||||
):
|
||||
"""Creates a efficientnet model."""
|
||||
|
||||
blocks_args = [
|
||||
"r1_k3_s11_e1_i32_o16_se0.25",
|
||||
"r2_k3_s22_e6_i16_o24_se0.25",
|
||||
"r2_k5_s22_e6_i24_o40_se0.25",
|
||||
"r3_k3_s22_e6_i40_o80_se0.25",
|
||||
"r3_k5_s11_e6_i80_o112_se0.25",
|
||||
"r4_k5_s22_e6_i112_o192_se0.25",
|
||||
"r1_k3_s11_e6_i192_o320_se0.25",
|
||||
]
|
||||
blocks_args = BlockDecoder.decode(blocks_args)
|
||||
|
||||
global_params = GlobalParams(
|
||||
batch_norm_momentum=0.99,
|
||||
batch_norm_epsilon=1e-3,
|
||||
dropout_rate=dropout_rate,
|
||||
drop_connect_rate=drop_connect_rate,
|
||||
# data_format='channels_last', # removed, this is always true in PyTorch
|
||||
num_classes=num_classes,
|
||||
width_coefficient=width_coefficient,
|
||||
depth_coefficient=depth_coefficient,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
return blocks_args, global_params
|
||||
|
||||
|
||||
def get_model_params(model_name, override_params):
|
||||
"""Get the block args and global params for a given model"""
|
||||
if model_name.startswith("efficientnet"):
|
||||
w, d, s, p = efficientnet_params(model_name)
|
||||
# note: all models have drop connect rate = 0.2
|
||||
blocks_args, global_params = efficientnet(
|
||||
width_coefficient=w,
|
||||
depth_coefficient=d,
|
||||
dropout_rate=p,
|
||||
image_size=s
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"model name is not pre-defined: %s" % model_name
|
||||
)
|
||||
if override_params:
|
||||
# ValueError will be raised here if override_params has fields not included in global_params.
|
||||
global_params = global_params._replace(**override_params)
|
||||
return blocks_args, global_params
|
||||
|
||||
|
||||
url_map = {
|
||||
"efficientnet-b0":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
|
||||
"efficientnet-b1":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
|
||||
"efficientnet-b2":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
|
||||
"efficientnet-b3":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
|
||||
"efficientnet-b4":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||
"efficientnet-b5":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||
"efficientnet-b6":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||
"efficientnet-b7":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
|
||||
}
|
||||
|
||||
url_map_advprop = {
|
||||
"efficientnet-b0":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
|
||||
"efficientnet-b1":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
|
||||
"efficientnet-b2":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
|
||||
"efficientnet-b3":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
|
||||
"efficientnet-b4":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
|
||||
"efficientnet-b5":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
|
||||
"efficientnet-b6":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
|
||||
"efficientnet-b7":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
|
||||
"efficientnet-b8":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
|
||||
}
|
||||
|
||||
|
||||
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
|
||||
"""Loads pretrained weights, and downloads if loading for the first time."""
|
||||
# AutoAugment or Advprop (different preprocessing)
|
||||
url_map_ = url_map_advprop if advprop else url_map
|
||||
state_dict = model_zoo.load_url(url_map_[model_name])
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
"""
|
||||
if load_fc:
|
||||
model.load_state_dict(state_dict)
|
||||
else:
|
||||
state_dict.pop('_fc.weight')
|
||||
state_dict.pop('_fc.bias')
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
||||
|
||||
print('Loaded pretrained weights for {}'.format(model_name))
|
||||
"""
|
||||
217
Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py
Normal file
217
Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
model_urls = {
|
||||
"mobilenet_v2":
|
||||
"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
}
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor/2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self, in_planes, out_planes, kernel_size=3, stride=1, groups=1
|
||||
):
|
||||
padding = (kernel_size-1) // 2
|
||||
super().__init__(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True),
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
ConvBNReLU(
|
||||
hidden_dim, hidden_dim, stride=stride, groups=hidden_dim
|
||||
),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(Backbone):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width_mult=1.0,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None,
|
||||
):
|
||||
"""
|
||||
MobileNet V2.
|
||||
|
||||
Args:
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if (
|
||||
len(inverted_residual_setting) == 0
|
||||
or len(inverted_residual_setting[0]) != 4
|
||||
):
|
||||
raise ValueError(
|
||||
"inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".
|
||||
format(inverted_residual_setting)
|
||||
)
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(
|
||||
input_channel * width_mult, round_nearest
|
||||
)
|
||||
self.last_channel = _make_divisible(
|
||||
last_channel * max(1.0, width_mult), round_nearest
|
||||
)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(
|
||||
block(
|
||||
input_channel, output_channel, stride, expand_ratio=t
|
||||
)
|
||||
)
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(
|
||||
ConvBNReLU(input_channel, self.last_channel, kernel_size=1)
|
||||
)
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
self._out_features = self.last_channel
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
x = x.mean([2, 3])
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
if model_url is None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"ImageNet pretrained weights are unavailable for this model"
|
||||
)
|
||||
return
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def mobilenetv2(pretrained=True, **kwargs):
|
||||
model = MobileNetV2(**kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["mobilenet_v2"])
|
||||
return model
|
||||
135
Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py
Normal file
135
Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class PreActBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super().__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super().__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, self.expansion * planes, kernel_size=1, bias=False
|
||||
)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out = self.conv3(F.relu(self.bn3(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class PreActResNet(Backbone):
|
||||
|
||||
def __init__(self, block, num_blocks):
|
||||
super().__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
|
||||
self._out_features = 512 * block.expansion
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
return out
|
||||
|
||||
|
||||
"""
|
||||
Preact-ResNet18 was used for the CIFAR10 and
|
||||
SVHN datasets (both are SSL tasks) in
|
||||
|
||||
- Wang et al. Semi-Supervised Learning by
|
||||
Augmented Distribution Alignment. ICCV 2019.
|
||||
"""
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def preact_resnet18(**kwargs):
|
||||
return PreActResNet(PreActBlock, [2, 2, 2, 2])
|
||||
589
Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py
Normal file
589
Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py
Normal file
@@ -0,0 +1,589 @@
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
model_urls = {
|
||||
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
||||
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
||||
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
||||
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
||||
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, planes * self.expansion, kernel_size=1, bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(Backbone):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block,
|
||||
layers,
|
||||
ms_class=None,
|
||||
ms_layers=[],
|
||||
ms_p=0.5,
|
||||
ms_a=0.1,
|
||||
**kwargs
|
||||
):
|
||||
self.inplanes = 64
|
||||
super().__init__()
|
||||
|
||||
# backbone network
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self._out_features = 512 * block.expansion
|
||||
|
||||
self.mixstyle = None
|
||||
if ms_layers:
|
||||
self.mixstyle = ms_class(p=ms_p, alpha=ms_a)
|
||||
for layer_name in ms_layers:
|
||||
assert layer_name in ["layer1", "layer2", "layer3"]
|
||||
print(f"Insert MixStyle after {ms_layers}")
|
||||
self.ms_layers = ms_layers
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
if "layer1" in self.ms_layers:
|
||||
x = self.mixstyle(x)
|
||||
x = self.layer2(x)
|
||||
if "layer2" in self.ms_layers:
|
||||
x = self.mixstyle(x)
|
||||
x = self.layer3(x)
|
||||
if "layer3" in self.ms_layers:
|
||||
x = self.mixstyle(x)
|
||||
return self.layer4(x)
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.global_avgpool(f)
|
||||
return v.view(v.size(0), -1)
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model.load_state_dict(pretrain_dict, strict=False)
|
||||
|
||||
|
||||
"""
|
||||
Residual network configurations:
|
||||
--
|
||||
resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
|
||||
resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
|
||||
resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
|
||||
resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
|
||||
resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
|
||||
"""
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18(pretrained=True, **kwargs):
|
||||
model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet34(pretrained=True, **kwargs):
|
||||
model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet34"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50(pretrained=True, **kwargs):
|
||||
model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101(pretrained=True, **kwargs):
|
||||
model = ResNet(block=Bottleneck, layers=[3, 4, 23, 3])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet152(pretrained=True, **kwargs):
|
||||
model = ResNet(block=Bottleneck, layers=[3, 8, 36, 3])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet152"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
"""
|
||||
Residual networks with mixstyle
|
||||
"""
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_ms_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_ms_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_ms_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_ms_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_ms_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_ms_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_ms_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_ms_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_ms_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
"""
|
||||
Residual networks with efdmix
|
||||
"""
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_efdmix_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_efdmix_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_efdmix_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_efdmix_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_efdmix_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_efdmix_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_efdmix_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_efdmix_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_efdmix_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
229
Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py
Normal file
229
Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Code source: https://github.com/pytorch/vision
|
||||
"""
|
||||
import torch
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
model_urls = {
|
||||
"shufflenetv2_x0.5":
|
||||
"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
|
||||
"shufflenetv2_x1.0":
|
||||
"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
|
||||
"shufflenetv2_x1.5": None,
|
||||
"shufflenetv2_x2.0": None,
|
||||
}
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
batchsize, num_channels, height, width = x.data.size()
|
||||
channels_per_group = num_channels // groups
|
||||
|
||||
# reshape
|
||||
x = x.view(batchsize, groups, channels_per_group, height, width)
|
||||
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
|
||||
# flatten
|
||||
x = x.view(batchsize, -1, height, width)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
|
||||
def __init__(self, inp, oup, stride):
|
||||
super().__init__()
|
||||
|
||||
if not (1 <= stride <= 3):
|
||||
raise ValueError("illegal stride value")
|
||||
self.stride = stride
|
||||
|
||||
branch_features = oup // 2
|
||||
assert (self.stride != 1) or (inp == branch_features << 1)
|
||||
|
||||
if self.stride > 1:
|
||||
self.branch1 = nn.Sequential(
|
||||
self.depthwise_conv(
|
||||
inp, inp, kernel_size=3, stride=self.stride, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.Conv2d(
|
||||
inp,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False
|
||||
),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inp if (self.stride > 1) else branch_features,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
self.depthwise_conv(
|
||||
branch_features,
|
||||
branch_features,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.Conv2d(
|
||||
branch_features,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
|
||||
return nn.Conv2d(
|
||||
i, o, kernel_size, stride, padding, bias=bias, groups=i
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
x1, x2 = x.chunk(2, dim=1)
|
||||
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
||||
else:
|
||||
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
||||
|
||||
out = channel_shuffle(out, 2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNetV2(Backbone):
|
||||
|
||||
def __init__(self, stages_repeats, stages_out_channels, **kwargs):
|
||||
super().__init__()
|
||||
if len(stages_repeats) != 3:
|
||||
raise ValueError(
|
||||
"expected stages_repeats as list of 3 positive ints"
|
||||
)
|
||||
if len(stages_out_channels) != 5:
|
||||
raise ValueError(
|
||||
"expected stages_out_channels as list of 5 positive ints"
|
||||
)
|
||||
self._stage_out_channels = stages_out_channels
|
||||
|
||||
input_channels = 3
|
||||
output_channels = self._stage_out_channels[0]
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
|
||||
for name, repeats, output_channels in zip(
|
||||
stage_names, stages_repeats, self._stage_out_channels[1:]
|
||||
):
|
||||
seq = [InvertedResidual(input_channels, output_channels, 2)]
|
||||
for i in range(repeats - 1):
|
||||
seq.append(
|
||||
InvertedResidual(output_channels, output_channels, 1)
|
||||
)
|
||||
setattr(self, name, nn.Sequential(*seq))
|
||||
input_channels = output_channels
|
||||
|
||||
output_channels = self._stage_out_channels[-1]
|
||||
self.conv5 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
self._out_features = output_channels
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = self.conv5(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.global_avgpool(f)
|
||||
return v.view(v.size(0), -1)
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
if model_url is None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"ImageNet pretrained weights are unavailable for this model"
|
||||
)
|
||||
return
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def shufflenet_v2_x0_5(pretrained=True, **kwargs):
|
||||
model = ShuffleNetV2([4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["shufflenetv2_x0.5"])
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def shufflenet_v2_x1_0(pretrained=True, **kwargs):
|
||||
model = ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["shufflenetv2_x1.0"])
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def shufflenet_v2_x1_5(pretrained=True, **kwargs):
|
||||
model = ShuffleNetV2([4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["shufflenetv2_x1.5"])
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def shufflenet_v2_x2_0(pretrained=True, **kwargs):
|
||||
model = ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["shufflenetv2_x2.0"])
|
||||
return model
|
||||
147
Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py
Normal file
147
Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
model_urls = {
|
||||
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
|
||||
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
|
||||
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
||||
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
||||
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
||||
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
||||
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
||||
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
|
||||
}
|
||||
|
||||
|
||||
class VGG(Backbone):
|
||||
|
||||
def __init__(self, features, init_weights=True):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
||||
# Note that self.classifier outputs features rather than logits
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 7 * 7, 4096),
|
||||
nn.ReLU(True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(True),
|
||||
nn.Dropout(),
|
||||
)
|
||||
|
||||
self._out_features = 4096
|
||||
|
||||
if init_weights:
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
return self.classifier(x)
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def make_layers(cfg, batch_norm=False):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for v in cfg:
|
||||
if v == "M":
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = v
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
cfgs = {
|
||||
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
||||
"B":
|
||||
[64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
||||
"D": [
|
||||
64,
|
||||
64,
|
||||
"M",
|
||||
128,
|
||||
128,
|
||||
"M",
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
],
|
||||
"E": [
|
||||
64,
|
||||
64,
|
||||
"M",
|
||||
128,
|
||||
128,
|
||||
"M",
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _vgg(arch, cfg, batch_norm, pretrained):
|
||||
init_weights = False if pretrained else True
|
||||
model = VGG(
|
||||
make_layers(cfgs[cfg], batch_norm=batch_norm),
|
||||
init_weights=init_weights
|
||||
)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=True)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def vgg16(pretrained=True, **kwargs):
|
||||
return _vgg("vgg16", "D", False, pretrained)
|
||||
150
Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py
Normal file
150
Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Modified from https://github.com/xternalz/WideResNet-pytorch
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||
super().__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.relu1 = nn.LeakyReLU(0.01, inplace=True)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
self.relu2 = nn.LeakyReLU(0.01, inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.droprate = dropRate
|
||||
self.equalInOut = in_planes == out_planes
|
||||
self.convShortcut = (
|
||||
(not self.equalInOut) and nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=False,
|
||||
) or None
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if not self.equalInOut:
|
||||
x = self.relu1(self.bn1(x))
|
||||
else:
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, training=self.training)
|
||||
out = self.conv2(out)
|
||||
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
||||
|
||||
|
||||
class NetworkBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = self._make_layer(
|
||||
block, in_planes, out_planes, nb_layers, stride, dropRate
|
||||
)
|
||||
|
||||
def _make_layer(
|
||||
self, block, in_planes, out_planes, nb_layers, stride, dropRate
|
||||
):
|
||||
layers = []
|
||||
for i in range(int(nb_layers)):
|
||||
layers.append(
|
||||
block(
|
||||
i == 0 and in_planes or out_planes,
|
||||
out_planes,
|
||||
i == 0 and stride or 1,
|
||||
dropRate,
|
||||
)
|
||||
)
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class WideResNet(Backbone):
|
||||
|
||||
def __init__(self, depth, widen_factor, dropRate=0.0):
|
||||
super().__init__()
|
||||
nChannels = [
|
||||
16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor
|
||||
]
|
||||
assert (depth-4) % 6 == 0
|
||||
n = (depth-4) / 6
|
||||
block = BasicBlock
|
||||
# 1st conv before any network block
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
# 1st block
|
||||
self.block1 = NetworkBlock(
|
||||
n, nChannels[0], nChannels[1], block, 1, dropRate
|
||||
)
|
||||
# 2nd block
|
||||
self.block2 = NetworkBlock(
|
||||
n, nChannels[1], nChannels[2], block, 2, dropRate
|
||||
)
|
||||
# 3rd block
|
||||
self.block3 = NetworkBlock(
|
||||
n, nChannels[2], nChannels[3], block, 2, dropRate
|
||||
)
|
||||
# global average pooling and classifier
|
||||
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
||||
self.relu = nn.LeakyReLU(0.01, inplace=True)
|
||||
|
||||
self._out_features = nChannels[3]
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.block1(out)
|
||||
out = self.block2(out)
|
||||
out = self.block3(out)
|
||||
out = self.relu(self.bn1(out))
|
||||
out = F.adaptive_avg_pool2d(out, 1)
|
||||
return out.view(out.size(0), -1)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def wide_resnet_28_2(**kwargs):
|
||||
return WideResNet(28, 2)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def wide_resnet_16_4(**kwargs):
|
||||
return WideResNet(16, 4)
|
||||
3
Dassl.ProGrad.pytorch/dassl/modeling/head/__init__.py
Normal file
3
Dassl.ProGrad.pytorch/dassl/modeling/head/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .build import build_head, HEAD_REGISTRY # isort:skip
|
||||
|
||||
from .mlp import mlp
|
||||
11
Dassl.ProGrad.pytorch/dassl/modeling/head/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/modeling/head/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
HEAD_REGISTRY = Registry("HEAD")
|
||||
|
||||
|
||||
def build_head(name, verbose=True, **kwargs):
|
||||
avai_heads = HEAD_REGISTRY.registered_names()
|
||||
check_availability(name, avai_heads)
|
||||
if verbose:
|
||||
print("Head: {}".format(name))
|
||||
return HEAD_REGISTRY.get(name)(**kwargs)
|
||||
50
Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py
Normal file
50
Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import functools
|
||||
import torch.nn as nn
|
||||
|
||||
from .build import HEAD_REGISTRY
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features=2048,
|
||||
hidden_layers=[],
|
||||
activation="relu",
|
||||
bn=True,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(hidden_layers, int):
|
||||
hidden_layers = [hidden_layers]
|
||||
|
||||
assert len(hidden_layers) > 0
|
||||
self.out_features = hidden_layers[-1]
|
||||
|
||||
mlp = []
|
||||
|
||||
if activation == "relu":
|
||||
act_fn = functools.partial(nn.ReLU, inplace=True)
|
||||
elif activation == "leaky_relu":
|
||||
act_fn = functools.partial(nn.LeakyReLU, inplace=True)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for hidden_dim in hidden_layers:
|
||||
mlp += [nn.Linear(in_features, hidden_dim)]
|
||||
if bn:
|
||||
mlp += [nn.BatchNorm1d(hidden_dim)]
|
||||
mlp += [act_fn()]
|
||||
if dropout > 0:
|
||||
mlp += [nn.Dropout(dropout)]
|
||||
in_features = hidden_dim
|
||||
|
||||
self.mlp = nn.Sequential(*mlp)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
@HEAD_REGISTRY.register()
|
||||
def mlp(**kwargs):
|
||||
return MLP(**kwargs)
|
||||
5
Dassl.ProGrad.pytorch/dassl/modeling/network/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/modeling/network/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .build import build_network, NETWORK_REGISTRY # isort:skip
|
||||
|
||||
from .ddaig_fcn import (
|
||||
fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn
|
||||
)
|
||||
11
Dassl.ProGrad.pytorch/dassl/modeling/network/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/modeling/network/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
NETWORK_REGISTRY = Registry("NETWORK")
|
||||
|
||||
|
||||
def build_network(name, verbose=True, **kwargs):
|
||||
avai_models = NETWORK_REGISTRY.registered_names()
|
||||
check_availability(name, avai_models)
|
||||
if verbose:
|
||||
print("Network: {}".format(name))
|
||||
return NETWORK_REGISTRY.get(name)(**kwargs)
|
||||
329
Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py
Normal file
329
Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
||||
"""
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .build import NETWORK_REGISTRY
|
||||
|
||||
|
||||
def init_network_weights(model, init_type="normal", gain=0.02):
|
||||
|
||||
def _init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, "weight") and (
|
||||
classname.find("Conv") != -1 or classname.find("Linear") != -1
|
||||
):
|
||||
if init_type == "normal":
|
||||
nn.init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == "xavier":
|
||||
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == "kaiming":
|
||||
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
||||
elif init_type == "orthogonal":
|
||||
nn.init.orthogonal_(m.weight.data, gain=gain)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"initialization method {} is not implemented".
|
||||
format(init_type)
|
||||
)
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find("BatchNorm2d") != -1:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find("InstanceNorm2d") != -1:
|
||||
if m.weight is not None and m.bias is not None:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
model.apply(_init_func)
|
||||
|
||||
|
||||
def get_norm_layer(norm_type="instance"):
|
||||
if norm_type == "batch":
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||||
elif norm_type == "instance":
|
||||
norm_layer = functools.partial(
|
||||
nn.InstanceNorm2d, affine=False, track_running_stats=False
|
||||
)
|
||||
elif norm_type == "none":
|
||||
norm_layer = None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"normalization layer [%s] is not found" % norm_type
|
||||
)
|
||||
return norm_layer
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||
super().__init__()
|
||||
self.conv_block = self.build_conv_block(
|
||||
dim, padding_type, norm_layer, use_dropout, use_bias
|
||||
)
|
||||
|
||||
def build_conv_block(
|
||||
self, dim, padding_type, norm_layer, use_dropout, use_bias
|
||||
):
|
||||
conv_block = []
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"padding [%s] is not implemented" % padding_type
|
||||
)
|
||||
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim),
|
||||
nn.ReLU(True),
|
||||
]
|
||||
if use_dropout:
|
||||
conv_block += [nn.Dropout(0.5)]
|
||||
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"padding [%s] is not implemented" % padding_type
|
||||
)
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim),
|
||||
]
|
||||
|
||||
return nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.conv_block(x)
|
||||
|
||||
|
||||
class LocNet(nn.Module):
|
||||
"""Localization network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nc,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
use_dropout=False,
|
||||
padding_type="zero",
|
||||
image_size=32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
backbone = []
|
||||
backbone += [
|
||||
nn.Conv2d(
|
||||
input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False
|
||||
)
|
||||
]
|
||||
backbone += [nn.BatchNorm2d(nc)]
|
||||
backbone += [nn.ReLU(True)]
|
||||
for _ in range(n_blocks):
|
||||
backbone += [
|
||||
ResnetBlock(
|
||||
nc,
|
||||
padding_type=padding_type,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
use_dropout=use_dropout,
|
||||
use_bias=False,
|
||||
)
|
||||
]
|
||||
backbone += [nn.MaxPool2d(2, stride=2)]
|
||||
self.backbone = nn.Sequential(*backbone)
|
||||
reduced_imsize = int(image_size * 0.5**(n_blocks + 1))
|
||||
self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc_loc(x)
|
||||
x = torch.tanh(x)
|
||||
x = x.view(-1, 2, 2)
|
||||
theta = x.data.new_zeros(x.size(0), 2, 3)
|
||||
theta[:, :, :2] = x
|
||||
return theta
|
||||
|
||||
|
||||
class FCN(nn.Module):
|
||||
"""Fully convolutional network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nc,
|
||||
output_nc,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
use_dropout=False,
|
||||
padding_type="reflect",
|
||||
gctx=True,
|
||||
stn=False,
|
||||
image_size=32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
backbone = []
|
||||
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
backbone += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
backbone += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
backbone += [
|
||||
nn.Conv2d(
|
||||
input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False
|
||||
)
|
||||
]
|
||||
backbone += [norm_layer(nc)]
|
||||
backbone += [nn.ReLU(True)]
|
||||
|
||||
for _ in range(n_blocks):
|
||||
backbone += [
|
||||
ResnetBlock(
|
||||
nc,
|
||||
padding_type=padding_type,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=use_dropout,
|
||||
use_bias=False,
|
||||
)
|
||||
]
|
||||
self.backbone = nn.Sequential(*backbone)
|
||||
|
||||
# global context fusion layer
|
||||
self.gctx_fusion = None
|
||||
if gctx:
|
||||
self.gctx_fusion = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False
|
||||
),
|
||||
norm_layer(nc),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
self.regress = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True
|
||||
),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.locnet = None
|
||||
if stn:
|
||||
self.locnet = LocNet(
|
||||
input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size
|
||||
)
|
||||
|
||||
def init_loc_layer(self):
|
||||
"""Initialize the weights/bias with identity transformation."""
|
||||
if self.locnet is not None:
|
||||
self.locnet.fc_loc.weight.data.zero_()
|
||||
self.locnet.fc_loc.bias.data.copy_(
|
||||
torch.tensor([1, 0, 0, 1], dtype=torch.float)
|
||||
)
|
||||
|
||||
def stn(self, x):
|
||||
"""Spatial transformer network."""
|
||||
theta = self.locnet(x)
|
||||
grid = F.affine_grid(theta, x.size())
|
||||
return F.grid_sample(x, grid), theta
|
||||
|
||||
def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): input mini-batch.
|
||||
lmda (float): multiplier for perturbation.
|
||||
return_p (bool): return perturbation.
|
||||
return_stn_output (bool): return the output of stn.
|
||||
"""
|
||||
theta = None
|
||||
if self.locnet is not None:
|
||||
x, theta = self.stn(x)
|
||||
input = x
|
||||
|
||||
x = self.backbone(x)
|
||||
if self.gctx_fusion is not None:
|
||||
c = F.adaptive_avg_pool2d(x, (1, 1))
|
||||
c = c.expand_as(x)
|
||||
x = torch.cat([x, c], 1)
|
||||
x = self.gctx_fusion(x)
|
||||
|
||||
p = self.regress(x)
|
||||
x_p = input + lmda*p
|
||||
|
||||
if return_stn_output:
|
||||
return x_p, p, input
|
||||
|
||||
if return_p:
|
||||
return x_p, p
|
||||
|
||||
return x_p
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x32_gctx(**kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x64_gctx(**kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x32_gctx_stn(image_size=32, **kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(
|
||||
3,
|
||||
3,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
norm_layer=norm_layer,
|
||||
stn=True,
|
||||
image_size=image_size
|
||||
)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
net.init_loc_layer()
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x64_gctx_stn(image_size=224, **kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(
|
||||
3,
|
||||
3,
|
||||
nc=64,
|
||||
n_blocks=3,
|
||||
norm_layer=norm_layer,
|
||||
stn=True,
|
||||
image_size=image_size
|
||||
)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
net.init_loc_layer()
|
||||
return net
|
||||
16
Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py
Normal file
16
Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .mmd import MaximumMeanDiscrepancy
|
||||
from .dsbn import DSBN1d, DSBN2d
|
||||
from .mixup import mixup
|
||||
from .efdmix import (
|
||||
EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix,
|
||||
crossdomain_efdmix, run_without_efdmix
|
||||
)
|
||||
from .mixstyle import (
|
||||
MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle,
|
||||
deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle
|
||||
)
|
||||
from .transnorm import TransNorm1d, TransNorm2d
|
||||
from .sequential2 import Sequential2
|
||||
from .reverse_grad import ReverseGrad
|
||||
from .cross_entropy import cross_entropy
|
||||
from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance
|
||||
30
Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py
Normal file
30
Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def cross_entropy(input, target, label_smooth=0, reduction="mean"):
|
||||
"""Cross entropy loss.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): logit matrix with shape of (batch, num_classes).
|
||||
target (torch.LongTensor): int label matrix.
|
||||
label_smooth (float, optional): label smoothing hyper-parameter.
|
||||
Default is 0.
|
||||
reduction (str, optional): how the losses for a mini-batch
|
||||
will be aggregated. Default is 'mean'.
|
||||
"""
|
||||
num_classes = input.shape[1]
|
||||
log_prob = F.log_softmax(input, dim=1)
|
||||
zeros = torch.zeros(log_prob.size())
|
||||
target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1)
|
||||
target = target.type_as(input)
|
||||
target = (1-label_smooth) * target + label_smooth/num_classes
|
||||
loss = (-target * log_prob).sum(1)
|
||||
if reduction == "mean":
|
||||
return loss.mean()
|
||||
elif reduction == "sum":
|
||||
return loss.sum()
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError
|
||||
45
Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py
Normal file
45
Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _DSBN(nn.Module):
|
||||
"""Domain Specific Batch Normalization.
|
||||
|
||||
Args:
|
||||
num_features (int): number of features.
|
||||
n_domain (int): number of domains.
|
||||
bn_type (str): type of bn. Choices are ['1d', '2d'].
|
||||
"""
|
||||
|
||||
def __init__(self, num_features, n_domain, bn_type):
|
||||
super().__init__()
|
||||
if bn_type == "1d":
|
||||
BN = nn.BatchNorm1d
|
||||
elif bn_type == "2d":
|
||||
BN = nn.BatchNorm2d
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))
|
||||
|
||||
self.valid_domain_idxs = list(range(n_domain))
|
||||
self.n_domain = n_domain
|
||||
self.domain_idx = 0
|
||||
|
||||
def select_bn(self, domain_idx=0):
|
||||
assert domain_idx in self.valid_domain_idxs
|
||||
self.domain_idx = domain_idx
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn[self.domain_idx](x)
|
||||
|
||||
|
||||
class DSBN1d(_DSBN):
|
||||
|
||||
def __init__(self, num_features, n_domain):
|
||||
super().__init__(num_features, n_domain, "1d")
|
||||
|
||||
|
||||
class DSBN2d(_DSBN):
|
||||
|
||||
def __init__(self, num_features, n_domain):
|
||||
super().__init__(num_features, n_domain, "2d")
|
||||
118
Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py
Normal file
118
Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def deactivate_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.set_activation_status(False)
|
||||
|
||||
|
||||
def activate_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.set_activation_status(True)
|
||||
|
||||
|
||||
def random_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.update_mix_method("random")
|
||||
|
||||
|
||||
def crossdomain_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.update_mix_method("crossdomain")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_without_efdmix(model):
|
||||
# Assume MixStyle was initially activated
|
||||
try:
|
||||
model.apply(deactivate_efdmix)
|
||||
yield
|
||||
finally:
|
||||
model.apply(activate_efdmix)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_with_efdmix(model, mix=None):
|
||||
# Assume MixStyle was initially deactivated
|
||||
if mix == "random":
|
||||
model.apply(random_efdmix)
|
||||
|
||||
elif mix == "crossdomain":
|
||||
model.apply(crossdomain_efdmix)
|
||||
|
||||
try:
|
||||
model.apply(activate_efdmix)
|
||||
yield
|
||||
finally:
|
||||
model.apply(deactivate_efdmix)
|
||||
|
||||
|
||||
class EFDMix(nn.Module):
|
||||
"""EFDMix.
|
||||
|
||||
Reference:
|
||||
Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
|
||||
"""
|
||||
Args:
|
||||
p (float): probability of using MixStyle.
|
||||
alpha (float): parameter of the Beta distribution.
|
||||
eps (float): scaling parameter to avoid numerical issues.
|
||||
mix (str): how to mix.
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.beta = torch.distributions.Beta(alpha, alpha)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
self.mix = mix
|
||||
self._activated = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
|
||||
)
|
||||
|
||||
def set_activation_status(self, status=True):
|
||||
self._activated = status
|
||||
|
||||
def update_mix_method(self, mix="random"):
|
||||
self.mix = mix
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self._activated:
|
||||
return x
|
||||
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
|
||||
B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||
x_view = x.view(B, C, -1)
|
||||
value_x, index_x = torch.sort(x_view) # sort inputs
|
||||
lmda = self.beta.sample((B, 1, 1))
|
||||
lmda = lmda.to(x.device)
|
||||
|
||||
if self.mix == "random":
|
||||
# random shuffle
|
||||
perm = torch.randperm(B)
|
||||
|
||||
elif self.mix == "crossdomain":
|
||||
# split into two halves and swap the order
|
||||
perm = torch.arange(B - 1, -1, -1) # inverse index
|
||||
perm_b, perm_a = perm.chunk(2)
|
||||
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
|
||||
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
|
||||
perm = torch.cat([perm_b, perm_a], 0)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
inverse_index = index_x.argsort(-1)
|
||||
x_view_copy = value_x[perm].gather(-1, inverse_index)
|
||||
new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)
|
||||
return new_x.view(B, C, W, H)
|
||||
124
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def deactivate_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.set_activation_status(False)
|
||||
|
||||
|
||||
def activate_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.set_activation_status(True)
|
||||
|
||||
|
||||
def random_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.update_mix_method("random")
|
||||
|
||||
|
||||
def crossdomain_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.update_mix_method("crossdomain")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_without_mixstyle(model):
|
||||
# Assume MixStyle was initially activated
|
||||
try:
|
||||
model.apply(deactivate_mixstyle)
|
||||
yield
|
||||
finally:
|
||||
model.apply(activate_mixstyle)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_with_mixstyle(model, mix=None):
|
||||
# Assume MixStyle was initially deactivated
|
||||
if mix == "random":
|
||||
model.apply(random_mixstyle)
|
||||
|
||||
elif mix == "crossdomain":
|
||||
model.apply(crossdomain_mixstyle)
|
||||
|
||||
try:
|
||||
model.apply(activate_mixstyle)
|
||||
yield
|
||||
finally:
|
||||
model.apply(deactivate_mixstyle)
|
||||
|
||||
|
||||
class MixStyle(nn.Module):
|
||||
"""MixStyle.
|
||||
|
||||
Reference:
|
||||
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
|
||||
"""
|
||||
Args:
|
||||
p (float): probability of using MixStyle.
|
||||
alpha (float): parameter of the Beta distribution.
|
||||
eps (float): scaling parameter to avoid numerical issues.
|
||||
mix (str): how to mix.
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.beta = torch.distributions.Beta(alpha, alpha)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
self.mix = mix
|
||||
self._activated = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
|
||||
)
|
||||
|
||||
def set_activation_status(self, status=True):
|
||||
self._activated = status
|
||||
|
||||
def update_mix_method(self, mix="random"):
|
||||
self.mix = mix
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self._activated:
|
||||
return x
|
||||
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
|
||||
B = x.size(0)
|
||||
|
||||
mu = x.mean(dim=[2, 3], keepdim=True)
|
||||
var = x.var(dim=[2, 3], keepdim=True)
|
||||
sig = (var + self.eps).sqrt()
|
||||
mu, sig = mu.detach(), sig.detach()
|
||||
x_normed = (x-mu) / sig
|
||||
|
||||
lmda = self.beta.sample((B, 1, 1, 1))
|
||||
lmda = lmda.to(x.device)
|
||||
|
||||
if self.mix == "random":
|
||||
# random shuffle
|
||||
perm = torch.randperm(B)
|
||||
|
||||
elif self.mix == "crossdomain":
|
||||
# split into two halves and swap the order
|
||||
perm = torch.arange(B - 1, -1, -1) # inverse index
|
||||
perm_b, perm_a = perm.chunk(2)
|
||||
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
|
||||
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
|
||||
perm = torch.cat([perm_b, perm_a], 0)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
mu2, sig2 = mu[perm], sig[perm]
|
||||
mu_mix = mu*lmda + mu2 * (1-lmda)
|
||||
sig_mix = sig*lmda + sig2 * (1-lmda)
|
||||
|
||||
return x_normed*sig_mix + mu_mix
|
||||
23
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
Normal file
23
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
def mixup(x1, x2, y1, y2, beta, preserve_order=False):
|
||||
"""Mixup.
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): data with shape of (b, c, h, w).
|
||||
x2 (torch.Tensor): data with shape of (b, c, h, w).
|
||||
y1 (torch.Tensor): label with shape of (b, n).
|
||||
y2 (torch.Tensor): label with shape of (b, n).
|
||||
beta (float): hyper-parameter for Beta sampling.
|
||||
preserve_order (bool): apply lmda=max(lmda, 1-lmda).
|
||||
Default is False.
|
||||
"""
|
||||
lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])
|
||||
if preserve_order:
|
||||
lmda = torch.max(lmda, 1 - lmda)
|
||||
lmda = lmda.to(x1.device)
|
||||
xmix = x1*lmda + x2 * (1-lmda)
|
||||
lmda = lmda[:, :, 0, 0]
|
||||
ymix = y1*lmda + y2 * (1-lmda)
|
||||
return xmix, ymix
|
||||
91
Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py
Normal file
91
Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class MaximumMeanDiscrepancy(nn.Module):
|
||||
|
||||
def __init__(self, kernel_type="rbf", normalize=False):
|
||||
super().__init__()
|
||||
self.kernel_type = kernel_type
|
||||
self.normalize = normalize
|
||||
|
||||
def forward(self, x, y):
|
||||
# x, y: two batches of data with shape (batch, dim)
|
||||
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
|
||||
if self.normalize:
|
||||
x = F.normalize(x, dim=1)
|
||||
y = F.normalize(y, dim=1)
|
||||
if self.kernel_type == "linear":
|
||||
return self.linear_mmd(x, y)
|
||||
elif self.kernel_type == "poly":
|
||||
return self.poly_mmd(x, y)
|
||||
elif self.kernel_type == "rbf":
|
||||
return self.rbf_mmd(x, y)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def linear_mmd(self, x, y):
|
||||
# k(x, y) = x^T y
|
||||
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
|
||||
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
|
||||
k_xy = torch.mm(x, y.t())
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
|
||||
# k(x, y) = (alpha * x^T y + c)^d
|
||||
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
|
||||
k_xx = (alpha*k_xx + c).pow(d)
|
||||
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
|
||||
k_yy = (alpha*k_yy + c).pow(d)
|
||||
k_xy = torch.mm(x, y.t())
|
||||
k_xy = (alpha*k_xy + c).pow(d)
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
def rbf_mmd(self, x, y):
|
||||
# k_xx
|
||||
d_xx = self.euclidean_squared_distance(x, x)
|
||||
d_xx = self.remove_self_distance(d_xx)
|
||||
k_xx = self.rbf_kernel_mixture(d_xx)
|
||||
# k_yy
|
||||
d_yy = self.euclidean_squared_distance(y, y)
|
||||
d_yy = self.remove_self_distance(d_yy)
|
||||
k_yy = self.rbf_kernel_mixture(d_yy)
|
||||
# k_xy
|
||||
d_xy = self.euclidean_squared_distance(x, y)
|
||||
k_xy = self.rbf_kernel_mixture(d_xy)
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
@staticmethod
|
||||
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
|
||||
K = 0
|
||||
for sigma in sigmas:
|
||||
gamma = 1.0 / (2.0 * sigma**2)
|
||||
K += torch.exp(-gamma * exponent)
|
||||
return K
|
||||
|
||||
@staticmethod
|
||||
def remove_self_distance(distmat):
|
||||
tmp_list = []
|
||||
for i, row in enumerate(distmat):
|
||||
row1 = torch.cat([row[:i], row[i + 1:]])
|
||||
tmp_list.append(row1)
|
||||
return torch.stack(tmp_list)
|
||||
|
||||
@staticmethod
|
||||
def euclidean_squared_distance(x, y):
|
||||
m, n = x.size(0), y.size(0)
|
||||
distmat = (
|
||||
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
||||
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
)
|
||||
# distmat.addmm_(1, -2, x, y.t())
|
||||
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
|
||||
return distmat
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
|
||||
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
|
||||
d = mmd(input1, input2)
|
||||
print(d.item())
|
||||
147
Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
Normal file
147
Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class OptimalTransport(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def distance(batch1, batch2, dist_metric="cosine"):
|
||||
if dist_metric == "cosine":
|
||||
batch1 = F.normalize(batch1, p=2, dim=1)
|
||||
batch2 = F.normalize(batch2, p=2, dim=1)
|
||||
dist_mat = 1 - torch.mm(batch1, batch2.t())
|
||||
elif dist_metric == "euclidean":
|
||||
m, n = batch1.size(0), batch2.size(0)
|
||||
dist_mat = (
|
||||
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
||||
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
)
|
||||
dist_mat.addmm_(
|
||||
1, -2, batch1, batch2.t()
|
||||
) # squared euclidean distance
|
||||
elif dist_metric == "fast_euclidean":
|
||||
batch1 = batch1.unsqueeze(-2)
|
||||
batch2 = batch2.unsqueeze(-3)
|
||||
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown cost function: {}. Expected to "
|
||||
"be one of [cosine | euclidean]".format(dist_metric)
|
||||
)
|
||||
return dist_mat
|
||||
|
||||
|
||||
class SinkhornDivergence(OptimalTransport):
|
||||
thre = 1e-3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist_metric="cosine",
|
||||
eps=0.01,
|
||||
max_iter=5,
|
||||
bp_to_sinkhorn=False
|
||||
):
|
||||
super().__init__()
|
||||
self.dist_metric = dist_metric
|
||||
self.eps = eps
|
||||
self.max_iter = max_iter
|
||||
self.bp_to_sinkhorn = bp_to_sinkhorn
|
||||
|
||||
def forward(self, x, y):
|
||||
# x, y: two batches of data with shape (batch, dim)
|
||||
W_xy = self.transport_cost(x, y)
|
||||
W_xx = self.transport_cost(x, x)
|
||||
W_yy = self.transport_cost(y, y)
|
||||
return 2*W_xy - W_xx - W_yy
|
||||
|
||||
def transport_cost(self, x, y, return_pi=False):
|
||||
C = self.distance(x, y, dist_metric=self.dist_metric)
|
||||
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
|
||||
if not self.bp_to_sinkhorn:
|
||||
pi = pi.detach()
|
||||
cost = torch.sum(pi * C)
|
||||
if return_pi:
|
||||
return cost, pi
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def sinkhorn_iterate(C, eps, max_iter, thre):
|
||||
nx, ny = C.shape
|
||||
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
|
||||
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
|
||||
u = torch.zeros_like(mu)
|
||||
v = torch.zeros_like(nu)
|
||||
|
||||
def M(_C, _u, _v):
|
||||
"""Modified cost for logarithmic updates.
|
||||
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
|
||||
"""
|
||||
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
|
||||
|
||||
real_iter = 0 # check if algorithm terminates before max_iter
|
||||
# Sinkhorn iterations
|
||||
for i in range(max_iter):
|
||||
u0 = u
|
||||
u = eps * (
|
||||
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
|
||||
) + u
|
||||
v = (
|
||||
eps * (
|
||||
torch.log(nu + 1e-8) -
|
||||
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
|
||||
) + v
|
||||
)
|
||||
err = (u - u0).abs().sum()
|
||||
real_iter += 1
|
||||
if err.item() < thre:
|
||||
break
|
||||
# Transport plan pi = diag(a)*K*diag(b)
|
||||
return torch.exp(M(C, u, v))
|
||||
|
||||
|
||||
class MinibatchEnergyDistance(SinkhornDivergence):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist_metric="cosine",
|
||||
eps=0.01,
|
||||
max_iter=5,
|
||||
bp_to_sinkhorn=False
|
||||
):
|
||||
super().__init__(
|
||||
dist_metric=dist_metric,
|
||||
eps=eps,
|
||||
max_iter=max_iter,
|
||||
bp_to_sinkhorn=bp_to_sinkhorn,
|
||||
)
|
||||
|
||||
def forward(self, x, y):
|
||||
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
|
||||
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
|
||||
cost = 0
|
||||
cost += self.transport_cost(x1, y1)
|
||||
cost += self.transport_cost(x1, y2)
|
||||
cost += self.transport_cost(x2, y1)
|
||||
cost += self.transport_cost(x2, y2)
|
||||
cost -= 2 * self.transport_cost(x1, x2)
|
||||
cost -= 2 * self.transport_cost(y1, y2)
|
||||
return cost
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# example: https://dfdazac.github.io/sinkhorn.html
|
||||
import numpy as np
|
||||
|
||||
n_points = 5
|
||||
a = np.array([[i, 0] for i in range(n_points)])
|
||||
b = np.array([[i, 1] for i in range(n_points)])
|
||||
x = torch.tensor(a, dtype=torch.float)
|
||||
y = torch.tensor(b, dtype=torch.float)
|
||||
sinkhorn = SinkhornDivergence(
|
||||
dist_metric="euclidean", eps=0.01, max_iter=5
|
||||
)
|
||||
dist, pi = sinkhorn.transport_cost(x, y, True)
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
34
Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
Normal file
34
Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class _ReverseGrad(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, grad_scaling):
|
||||
ctx.grad_scaling = grad_scaling
|
||||
return input.view_as(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_scaling = ctx.grad_scaling
|
||||
return -grad_scaling * grad_output, None
|
||||
|
||||
|
||||
reverse_grad = _ReverseGrad.apply
|
||||
|
||||
|
||||
class ReverseGrad(nn.Module):
|
||||
"""Gradient reversal layer.
|
||||
|
||||
It acts as an identity layer in the forward,
|
||||
but reverses the sign of the gradient in
|
||||
the backward.
|
||||
"""
|
||||
|
||||
def forward(self, x, grad_scaling=1.0):
|
||||
assert (grad_scaling >=
|
||||
0), "grad_scaling must be non-negative, " "but got {}".format(
|
||||
grad_scaling
|
||||
)
|
||||
return reverse_grad(x, grad_scaling)
|
||||
15
Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py
Normal file
15
Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Sequential2(nn.Sequential):
|
||||
"""An alternative sequential container to nn.Sequential,
|
||||
which accepts an arbitrary number of input arguments.
|
||||
"""
|
||||
|
||||
def forward(self, *inputs):
|
||||
for module in self._modules.values():
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
return inputs
|
||||
138
Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py
Normal file
138
Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _TransNorm(nn.Module):
|
||||
"""Transferable normalization.
|
||||
|
||||
Reference:
|
||||
- Wang et al. Transferable Normalization: Towards Improving
|
||||
Transferability of Deep Neural Networks. NeurIPS 2019.
|
||||
|
||||
Args:
|
||||
num_features (int): number of features.
|
||||
eps (float): epsilon.
|
||||
momentum (float): value for updating running_mean and running_var.
|
||||
adaptive_alpha (bool): apply domain adaptive alpha.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True
|
||||
):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.adaptive_alpha = adaptive_alpha
|
||||
|
||||
self.register_buffer("running_mean_s", torch.zeros(num_features))
|
||||
self.register_buffer("running_var_s", torch.ones(num_features))
|
||||
self.register_buffer("running_mean_t", torch.zeros(num_features))
|
||||
self.register_buffer("running_var_t", torch.ones(num_features))
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
|
||||
def resnet_running_stats(self):
|
||||
self.running_mean_s.zero_()
|
||||
self.running_var_s.fill_(1)
|
||||
self.running_mean_t.zero_()
|
||||
self.running_var_t.fill_(1)
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def _check_input(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def _compute_alpha(self, mean_s, var_s, mean_t, var_t):
|
||||
C = self.num_features
|
||||
ratio_s = mean_s / (var_s + self.eps).sqrt()
|
||||
ratio_t = mean_t / (var_t + self.eps).sqrt()
|
||||
dist = (ratio_s - ratio_t).abs()
|
||||
dist_inv = 1 / (1+dist)
|
||||
return C * dist_inv / dist_inv.sum()
|
||||
|
||||
def forward(self, input):
|
||||
self._check_input(input)
|
||||
C = self.num_features
|
||||
if input.dim() == 2:
|
||||
new_shape = (1, C)
|
||||
elif input.dim() == 4:
|
||||
new_shape = (1, C, 1, 1)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
weight = self.weight.view(*new_shape)
|
||||
bias = self.bias.view(*new_shape)
|
||||
|
||||
if not self.training:
|
||||
mean_t = self.running_mean_t.view(*new_shape)
|
||||
var_t = self.running_var_t.view(*new_shape)
|
||||
output = (input-mean_t) / (var_t + self.eps).sqrt()
|
||||
output = output*weight + bias
|
||||
|
||||
if self.adaptive_alpha:
|
||||
mean_s = self.running_mean_s.view(*new_shape)
|
||||
var_s = self.running_var_s.view(*new_shape)
|
||||
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
|
||||
alpha = alpha.reshape(*new_shape)
|
||||
output = (1 + alpha.detach()) * output
|
||||
|
||||
return output
|
||||
|
||||
input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)
|
||||
|
||||
x_s = input_s.transpose(0, 1).reshape(C, -1)
|
||||
mean_s = x_s.mean(1)
|
||||
var_s = x_s.var(1)
|
||||
self.running_mean_s.mul_(self.momentum)
|
||||
self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
|
||||
self.running_var_s.mul_(self.momentum)
|
||||
self.running_var_s.add_((1 - self.momentum) * var_s.data)
|
||||
mean_s = mean_s.reshape(*new_shape)
|
||||
var_s = var_s.reshape(*new_shape)
|
||||
output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
|
||||
output_s = output_s*weight + bias
|
||||
|
||||
x_t = input_t.transpose(0, 1).reshape(C, -1)
|
||||
mean_t = x_t.mean(1)
|
||||
var_t = x_t.var(1)
|
||||
self.running_mean_t.mul_(self.momentum)
|
||||
self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
|
||||
self.running_var_t.mul_(self.momentum)
|
||||
self.running_var_t.add_((1 - self.momentum) * var_t.data)
|
||||
mean_t = mean_t.reshape(*new_shape)
|
||||
var_t = var_t.reshape(*new_shape)
|
||||
output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
|
||||
output_t = output_t*weight + bias
|
||||
|
||||
output = torch.cat([output_s, output_t], 0)
|
||||
|
||||
if self.adaptive_alpha:
|
||||
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
|
||||
alpha = alpha.reshape(*new_shape)
|
||||
output = (1 + alpha.detach()) * output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransNorm1d(_TransNorm):
|
||||
|
||||
def _check_input(self, x):
|
||||
if x.dim() != 2:
|
||||
raise ValueError(
|
||||
"Expected the input to be 2-D, "
|
||||
"but got {}-D".format(x.dim())
|
||||
)
|
||||
|
||||
|
||||
class TransNorm2d(_TransNorm):
|
||||
|
||||
def _check_input(self, x):
|
||||
if x.dim() != 4:
|
||||
raise ValueError(
|
||||
"Expected the input to be 4-D, "
|
||||
"but got {}-D".format(x.dim())
|
||||
)
|
||||
75
Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py
Normal file
75
Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def sharpen_prob(p, temperature=2):
|
||||
"""Sharpening probability with a temperature.
|
||||
|
||||
Args:
|
||||
p (torch.Tensor): probability matrix (batch_size, n_classes)
|
||||
temperature (float): temperature.
|
||||
"""
|
||||
p = p.pow(temperature)
|
||||
return p / p.sum(1, keepdim=True)
|
||||
|
||||
|
||||
def reverse_index(data, label):
|
||||
"""Reverse order."""
|
||||
inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()
|
||||
return data[inv_idx], label[inv_idx]
|
||||
|
||||
|
||||
def shuffle_index(data, label):
|
||||
"""Shuffle order."""
|
||||
rnd_idx = torch.randperm(data.shape[0])
|
||||
return data[rnd_idx], label[rnd_idx]
|
||||
|
||||
|
||||
def create_onehot(label, num_classes):
|
||||
"""Create one-hot tensor.
|
||||
|
||||
We suggest using nn.functional.one_hot.
|
||||
|
||||
Args:
|
||||
label (torch.Tensor): 1-D tensor.
|
||||
num_classes (int): number of classes.
|
||||
"""
|
||||
onehot = torch.zeros(label.shape[0], num_classes)
|
||||
return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)
|
||||
|
||||
|
||||
def sigmoid_rampup(current, rampup_length):
|
||||
"""Exponential rampup.
|
||||
|
||||
Args:
|
||||
current (int): current step.
|
||||
rampup_length (int): maximum step.
|
||||
"""
|
||||
assert rampup_length > 0
|
||||
current = np.clip(current, 0.0, rampup_length)
|
||||
phase = 1.0 - current/rampup_length
|
||||
return float(np.exp(-5.0 * phase * phase))
|
||||
|
||||
|
||||
def linear_rampup(current, rampup_length):
|
||||
"""Linear rampup.
|
||||
|
||||
Args:
|
||||
current (int): current step.
|
||||
rampup_length (int): maximum step.
|
||||
"""
|
||||
assert rampup_length > 0
|
||||
ratio = np.clip(current / rampup_length, 0.0, 1.0)
|
||||
return float(ratio)
|
||||
|
||||
|
||||
def ema_model_update(model, ema_model, alpha):
|
||||
"""Exponential moving average of model parameters.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model being trained.
|
||||
ema_model (nn.Module): ema of the model.
|
||||
alpha (float): ema decay rate.
|
||||
"""
|
||||
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
|
||||
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
|
||||
Reference in New Issue
Block a user