Upload to Main
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
from .alexnet import *
|
||||
from .inceptionv3 import *
|
||||
from .lenet import *
|
||||
from .mlp import *
|
||||
from .mobilenetv3 import *
|
||||
from .resnet import *
|
||||
from .vgg import *
|
||||
from .wideresnet import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,100 @@
|
||||
import torch.nn as nn
|
||||
from torch import set_grad_enabled
|
||||
from torchvision import models
|
||||
import torch
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
class AlexNet_32x32(nn.Module):
|
||||
def __init__(self, channel, num_classes, record_embedding=False, no_grad=False):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel == 1 else 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(128, 192, kernel_size=5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(192, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 192, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(192, 192, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.fc = nn.Linear(192 * 4 * 4, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class AlexNet_224x224(models.AlexNet):
|
||||
def __init__(self, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super().__init__(num_classes, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.features[0] = nn.Conv2d(channel, 64, kernel_size=11, stride=4, padding=2)
|
||||
self.fc = self.classifier[-1]
|
||||
self.classifier[-1] = self.embedding_recorder
|
||||
self.classifier.add_module("fc", self.fc)
|
||||
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def AlexNet(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
net = AlexNet_224x224(channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-7be5be79.pth'
|
||||
, progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.features[0] = nn.Conv2d(channel, 64, kernel_size=11, stride=4, padding=2)
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(4096, num_classes)
|
||||
net.classifier[-1] = net.fc
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
net = AlexNet_224x224(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad)
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
net = AlexNet_32x32(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad)
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
@@ -0,0 +1,426 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models import inception
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, input_channels, output_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs)
|
||||
self.bn = nn.BatchNorm2d(output_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# same naive inception module
|
||||
class InceptionA(nn.Module):
|
||||
|
||||
def __init__(self, input_channels, pool_features):
|
||||
super().__init__()
|
||||
self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1)
|
||||
|
||||
self.branch5x5 = nn.Sequential(
|
||||
BasicConv2d(input_channels, 48, kernel_size=1),
|
||||
BasicConv2d(48, 64, kernel_size=5, padding=2)
|
||||
)
|
||||
|
||||
self.branch3x3 = nn.Sequential(
|
||||
BasicConv2d(input_channels, 64, kernel_size=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
self.branchpool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x -> 1x1(same)
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
# x -> 1x1 -> 5x5(same)
|
||||
branch5x5 = self.branch5x5(x)
|
||||
# branch5x5 = self.branch5x5_2(branch5x5)
|
||||
|
||||
# x -> 1x1 -> 3x3 -> 3x3(same)
|
||||
branch3x3 = self.branch3x3(x)
|
||||
|
||||
# x -> pool -> 1x1(same)
|
||||
branchpool = self.branchpool(x)
|
||||
|
||||
outputs = [branch1x1, branch5x5, branch3x3, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
# downsample
|
||||
# Factorization into smaller convolutions
|
||||
class InceptionB(nn.Module):
|
||||
|
||||
def __init__(self, input_channels):
|
||||
super().__init__()
|
||||
|
||||
self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch3x3stack = nn.Sequential(
|
||||
BasicConv2d(input_channels, 64, kernel_size=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
# x - > 3x3(downsample)
|
||||
branch3x3 = self.branch3x3(x)
|
||||
|
||||
# x -> 3x3 -> 3x3(downsample)
|
||||
branch3x3stack = self.branch3x3stack(x)
|
||||
|
||||
# x -> avgpool(downsample)
|
||||
branchpool = self.branchpool(x)
|
||||
|
||||
# """We can use two parallel stride 2 blocks: P and C. P is a pooling
|
||||
# layer (either average or maximum pooling) the activation, both of
|
||||
# them are stride 2 the filter banks of which are concatenated as in
|
||||
# figure 10."""
|
||||
outputs = [branch3x3, branch3x3stack, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
# Factorizing Convolutions with Large Filter Size
|
||||
class InceptionC(nn.Module):
|
||||
def __init__(self, input_channels, channels_7x7):
|
||||
super().__init__()
|
||||
self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1)
|
||||
|
||||
c7 = channels_7x7
|
||||
|
||||
# In theory, we could go even further and argue that one can replace any n × n
|
||||
# convolution by a 1 × n convolution followed by a n × 1 convolution and the
|
||||
# computational cost saving increases dramatically as n grows (see figure 6).
|
||||
self.branch7x7 = nn.Sequential(
|
||||
BasicConv2d(input_channels, c7, kernel_size=1),
|
||||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||||
BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
||||
)
|
||||
|
||||
self.branch7x7stack = nn.Sequential(
|
||||
BasicConv2d(input_channels, c7, kernel_size=1),
|
||||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||||
BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
|
||||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||||
BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
||||
)
|
||||
|
||||
self.branch_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x -> 1x1(same)
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
# x -> 1layer 1*7 and 7*1 (same)
|
||||
branch7x7 = self.branch7x7(x)
|
||||
|
||||
# x-> 2layer 1*7 and 7*1(same)
|
||||
branch7x7stack = self.branch7x7stack(x)
|
||||
|
||||
# x-> avgpool (same)
|
||||
branchpool = self.branch_pool(x)
|
||||
|
||||
outputs = [branch1x1, branch7x7, branch7x7stack, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionD(nn.Module):
|
||||
|
||||
def __init__(self, input_channels):
|
||||
super().__init__()
|
||||
|
||||
self.branch3x3 = nn.Sequential(
|
||||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||||
BasicConv2d(192, 320, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch7x7 = nn.Sequential(
|
||||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||||
BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)),
|
||||
BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)),
|
||||
BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
# x -> 1x1 -> 3x3(downsample)
|
||||
branch3x3 = self.branch3x3(x)
|
||||
|
||||
# x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample)
|
||||
branch7x7 = self.branch7x7(x)
|
||||
|
||||
# x -> avgpool (downsample)
|
||||
branchpool = self.branchpool(x)
|
||||
|
||||
outputs = [branch3x3, branch7x7, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
# same
|
||||
class InceptionE(nn.Module):
|
||||
def __init__(self, input_channels):
|
||||
super().__init__()
|
||||
self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1)
|
||||
|
||||
self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1)
|
||||
self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||||
self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||||
|
||||
self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1)
|
||||
self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
|
||||
self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||||
self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||||
|
||||
self.branch_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(input_channels, 192, kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x -> 1x1 (same)
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
# x -> 1x1 -> 3x1
|
||||
# x -> 1x1 -> 1x3
|
||||
# concatenate(3x1, 1x3)
|
||||
# """7. Inception modules with expanded the filter bank outputs.
|
||||
# This architecture is used on the coarsest (8 × 8) grids to promote
|
||||
# high dimensional representations, as suggested by principle
|
||||
# 2 of Section 2."""
|
||||
branch3x3 = self.branch3x3_1(x)
|
||||
branch3x3 = [
|
||||
self.branch3x3_2a(branch3x3),
|
||||
self.branch3x3_2b(branch3x3)
|
||||
]
|
||||
branch3x3 = torch.cat(branch3x3, 1)
|
||||
|
||||
# x -> 1x1 -> 3x3 -> 1x3
|
||||
# x -> 1x1 -> 3x3 -> 3x1
|
||||
# concatenate(1x3, 3x1)
|
||||
branch3x3stack = self.branch3x3stack_1(x)
|
||||
branch3x3stack = self.branch3x3stack_2(branch3x3stack)
|
||||
branch3x3stack = [
|
||||
self.branch3x3stack_3a(branch3x3stack),
|
||||
self.branch3x3stack_3b(branch3x3stack)
|
||||
]
|
||||
branch3x3stack = torch.cat(branch3x3stack, 1)
|
||||
|
||||
branchpool = self.branch_pool(x)
|
||||
|
||||
outputs = [branch1x1, branch3x3, branch3x3stack, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionV3_32x32(nn.Module):
|
||||
|
||||
def __init__(self, channel, num_classes, record_embedding=False, no_grad=False):
|
||||
super().__init__()
|
||||
self.Conv2d_1a_3x3 = BasicConv2d(channel, 32, kernel_size=3, padding=3 if channel == 1 else 1)
|
||||
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
|
||||
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
|
||||
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
|
||||
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
|
||||
|
||||
# naive inception module
|
||||
self.Mixed_5b = InceptionA(192, pool_features=32)
|
||||
self.Mixed_5c = InceptionA(256, pool_features=64)
|
||||
self.Mixed_5d = InceptionA(288, pool_features=64)
|
||||
|
||||
# downsample
|
||||
self.Mixed_6a = InceptionB(288)
|
||||
|
||||
self.Mixed_6b = InceptionC(768, channels_7x7=128)
|
||||
self.Mixed_6c = InceptionC(768, channels_7x7=160)
|
||||
self.Mixed_6d = InceptionC(768, channels_7x7=160)
|
||||
self.Mixed_6e = InceptionC(768, channels_7x7=192)
|
||||
|
||||
# downsample
|
||||
self.Mixed_7a = InceptionD(768)
|
||||
|
||||
self.Mixed_7b = InceptionE(1280)
|
||||
self.Mixed_7c = InceptionE(2048)
|
||||
|
||||
# 6*6 feature size
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.dropout = nn.Dropout2d()
|
||||
self.linear = nn.Linear(2048, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.linear
|
||||
|
||||
def forward(self, x):
|
||||
with torch.set_grad_enabled(not self.no_grad):
|
||||
# 32 -> 30
|
||||
x = self.Conv2d_1a_3x3(x)
|
||||
x = self.Conv2d_2a_3x3(x)
|
||||
x = self.Conv2d_2b_3x3(x)
|
||||
x = self.Conv2d_3b_1x1(x)
|
||||
x = self.Conv2d_4a_3x3(x)
|
||||
|
||||
# 30 -> 30
|
||||
x = self.Mixed_5b(x)
|
||||
x = self.Mixed_5c(x)
|
||||
x = self.Mixed_5d(x)
|
||||
|
||||
# 30 -> 14
|
||||
# Efficient Grid Size Reduction to avoid representation
|
||||
# bottleneck
|
||||
x = self.Mixed_6a(x)
|
||||
|
||||
# 14 -> 14
|
||||
# """In practice, we have found that employing this factorization does not
|
||||
# work well on early layers, but it gives very good results on medium
|
||||
# grid-sizes (On m × m feature maps, where m ranges between 12 and 20).
|
||||
# On that level, very good results can be achieved by using 1 × 7 convolutions
|
||||
# followed by 7 × 1 convolutions."""
|
||||
x = self.Mixed_6b(x)
|
||||
x = self.Mixed_6c(x)
|
||||
x = self.Mixed_6d(x)
|
||||
x = self.Mixed_6e(x)
|
||||
|
||||
# 14 -> 6
|
||||
# Efficient Grid Size Reduction
|
||||
x = self.Mixed_7a(x)
|
||||
|
||||
# 6 -> 6
|
||||
# We are using this solution only on the coarsest grid,
|
||||
# since that is the place where producing high dimensional
|
||||
# sparse representation is the most critical as the ratio of
|
||||
# local processing (by 1 × 1 convolutions) is increased compared
|
||||
# to the spatial aggregation."""
|
||||
x = self.Mixed_7b(x)
|
||||
x = self.Mixed_7c(x)
|
||||
|
||||
# 6 -> 1
|
||||
x = self.avgpool(x)
|
||||
x = self.dropout(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class InceptionV3_224x224(inception.Inception3):
|
||||
def __init__(self, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super().__init__(num_classes=num_classes, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def _forward(self, x):
|
||||
with torch.set_grad_enabled(not self.no_grad):
|
||||
# N x 3 x 299 x 299
|
||||
x = self.Conv2d_1a_3x3(x)
|
||||
# N x 32 x 149 x 149
|
||||
x = self.Conv2d_2a_3x3(x)
|
||||
# N x 32 x 147 x 147
|
||||
x = self.Conv2d_2b_3x3(x)
|
||||
# N x 64 x 147 x 147
|
||||
x = self.maxpool1(x)
|
||||
# N x 64 x 73 x 73
|
||||
x = self.Conv2d_3b_1x1(x)
|
||||
# N x 80 x 73 x 73
|
||||
x = self.Conv2d_4a_3x3(x)
|
||||
# N x 192 x 71 x 71
|
||||
x = self.maxpool2(x)
|
||||
# N x 192 x 35 x 35
|
||||
x = self.Mixed_5b(x)
|
||||
# N x 256 x 35 x 35
|
||||
x = self.Mixed_5c(x)
|
||||
# N x 288 x 35 x 35
|
||||
x = self.Mixed_5d(x)
|
||||
# N x 288 x 35 x 35
|
||||
x = self.Mixed_6a(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6b(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6c(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6d(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6e(x)
|
||||
# N x 768 x 17 x 17
|
||||
aux = None
|
||||
if self.AuxLogits is not None:
|
||||
if self.training:
|
||||
aux = self.AuxLogits(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_7a(x)
|
||||
# N x 1280 x 8 x 8
|
||||
x = self.Mixed_7b(x)
|
||||
# N x 2048 x 8 x 8
|
||||
x = self.Mixed_7c(x)
|
||||
# N x 2048 x 8 x 8
|
||||
# Adaptive average pooling
|
||||
x = self.avgpool(x)
|
||||
# N x 2048 x 1 x 1
|
||||
x = self.dropout(x)
|
||||
# N x 2048 x 1 x 1
|
||||
x = torch.flatten(x, 1)
|
||||
# N x 2048
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.fc(x)
|
||||
# N x 1000 (num_classes)
|
||||
return x, aux
|
||||
|
||||
|
||||
def InceptionV3(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
net = InceptionV3_224x224(channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(inception.model_urls["inception_v3_google"], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(net.fc.in_features, num_classes)
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
net = InceptionV3_224x224(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad)
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
net = InceptionV3_32x32(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad)
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
|
||||
return net
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import set_grad_enabled
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
class LeNet(nn.Module):
|
||||
def __init__(self, channel, num_classes, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
if pretrained:
|
||||
raise NotImplementedError("torchvison pretrained models not available.")
|
||||
super(LeNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel == 1 else 0),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(6, 16, kernel_size=5),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.fc_1 = nn.Linear(16 * 53 * 53 if im_size[0] == im_size[1] == 224 else 16 * 5 * 5, 120)
|
||||
self.fc_2 = nn.Linear(120, 84)
|
||||
self.fc_3 = nn.Linear(84, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc_3
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = F.relu(self.fc_1(x))
|
||||
x = F.relu(self.fc_2(x))
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.fc_3(x)
|
||||
return x
|
||||
@@ -0,0 +1,37 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import set_grad_enabled
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
|
||||
''' MLP '''
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, channel, num_classes, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
if pretrained:
|
||||
raise NotImplementedError("torchvison pretrained models not available.")
|
||||
super(MLP, self).__init__()
|
||||
self.fc_1 = nn.Linear(im_size[0] * im_size[1] * channel, 128)
|
||||
self.fc_2 = nn.Linear(128, 128)
|
||||
self.fc_3 = nn.Linear(128, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc_3
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
out = x.view(x.size(0), -1)
|
||||
out = F.relu(self.fc_1(out))
|
||||
out = F.relu(self.fc_2(out))
|
||||
out = self.embedding_recorder(out)
|
||||
out = self.fc_3(out)
|
||||
return out
|
||||
@@ -0,0 +1,304 @@
|
||||
import torch.nn as nn
|
||||
from torch import set_grad_enabled, flatten, Tensor
|
||||
from torchvision.models import mobilenetv3
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
import math
|
||||
|
||||
'''MobileNetV3 in PyTorch.
|
||||
Paper: "Inverted Residuals and Linear Bottlenecks:Mobile Networks for Classification, Detection and Segmentation"
|
||||
|
||||
Acknowlegement to:
|
||||
https://github.com/d-li14/mobilenetv3.pytorch/blob/master/mobilenetv3.py
|
||||
'''
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class h_sigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(h_sigmoid, self).__init__()
|
||||
self.relu = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(x + 3) / 6
|
||||
|
||||
|
||||
class h_swish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(h_swish, self).__init__()
|
||||
self.sigmoid = h_sigmoid(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.sigmoid(x)
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, _make_divisible(channel // reduction, 8)),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(_make_divisible(channel // reduction, 8), channel),
|
||||
h_sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
def conv_3x3_bn(inp, oup, stride, padding=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, padding, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
h_swish()
|
||||
)
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
h_swish()
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
|
||||
super(InvertedResidual, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
|
||||
self.identity = stride == 1 and inp == oup
|
||||
|
||||
if inp == hidden_dim:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
h_swish() if use_hs else nn.ReLU(inplace=True),
|
||||
# Squeeze-and-Excite
|
||||
SELayer(hidden_dim) if use_se else nn.Identity(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
h_swish() if use_hs else nn.ReLU(inplace=True),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
# Squeeze-and-Excite
|
||||
SELayer(hidden_dim) if use_se else nn.Identity(),
|
||||
h_swish() if use_hs else nn.ReLU(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.identity:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV3_32x32(nn.Module):
|
||||
def __init__(self, cfgs, mode, channel=3, num_classes=1000, record_embedding=False,
|
||||
no_grad=False, width_mult=1.):
|
||||
super(MobileNetV3_32x32, self).__init__()
|
||||
# setting of inverted residual blocks
|
||||
self.cfgs = cfgs
|
||||
assert mode in ['mobilenet_v3_large', 'mobilenet_v3_small']
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(16 * width_mult, 8)
|
||||
layers = [conv_3x3_bn(channel, input_channel, 2, padding=3 if channel == 1 else 1)]
|
||||
# building inverted residual blocks
|
||||
block = InvertedResidual
|
||||
for k, t, c, use_se, use_hs, s in self.cfgs:
|
||||
output_channel = _make_divisible(c * width_mult, 8)
|
||||
exp_size = _make_divisible(input_channel * t, 8)
|
||||
layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
|
||||
input_channel = output_channel
|
||||
self.features = nn.Sequential(*layers)
|
||||
# building last several layers
|
||||
self.conv = conv_1x1_bn(input_channel, exp_size)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
output_channel = {'mobilenet_v3_large': 1280, 'mobilenet_v3_small': 1024}
|
||||
output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[
|
||||
mode]
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(exp_size, output_channel),
|
||||
h_swish(),
|
||||
nn.Dropout(0.2),
|
||||
self.embedding_recorder,
|
||||
nn.Linear(output_channel, num_classes),
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.conv(x)
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.classifier[-1]
|
||||
|
||||
|
||||
class MobileNetV3_224x224(mobilenetv3.MobileNetV3):
|
||||
def __init__(self, inverted_residual_setting, last_channel,
|
||||
channel=3, num_classes=1000, record_embedding=False, no_grad=False, **kwargs):
|
||||
super(MobileNetV3_224x224, self).__init__(inverted_residual_setting, last_channel,
|
||||
num_classes=num_classes, **kwargs)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
|
||||
self.fc = self.classifier[-1]
|
||||
self.classifier[-1] = self.embedding_recorder
|
||||
self.classifier.add_module("fc", self.fc)
|
||||
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def MobileNetV3(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False,
|
||||
no_grad: bool = False,
|
||||
pretrained: bool = False, **kwargs):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if channel != 3:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
|
||||
inverted_residual_setting, last_channel = mobilenetv3._mobilenet_v3_conf(arch)
|
||||
net = MobileNetV3_224x224(inverted_residual_setting=inverted_residual_setting, last_channel=last_channel,
|
||||
channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad,
|
||||
**kwargs)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(mobilenetv3.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(last_channel, num_classes)
|
||||
net.classifier[-1] = net.fc
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
if channel != 3:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
inverted_residual_setting, last_channel = mobilenetv3._mobilenet_v3_conf(arch)
|
||||
net = MobileNetV3_224x224(inverted_residual_setting=inverted_residual_setting, last_channel=last_channel,
|
||||
channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad, **kwargs)
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
if arch == "mobilenet_v3_large":
|
||||
cfgs = [
|
||||
# k, t, c, SE, HS, s
|
||||
[3, 1, 16, 0, 0, 1],
|
||||
[3, 4, 24, 0, 0, 2],
|
||||
[3, 3, 24, 0, 0, 1],
|
||||
[5, 3, 40, 1, 0, 2],
|
||||
[5, 3, 40, 1, 0, 1],
|
||||
[5, 3, 40, 1, 0, 1],
|
||||
[3, 6, 80, 0, 1, 2],
|
||||
[3, 2.5, 80, 0, 1, 1],
|
||||
[3, 2.3, 80, 0, 1, 1],
|
||||
[3, 2.3, 80, 0, 1, 1],
|
||||
[3, 6, 112, 1, 1, 1],
|
||||
[3, 6, 112, 1, 1, 1],
|
||||
[5, 6, 160, 1, 1, 2],
|
||||
[5, 6, 160, 1, 1, 1],
|
||||
[5, 6, 160, 1, 1, 1]
|
||||
]
|
||||
net = MobileNetV3_32x32(cfgs, arch, channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "mobilenet_v3_small":
|
||||
cfgs = [
|
||||
# k, t, c, SE, HS, s
|
||||
[3, 1, 16, 1, 0, 2],
|
||||
[3, 4.5, 24, 0, 0, 2],
|
||||
[3, 3.67, 24, 0, 0, 1],
|
||||
[5, 4, 40, 1, 1, 2],
|
||||
[5, 6, 40, 1, 1, 1],
|
||||
[5, 6, 40, 1, 1, 1],
|
||||
[5, 3, 48, 1, 1, 1],
|
||||
[5, 3, 48, 1, 1, 1],
|
||||
[5, 6, 96, 1, 1, 2],
|
||||
[5, 6, 96, 1, 1, 1],
|
||||
[5, 6, 96, 1, 1, 1],
|
||||
]
|
||||
net = MobileNetV3_32x32(cfgs, arch, channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def MobileNetV3Large(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False, **kwargs):
|
||||
return MobileNetV3("mobilenet_v3_large", channel, num_classes, im_size, record_embedding, no_grad,
|
||||
pretrained, **kwargs)
|
||||
|
||||
|
||||
def MobileNetV3Small(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False, **kwargs):
|
||||
return MobileNetV3("mobilenet_v3_small", channel, num_classes, im_size, record_embedding, no_grad,
|
||||
pretrained, **kwargs)
|
||||
@@ -0,0 +1,2 @@
|
||||
from .parallel import *
|
||||
from .recorder import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,16 @@
|
||||
from torch.nn import DataParallel
|
||||
|
||||
|
||||
class MyDataParallel(DataParallel):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.module, name)
|
||||
def __setattr__(self, name, value):
|
||||
try:
|
||||
if name == "no_grad":
|
||||
return setattr(self.module, name, value)
|
||||
return super().__setattr__(name, value)
|
||||
except AttributeError:
|
||||
return setattr(self.module, name, value)
|
||||
@@ -0,0 +1,18 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class EmbeddingRecorder(nn.Module):
|
||||
def __init__(self, record_embedding: bool = False):
|
||||
super().__init__()
|
||||
self.record_embedding = record_embedding
|
||||
|
||||
def forward(self, x):
|
||||
if self.record_embedding:
|
||||
self.embedding = x
|
||||
return x
|
||||
|
||||
def __enter__(self):
|
||||
self.record_embedding = True
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.record_embedding = False
|
||||
@@ -0,0 +1,241 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import set_grad_enabled, flatten, Tensor
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
from torchvision.models import resnet
|
||||
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(in_planes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_32x32(nn.Module):
|
||||
def __init__(self, block, num_blocks, channel=3, num_classes=10, record_embedding: bool = False,
|
||||
no_grad: bool = False):
|
||||
super().__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = conv3x3(channel, 64)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.linear
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.embedding_recorder(out)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_224x224(resnet.ResNet):
|
||||
def __init__(self, block, layers, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super().__init__(block, layers, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
if num_classes != 1000:
|
||||
self.fc = nn.Linear(self.fc.in_features, num_classes)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = flatten(x, 1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def ResNet(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if arch == "resnet18":
|
||||
net = ResNet_224x224(resnet.BasicBlock, [2, 2, 2, 2], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet34":
|
||||
net = ResNet_224x224(resnet.BasicBlock, [3, 4, 6, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet50":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 6, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet101":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 23, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet152":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 8, 36, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(resnet.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(net.fc.in_features, num_classes)
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
if arch == "resnet18":
|
||||
net = ResNet_224x224(resnet.BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet34":
|
||||
net = ResNet_224x224(resnet.BasicBlock, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet50":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet101":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet152":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
if arch == "resnet18":
|
||||
net = ResNet_32x32(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet34":
|
||||
net = ResNet_32x32(BasicBlock, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet50":
|
||||
net = ResNet_32x32(Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet101":
|
||||
net = ResNet_32x32(Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet152":
|
||||
net = ResNet_32x32(Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def ResNet18(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet18", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def ResNet34(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet34", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def ResNet50(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet50", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def ResNet101(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet101", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def ResNet152(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet152", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
@@ -0,0 +1,128 @@
|
||||
import torch.nn as nn
|
||||
from torch import set_grad_enabled, flatten, Tensor
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
from torchvision.models import vgg
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
cfg_vgg = {
|
||||
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
|
||||
class VGG_32x32(nn.Module):
|
||||
def __init__(self, vgg_name, channel, num_classes, record_embedding=False, no_grad=False):
|
||||
super(VGG_32x32, self).__init__()
|
||||
self.channel = channel
|
||||
self.features = self._make_layers(cfg_vgg[vgg_name])
|
||||
self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.classifier
|
||||
|
||||
def _make_layers(self, cfg):
|
||||
layers = []
|
||||
in_channels = self.channel
|
||||
for ic, x in enumerate(cfg):
|
||||
if x == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel == 1 and ic == 0 else 1),
|
||||
nn.BatchNorm2d(x),
|
||||
nn.ReLU(inplace=True)]
|
||||
in_channels = x
|
||||
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class VGG_224x224(vgg.VGG):
|
||||
def __init__(self, features: nn.Module, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super(VGG_224x224, self).__init__(features, num_classes, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.features[0] = nn.Conv2d(channel, 64, kernel_size=3, padding=1)
|
||||
self.fc = self.classifier[-1]
|
||||
self.classifier[-1] = self.embedding_recorder
|
||||
self.classifier.add_module("fc", self.fc)
|
||||
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def VGG(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
net = VGG_224x224(features=vgg.make_layers(cfg_vgg[arch], True), channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(vgg.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.features[0] = nn.Conv2d(channel, 64, kernel_size=3, padding=1)
|
||||
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(4096, num_classes)
|
||||
net.classifier[-1] = net.fc
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
net = VGG_224x224(features=vgg.make_layers(cfg_vgg[arch], True), channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
net = VGG_32x32(arch, channel, num_classes=num_classes, record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def VGG11(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG("vgg11", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG13(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg13', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG16(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg16', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG19(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg19', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
from torchvision.models import resnet
|
||||
from .resnet import ResNet_224x224
|
||||
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/xternalz/WideResNet-pytorch
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
self.equalInOut = (in_planes == out_planes)
|
||||
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||
padding=0, bias=False) or None
|
||||
|
||||
def forward(self, x):
|
||||
if not self.equalInOut:
|
||||
x = self.relu1(self.bn1(x))
|
||||
else:
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, training=self.training)
|
||||
out = self.conv2(out)
|
||||
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
||||
|
||||
|
||||
class NetworkBlock(nn.Module):
|
||||
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
|
||||
super(NetworkBlock, self).__init__()
|
||||
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
|
||||
|
||||
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
|
||||
layers = []
|
||||
for i in range(int(nb_layers)):
|
||||
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class WideResNet_32x32(nn.Module):
|
||||
def __init__(self, depth, num_classes, channel=3, widen_factor=1, drop_rate=0.0, record_embedding=False,
|
||||
no_grad=False):
|
||||
super(WideResNet_32x32, self).__init__()
|
||||
nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
|
||||
assert ((depth - 4) % 6 == 0)
|
||||
n = (depth - 4) / 6
|
||||
block = BasicBlock
|
||||
# 1st conv before any network block
|
||||
self.conv1 = nn.Conv2d(channel, nChannels[0], kernel_size=3, stride=1,
|
||||
padding=3 if channel == 1 else 1, bias=False)
|
||||
# 1st block
|
||||
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, drop_rate)
|
||||
# 2nd block
|
||||
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, drop_rate)
|
||||
# 3rd block
|
||||
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, drop_rate)
|
||||
# global average pooling and classifier
|
||||
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc = nn.Linear(nChannels[3], num_classes)
|
||||
self.nChannels = nChannels[3]
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x):
|
||||
with torch.set_grad_enabled(not self.no_grad):
|
||||
out = self.conv1(x)
|
||||
out = self.block1(out)
|
||||
out = self.block2(out)
|
||||
out = self.block3(out)
|
||||
out = self.relu(self.bn1(out))
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(-1, self.nChannels)
|
||||
out = self.embedding_recorder(out)
|
||||
return self.fc(out)
|
||||
|
||||
|
||||
def WideResNet(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False,
|
||||
no_grad: bool = False, pretrained: bool = False):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
if arch == "wrn502":
|
||||
arch = "wide_resnet50_2"
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 6, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad, width_per_group=64 * 2)
|
||||
elif arch == "wrn1012":
|
||||
arch = "wide_resnet101_2"
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 23, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad, width_per_group=64 * 2)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(resnet.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(net.fc.in_features, num_classes)
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
# Use torchvision models without pretrained parameters
|
||||
if arch == "wrn502":
|
||||
arch = "wide_resnet50_2"
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad, width_per_group=64 * 2)
|
||||
elif arch == "wrn1012":
|
||||
arch = "wide_resnet101_2"
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad, width_per_group=64 * 2)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
if arch == "wrn168":
|
||||
net = WideResNet_32x32(16, num_classes, channel, 8)
|
||||
elif arch == "wrn2810":
|
||||
net = WideResNet_32x32(28, num_classes, channel, 10)
|
||||
elif arch == "wrn282":
|
||||
net = WideResNet_32x32(28, num_classes, channel, 2)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def WRN168(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet("wrn168", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def WRN2810(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet("wrn2810", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def WRN282(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet('wrn282', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def WRN502(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet("wrn502", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def WRN1012(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet("wrn1012", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
Reference in New Issue
Block a user