Files
DAPT/deepcore/nets/inceptionv3.py
2025-10-07 22:42:55 +08:00

427 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
from torchvision.models import inception
from .nets_utils import EmbeddingRecorder
class BasicConv2d(nn.Module):
def __init__(self, input_channels, output_channels, **kwargs):
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(output_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
# same naive inception module
class InceptionA(nn.Module):
def __init__(self, input_channels, pool_features):
super().__init__()
self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1)
self.branch5x5 = nn.Sequential(
BasicConv2d(input_channels, 48, kernel_size=1),
BasicConv2d(48, 64, kernel_size=5, padding=2)
)
self.branch3x3 = nn.Sequential(
BasicConv2d(input_channels, 64, kernel_size=1),
BasicConv2d(64, 96, kernel_size=3, padding=1),
BasicConv2d(96, 96, kernel_size=3, padding=1)
)
self.branchpool = nn.Sequential(
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1)
)
def forward(self, x):
# x -> 1x1(same)
branch1x1 = self.branch1x1(x)
# x -> 1x1 -> 5x5(same)
branch5x5 = self.branch5x5(x)
# branch5x5 = self.branch5x5_2(branch5x5)
# x -> 1x1 -> 3x3 -> 3x3(same)
branch3x3 = self.branch3x3(x)
# x -> pool -> 1x1(same)
branchpool = self.branchpool(x)
outputs = [branch1x1, branch5x5, branch3x3, branchpool]
return torch.cat(outputs, 1)
# downsample
# Factorization into smaller convolutions
class InceptionB(nn.Module):
def __init__(self, input_channels):
super().__init__()
self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2)
self.branch3x3stack = nn.Sequential(
BasicConv2d(input_channels, 64, kernel_size=1),
BasicConv2d(64, 96, kernel_size=3, padding=1),
BasicConv2d(96, 96, kernel_size=3, stride=2)
)
self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2)
def forward(self, x):
# x - > 3x3(downsample)
branch3x3 = self.branch3x3(x)
# x -> 3x3 -> 3x3(downsample)
branch3x3stack = self.branch3x3stack(x)
# x -> avgpool(downsample)
branchpool = self.branchpool(x)
# """We can use two parallel stride 2 blocks: P and C. P is a pooling
# layer (either average or maximum pooling) the activation, both of
# them are stride 2 the filter banks of which are concatenated as in
# figure 10."""
outputs = [branch3x3, branch3x3stack, branchpool]
return torch.cat(outputs, 1)
# Factorizing Convolutions with Large Filter Size
class InceptionC(nn.Module):
def __init__(self, input_channels, channels_7x7):
super().__init__()
self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1)
c7 = channels_7x7
# In theory, we could go even further and argue that one can replace any n × n
# convolution by a 1 × n convolution followed by a n × 1 convolution and the
# computational cost saving increases dramatically as n grows (see figure 6).
self.branch7x7 = nn.Sequential(
BasicConv2d(input_channels, c7, kernel_size=1),
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
)
self.branch7x7stack = nn.Sequential(
BasicConv2d(input_channels, c7, kernel_size=1),
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
)
self.branch_pool = nn.Sequential(
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
BasicConv2d(input_channels, 192, kernel_size=1),
)
def forward(self, x):
# x -> 1x1(same)
branch1x1 = self.branch1x1(x)
# x -> 1layer 1*7 and 7*1 (same)
branch7x7 = self.branch7x7(x)
# x-> 2layer 1*7 and 7*1(same)
branch7x7stack = self.branch7x7stack(x)
# x-> avgpool (same)
branchpool = self.branch_pool(x)
outputs = [branch1x1, branch7x7, branch7x7stack, branchpool]
return torch.cat(outputs, 1)
class InceptionD(nn.Module):
def __init__(self, input_channels):
super().__init__()
self.branch3x3 = nn.Sequential(
BasicConv2d(input_channels, 192, kernel_size=1),
BasicConv2d(192, 320, kernel_size=3, stride=2)
)
self.branch7x7 = nn.Sequential(
BasicConv2d(input_channels, 192, kernel_size=1),
BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)),
BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(192, 192, kernel_size=3, stride=2)
)
self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2)
def forward(self, x):
# x -> 1x1 -> 3x3(downsample)
branch3x3 = self.branch3x3(x)
# x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample)
branch7x7 = self.branch7x7(x)
# x -> avgpool (downsample)
branchpool = self.branchpool(x)
outputs = [branch3x3, branch7x7, branchpool]
return torch.cat(outputs, 1)
# same
class InceptionE(nn.Module):
def __init__(self, input_channels):
super().__init__()
self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1)
self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1)
self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1)
self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch_pool = nn.Sequential(
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
BasicConv2d(input_channels, 192, kernel_size=1)
)
def forward(self, x):
# x -> 1x1 (same)
branch1x1 = self.branch1x1(x)
# x -> 1x1 -> 3x1
# x -> 1x1 -> 1x3
# concatenate(3x1, 1x3)
# """7. Inception modules with expanded the filter bank outputs.
# This architecture is used on the coarsest (8 × 8) grids to promote
# high dimensional representations, as suggested by principle
# 2 of Section 2."""
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3)
]
branch3x3 = torch.cat(branch3x3, 1)
# x -> 1x1 -> 3x3 -> 1x3
# x -> 1x1 -> 3x3 -> 3x1
# concatenate(1x3, 3x1)
branch3x3stack = self.branch3x3stack_1(x)
branch3x3stack = self.branch3x3stack_2(branch3x3stack)
branch3x3stack = [
self.branch3x3stack_3a(branch3x3stack),
self.branch3x3stack_3b(branch3x3stack)
]
branch3x3stack = torch.cat(branch3x3stack, 1)
branchpool = self.branch_pool(x)
outputs = [branch1x1, branch3x3, branch3x3stack, branchpool]
return torch.cat(outputs, 1)
class InceptionV3_32x32(nn.Module):
def __init__(self, channel, num_classes, record_embedding=False, no_grad=False):
super().__init__()
self.Conv2d_1a_3x3 = BasicConv2d(channel, 32, kernel_size=3, padding=3 if channel == 1 else 1)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
# naive inception module
self.Mixed_5b = InceptionA(192, pool_features=32)
self.Mixed_5c = InceptionA(256, pool_features=64)
self.Mixed_5d = InceptionA(288, pool_features=64)
# downsample
self.Mixed_6a = InceptionB(288)
self.Mixed_6b = InceptionC(768, channels_7x7=128)
self.Mixed_6c = InceptionC(768, channels_7x7=160)
self.Mixed_6d = InceptionC(768, channels_7x7=160)
self.Mixed_6e = InceptionC(768, channels_7x7=192)
# downsample
self.Mixed_7a = InceptionD(768)
self.Mixed_7b = InceptionE(1280)
self.Mixed_7c = InceptionE(2048)
# 6*6 feature size
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout2d()
self.linear = nn.Linear(2048, num_classes)
self.embedding_recorder = EmbeddingRecorder(record_embedding)
self.no_grad = no_grad
def get_last_layer(self):
return self.linear
def forward(self, x):
with torch.set_grad_enabled(not self.no_grad):
# 32 -> 30
x = self.Conv2d_1a_3x3(x)
x = self.Conv2d_2a_3x3(x)
x = self.Conv2d_2b_3x3(x)
x = self.Conv2d_3b_1x1(x)
x = self.Conv2d_4a_3x3(x)
# 30 -> 30
x = self.Mixed_5b(x)
x = self.Mixed_5c(x)
x = self.Mixed_5d(x)
# 30 -> 14
# Efficient Grid Size Reduction to avoid representation
# bottleneck
x = self.Mixed_6a(x)
# 14 -> 14
# """In practice, we have found that employing this factorization does not
# work well on early layers, but it gives very good results on medium
# grid-sizes (On m × m feature maps, where m ranges between 12 and 20).
# On that level, very good results can be achieved by using 1 × 7 convolutions
# followed by 7 × 1 convolutions."""
x = self.Mixed_6b(x)
x = self.Mixed_6c(x)
x = self.Mixed_6d(x)
x = self.Mixed_6e(x)
# 14 -> 6
# Efficient Grid Size Reduction
x = self.Mixed_7a(x)
# 6 -> 6
# We are using this solution only on the coarsest grid,
# since that is the place where producing high dimensional
# sparse representation is the most critical as the ratio of
# local processing (by 1 × 1 convolutions) is increased compared
# to the spatial aggregation."""
x = self.Mixed_7b(x)
x = self.Mixed_7c(x)
# 6 -> 1
x = self.avgpool(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.embedding_recorder(x)
x = self.linear(x)
return x
class InceptionV3_224x224(inception.Inception3):
def __init__(self, channel: int, num_classes: int, record_embedding: bool = False,
no_grad: bool = False, **kwargs):
super().__init__(num_classes=num_classes, **kwargs)
self.embedding_recorder = EmbeddingRecorder(record_embedding)
if channel != 3:
self.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
self.no_grad = no_grad
def get_last_layer(self):
return self.fc
def _forward(self, x):
with torch.set_grad_enabled(not self.no_grad):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = self.maxpool1(x)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = self.maxpool2(x)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux = None
if self.AuxLogits is not None:
if self.training:
aux = self.AuxLogits(x)
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
# Adaptive average pooling
x = self.avgpool(x)
# N x 2048 x 1 x 1
x = self.dropout(x)
# N x 2048 x 1 x 1
x = torch.flatten(x, 1)
# N x 2048
x = self.embedding_recorder(x)
x = self.fc(x)
# N x 1000 (num_classes)
return x, aux
def InceptionV3(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
pretrained: bool = False):
if pretrained:
if im_size[0] != 224 or im_size[1] != 224:
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
net = InceptionV3_224x224(channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad)
from torch.hub import load_state_dict_from_url
state_dict = load_state_dict_from_url(inception.model_urls["inception_v3_google"], progress=True)
net.load_state_dict(state_dict)
if channel != 3:
net.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
if num_classes != 1000:
net.fc = nn.Linear(net.fc.in_features, num_classes)
elif im_size[0] == 224 and im_size[1] == 224:
net = InceptionV3_224x224(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
no_grad=no_grad)
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
net = InceptionV3_32x32(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
no_grad=no_grad)
else:
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
return net