427 lines
15 KiB
Python
427 lines
15 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from torchvision.models import inception
|
||
from .nets_utils import EmbeddingRecorder
|
||
|
||
|
||
class BasicConv2d(nn.Module):
|
||
|
||
def __init__(self, input_channels, output_channels, **kwargs):
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs)
|
||
self.bn = nn.BatchNorm2d(output_channels)
|
||
self.relu = nn.ReLU(inplace=True)
|
||
|
||
def forward(self, x):
|
||
x = self.conv(x)
|
||
x = self.bn(x)
|
||
x = self.relu(x)
|
||
|
||
return x
|
||
|
||
|
||
# same naive inception module
|
||
class InceptionA(nn.Module):
|
||
|
||
def __init__(self, input_channels, pool_features):
|
||
super().__init__()
|
||
self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1)
|
||
|
||
self.branch5x5 = nn.Sequential(
|
||
BasicConv2d(input_channels, 48, kernel_size=1),
|
||
BasicConv2d(48, 64, kernel_size=5, padding=2)
|
||
)
|
||
|
||
self.branch3x3 = nn.Sequential(
|
||
BasicConv2d(input_channels, 64, kernel_size=1),
|
||
BasicConv2d(64, 96, kernel_size=3, padding=1),
|
||
BasicConv2d(96, 96, kernel_size=3, padding=1)
|
||
)
|
||
|
||
self.branchpool = nn.Sequential(
|
||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||
BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1)
|
||
)
|
||
|
||
def forward(self, x):
|
||
# x -> 1x1(same)
|
||
branch1x1 = self.branch1x1(x)
|
||
|
||
# x -> 1x1 -> 5x5(same)
|
||
branch5x5 = self.branch5x5(x)
|
||
# branch5x5 = self.branch5x5_2(branch5x5)
|
||
|
||
# x -> 1x1 -> 3x3 -> 3x3(same)
|
||
branch3x3 = self.branch3x3(x)
|
||
|
||
# x -> pool -> 1x1(same)
|
||
branchpool = self.branchpool(x)
|
||
|
||
outputs = [branch1x1, branch5x5, branch3x3, branchpool]
|
||
|
||
return torch.cat(outputs, 1)
|
||
|
||
|
||
# downsample
|
||
# Factorization into smaller convolutions
|
||
class InceptionB(nn.Module):
|
||
|
||
def __init__(self, input_channels):
|
||
super().__init__()
|
||
|
||
self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2)
|
||
|
||
self.branch3x3stack = nn.Sequential(
|
||
BasicConv2d(input_channels, 64, kernel_size=1),
|
||
BasicConv2d(64, 96, kernel_size=3, padding=1),
|
||
BasicConv2d(96, 96, kernel_size=3, stride=2)
|
||
)
|
||
|
||
self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2)
|
||
|
||
def forward(self, x):
|
||
# x - > 3x3(downsample)
|
||
branch3x3 = self.branch3x3(x)
|
||
|
||
# x -> 3x3 -> 3x3(downsample)
|
||
branch3x3stack = self.branch3x3stack(x)
|
||
|
||
# x -> avgpool(downsample)
|
||
branchpool = self.branchpool(x)
|
||
|
||
# """We can use two parallel stride 2 blocks: P and C. P is a pooling
|
||
# layer (either average or maximum pooling) the activation, both of
|
||
# them are stride 2 the filter banks of which are concatenated as in
|
||
# figure 10."""
|
||
outputs = [branch3x3, branch3x3stack, branchpool]
|
||
|
||
return torch.cat(outputs, 1)
|
||
|
||
|
||
# Factorizing Convolutions with Large Filter Size
|
||
class InceptionC(nn.Module):
|
||
def __init__(self, input_channels, channels_7x7):
|
||
super().__init__()
|
||
self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1)
|
||
|
||
c7 = channels_7x7
|
||
|
||
# In theory, we could go even further and argue that one can replace any n × n
|
||
# convolution by a 1 × n convolution followed by a n × 1 convolution and the
|
||
# computational cost saving increases dramatically as n grows (see figure 6).
|
||
self.branch7x7 = nn.Sequential(
|
||
BasicConv2d(input_channels, c7, kernel_size=1),
|
||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||
BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
||
)
|
||
|
||
self.branch7x7stack = nn.Sequential(
|
||
BasicConv2d(input_channels, c7, kernel_size=1),
|
||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||
BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
|
||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||
BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
||
)
|
||
|
||
self.branch_pool = nn.Sequential(
|
||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||
)
|
||
|
||
def forward(self, x):
|
||
# x -> 1x1(same)
|
||
branch1x1 = self.branch1x1(x)
|
||
|
||
# x -> 1layer 1*7 and 7*1 (same)
|
||
branch7x7 = self.branch7x7(x)
|
||
|
||
# x-> 2layer 1*7 and 7*1(same)
|
||
branch7x7stack = self.branch7x7stack(x)
|
||
|
||
# x-> avgpool (same)
|
||
branchpool = self.branch_pool(x)
|
||
|
||
outputs = [branch1x1, branch7x7, branch7x7stack, branchpool]
|
||
|
||
return torch.cat(outputs, 1)
|
||
|
||
|
||
class InceptionD(nn.Module):
|
||
|
||
def __init__(self, input_channels):
|
||
super().__init__()
|
||
|
||
self.branch3x3 = nn.Sequential(
|
||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||
BasicConv2d(192, 320, kernel_size=3, stride=2)
|
||
)
|
||
|
||
self.branch7x7 = nn.Sequential(
|
||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||
BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)),
|
||
BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)),
|
||
BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||
)
|
||
|
||
self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2)
|
||
|
||
def forward(self, x):
|
||
# x -> 1x1 -> 3x3(downsample)
|
||
branch3x3 = self.branch3x3(x)
|
||
|
||
# x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample)
|
||
branch7x7 = self.branch7x7(x)
|
||
|
||
# x -> avgpool (downsample)
|
||
branchpool = self.branchpool(x)
|
||
|
||
outputs = [branch3x3, branch7x7, branchpool]
|
||
|
||
return torch.cat(outputs, 1)
|
||
|
||
|
||
# same
|
||
class InceptionE(nn.Module):
|
||
def __init__(self, input_channels):
|
||
super().__init__()
|
||
self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1)
|
||
|
||
self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1)
|
||
self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||
self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||
|
||
self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1)
|
||
self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
|
||
self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||
self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||
|
||
self.branch_pool = nn.Sequential(
|
||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||
BasicConv2d(input_channels, 192, kernel_size=1)
|
||
)
|
||
|
||
def forward(self, x):
|
||
# x -> 1x1 (same)
|
||
branch1x1 = self.branch1x1(x)
|
||
|
||
# x -> 1x1 -> 3x1
|
||
# x -> 1x1 -> 1x3
|
||
# concatenate(3x1, 1x3)
|
||
# """7. Inception modules with expanded the filter bank outputs.
|
||
# This architecture is used on the coarsest (8 × 8) grids to promote
|
||
# high dimensional representations, as suggested by principle
|
||
# 2 of Section 2."""
|
||
branch3x3 = self.branch3x3_1(x)
|
||
branch3x3 = [
|
||
self.branch3x3_2a(branch3x3),
|
||
self.branch3x3_2b(branch3x3)
|
||
]
|
||
branch3x3 = torch.cat(branch3x3, 1)
|
||
|
||
# x -> 1x1 -> 3x3 -> 1x3
|
||
# x -> 1x1 -> 3x3 -> 3x1
|
||
# concatenate(1x3, 3x1)
|
||
branch3x3stack = self.branch3x3stack_1(x)
|
||
branch3x3stack = self.branch3x3stack_2(branch3x3stack)
|
||
branch3x3stack = [
|
||
self.branch3x3stack_3a(branch3x3stack),
|
||
self.branch3x3stack_3b(branch3x3stack)
|
||
]
|
||
branch3x3stack = torch.cat(branch3x3stack, 1)
|
||
|
||
branchpool = self.branch_pool(x)
|
||
|
||
outputs = [branch1x1, branch3x3, branch3x3stack, branchpool]
|
||
|
||
return torch.cat(outputs, 1)
|
||
|
||
|
||
class InceptionV3_32x32(nn.Module):
|
||
|
||
def __init__(self, channel, num_classes, record_embedding=False, no_grad=False):
|
||
super().__init__()
|
||
self.Conv2d_1a_3x3 = BasicConv2d(channel, 32, kernel_size=3, padding=3 if channel == 1 else 1)
|
||
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
|
||
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
|
||
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
|
||
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
|
||
|
||
# naive inception module
|
||
self.Mixed_5b = InceptionA(192, pool_features=32)
|
||
self.Mixed_5c = InceptionA(256, pool_features=64)
|
||
self.Mixed_5d = InceptionA(288, pool_features=64)
|
||
|
||
# downsample
|
||
self.Mixed_6a = InceptionB(288)
|
||
|
||
self.Mixed_6b = InceptionC(768, channels_7x7=128)
|
||
self.Mixed_6c = InceptionC(768, channels_7x7=160)
|
||
self.Mixed_6d = InceptionC(768, channels_7x7=160)
|
||
self.Mixed_6e = InceptionC(768, channels_7x7=192)
|
||
|
||
# downsample
|
||
self.Mixed_7a = InceptionD(768)
|
||
|
||
self.Mixed_7b = InceptionE(1280)
|
||
self.Mixed_7c = InceptionE(2048)
|
||
|
||
# 6*6 feature size
|
||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||
self.dropout = nn.Dropout2d()
|
||
self.linear = nn.Linear(2048, num_classes)
|
||
|
||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||
self.no_grad = no_grad
|
||
|
||
def get_last_layer(self):
|
||
return self.linear
|
||
|
||
def forward(self, x):
|
||
with torch.set_grad_enabled(not self.no_grad):
|
||
# 32 -> 30
|
||
x = self.Conv2d_1a_3x3(x)
|
||
x = self.Conv2d_2a_3x3(x)
|
||
x = self.Conv2d_2b_3x3(x)
|
||
x = self.Conv2d_3b_1x1(x)
|
||
x = self.Conv2d_4a_3x3(x)
|
||
|
||
# 30 -> 30
|
||
x = self.Mixed_5b(x)
|
||
x = self.Mixed_5c(x)
|
||
x = self.Mixed_5d(x)
|
||
|
||
# 30 -> 14
|
||
# Efficient Grid Size Reduction to avoid representation
|
||
# bottleneck
|
||
x = self.Mixed_6a(x)
|
||
|
||
# 14 -> 14
|
||
# """In practice, we have found that employing this factorization does not
|
||
# work well on early layers, but it gives very good results on medium
|
||
# grid-sizes (On m × m feature maps, where m ranges between 12 and 20).
|
||
# On that level, very good results can be achieved by using 1 × 7 convolutions
|
||
# followed by 7 × 1 convolutions."""
|
||
x = self.Mixed_6b(x)
|
||
x = self.Mixed_6c(x)
|
||
x = self.Mixed_6d(x)
|
||
x = self.Mixed_6e(x)
|
||
|
||
# 14 -> 6
|
||
# Efficient Grid Size Reduction
|
||
x = self.Mixed_7a(x)
|
||
|
||
# 6 -> 6
|
||
# We are using this solution only on the coarsest grid,
|
||
# since that is the place where producing high dimensional
|
||
# sparse representation is the most critical as the ratio of
|
||
# local processing (by 1 × 1 convolutions) is increased compared
|
||
# to the spatial aggregation."""
|
||
x = self.Mixed_7b(x)
|
||
x = self.Mixed_7c(x)
|
||
|
||
# 6 -> 1
|
||
x = self.avgpool(x)
|
||
x = self.dropout(x)
|
||
x = x.view(x.size(0), -1)
|
||
x = self.embedding_recorder(x)
|
||
x = self.linear(x)
|
||
return x
|
||
|
||
|
||
class InceptionV3_224x224(inception.Inception3):
|
||
def __init__(self, channel: int, num_classes: int, record_embedding: bool = False,
|
||
no_grad: bool = False, **kwargs):
|
||
super().__init__(num_classes=num_classes, **kwargs)
|
||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||
if channel != 3:
|
||
self.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
|
||
self.no_grad = no_grad
|
||
|
||
def get_last_layer(self):
|
||
return self.fc
|
||
|
||
def _forward(self, x):
|
||
with torch.set_grad_enabled(not self.no_grad):
|
||
# N x 3 x 299 x 299
|
||
x = self.Conv2d_1a_3x3(x)
|
||
# N x 32 x 149 x 149
|
||
x = self.Conv2d_2a_3x3(x)
|
||
# N x 32 x 147 x 147
|
||
x = self.Conv2d_2b_3x3(x)
|
||
# N x 64 x 147 x 147
|
||
x = self.maxpool1(x)
|
||
# N x 64 x 73 x 73
|
||
x = self.Conv2d_3b_1x1(x)
|
||
# N x 80 x 73 x 73
|
||
x = self.Conv2d_4a_3x3(x)
|
||
# N x 192 x 71 x 71
|
||
x = self.maxpool2(x)
|
||
# N x 192 x 35 x 35
|
||
x = self.Mixed_5b(x)
|
||
# N x 256 x 35 x 35
|
||
x = self.Mixed_5c(x)
|
||
# N x 288 x 35 x 35
|
||
x = self.Mixed_5d(x)
|
||
# N x 288 x 35 x 35
|
||
x = self.Mixed_6a(x)
|
||
# N x 768 x 17 x 17
|
||
x = self.Mixed_6b(x)
|
||
# N x 768 x 17 x 17
|
||
x = self.Mixed_6c(x)
|
||
# N x 768 x 17 x 17
|
||
x = self.Mixed_6d(x)
|
||
# N x 768 x 17 x 17
|
||
x = self.Mixed_6e(x)
|
||
# N x 768 x 17 x 17
|
||
aux = None
|
||
if self.AuxLogits is not None:
|
||
if self.training:
|
||
aux = self.AuxLogits(x)
|
||
# N x 768 x 17 x 17
|
||
x = self.Mixed_7a(x)
|
||
# N x 1280 x 8 x 8
|
||
x = self.Mixed_7b(x)
|
||
# N x 2048 x 8 x 8
|
||
x = self.Mixed_7c(x)
|
||
# N x 2048 x 8 x 8
|
||
# Adaptive average pooling
|
||
x = self.avgpool(x)
|
||
# N x 2048 x 1 x 1
|
||
x = self.dropout(x)
|
||
# N x 2048 x 1 x 1
|
||
x = torch.flatten(x, 1)
|
||
# N x 2048
|
||
x = self.embedding_recorder(x)
|
||
x = self.fc(x)
|
||
# N x 1000 (num_classes)
|
||
return x, aux
|
||
|
||
|
||
def InceptionV3(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||
pretrained: bool = False):
|
||
if pretrained:
|
||
if im_size[0] != 224 or im_size[1] != 224:
|
||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||
net = InceptionV3_224x224(channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad)
|
||
|
||
from torch.hub import load_state_dict_from_url
|
||
state_dict = load_state_dict_from_url(inception.model_urls["inception_v3_google"], progress=True)
|
||
net.load_state_dict(state_dict)
|
||
|
||
if channel != 3:
|
||
net.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
|
||
if num_classes != 1000:
|
||
net.fc = nn.Linear(net.fc.in_features, num_classes)
|
||
|
||
elif im_size[0] == 224 and im_size[1] == 224:
|
||
net = InceptionV3_224x224(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||
no_grad=no_grad)
|
||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||
net = InceptionV3_32x32(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||
no_grad=no_grad)
|
||
else:
|
||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||
|
||
return net
|