Upload to Main
This commit is contained in:
128
deepcore/nets/vgg.py
Normal file
128
deepcore/nets/vgg.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch.nn as nn
|
||||
from torch import set_grad_enabled, flatten, Tensor
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
from torchvision.models import vgg
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
cfg_vgg = {
|
||||
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
|
||||
class VGG_32x32(nn.Module):
|
||||
def __init__(self, vgg_name, channel, num_classes, record_embedding=False, no_grad=False):
|
||||
super(VGG_32x32, self).__init__()
|
||||
self.channel = channel
|
||||
self.features = self._make_layers(cfg_vgg[vgg_name])
|
||||
self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.classifier
|
||||
|
||||
def _make_layers(self, cfg):
|
||||
layers = []
|
||||
in_channels = self.channel
|
||||
for ic, x in enumerate(cfg):
|
||||
if x == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel == 1 and ic == 0 else 1),
|
||||
nn.BatchNorm2d(x),
|
||||
nn.ReLU(inplace=True)]
|
||||
in_channels = x
|
||||
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class VGG_224x224(vgg.VGG):
|
||||
def __init__(self, features: nn.Module, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super(VGG_224x224, self).__init__(features, num_classes, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.features[0] = nn.Conv2d(channel, 64, kernel_size=3, padding=1)
|
||||
self.fc = self.classifier[-1]
|
||||
self.classifier[-1] = self.embedding_recorder
|
||||
self.classifier.add_module("fc", self.fc)
|
||||
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def VGG(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
net = VGG_224x224(features=vgg.make_layers(cfg_vgg[arch], True), channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(vgg.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.features[0] = nn.Conv2d(channel, 64, kernel_size=3, padding=1)
|
||||
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(4096, num_classes)
|
||||
net.classifier[-1] = net.fc
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
net = VGG_224x224(features=vgg.make_layers(cfg_vgg[arch], True), channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
net = VGG_32x32(arch, channel, num_classes=num_classes, record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def VGG11(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG("vgg11", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG13(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg13', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG16(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg16', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG19(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg19', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
Reference in New Issue
Block a user