release code
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Source: https://github.com/lukemelas/EfficientNet-PyTorch.
|
||||
"""
|
||||
__version__ = "0.6.4"
|
||||
from .model import (
|
||||
EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2,
|
||||
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6,
|
||||
efficientnet_b7
|
||||
)
|
||||
from .utils import (
|
||||
BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params
|
||||
)
|
||||
@@ -0,0 +1,371 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .utils import (
|
||||
Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats,
|
||||
get_model_params, efficientnet_params, get_same_padding_conv2d,
|
||||
load_pretrained_weights, calculate_output_image_size
|
||||
)
|
||||
from ..build import BACKBONE_REGISTRY
|
||||
from ..backbone import Backbone
|
||||
|
||||
|
||||
class MBConvBlock(nn.Module):
|
||||
"""
|
||||
Mobile Inverted Residual Bottleneck Block
|
||||
|
||||
Args:
|
||||
block_args (namedtuple): BlockArgs, see above
|
||||
global_params (namedtuple): GlobalParam, see above
|
||||
|
||||
Attributes:
|
||||
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, global_params, image_size=None):
|
||||
super().__init__()
|
||||
self._block_args = block_args
|
||||
self._bn_mom = 1 - global_params.batch_norm_momentum
|
||||
self._bn_eps = global_params.batch_norm_epsilon
|
||||
self.has_se = (self._block_args.se_ratio
|
||||
is not None) and (0 < self._block_args.se_ratio <= 1)
|
||||
self.id_skip = block_args.id_skip # skip connection and drop connect
|
||||
|
||||
# Expansion phase
|
||||
inp = self._block_args.input_filters # number of input channels
|
||||
oup = (
|
||||
self._block_args.input_filters * self._block_args.expand_ratio
|
||||
) # number of output channels
|
||||
if self._block_args.expand_ratio != 1:
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._expand_conv = Conv2d(
|
||||
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn0 = nn.BatchNorm2d(
|
||||
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
# image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing
|
||||
|
||||
# Depthwise convolution phase
|
||||
k = self._block_args.kernel_size
|
||||
s = self._block_args.stride
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._depthwise_conv = Conv2d(
|
||||
in_channels=oup,
|
||||
out_channels=oup,
|
||||
groups=oup, # groups makes it depthwise
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
bias=False,
|
||||
)
|
||||
self._bn1 = nn.BatchNorm2d(
|
||||
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
image_size = calculate_output_image_size(image_size, s)
|
||||
|
||||
# Squeeze and Excitation layer, if desired
|
||||
if self.has_se:
|
||||
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
||||
num_squeezed_channels = max(
|
||||
1,
|
||||
int(
|
||||
self._block_args.input_filters * self._block_args.se_ratio
|
||||
)
|
||||
)
|
||||
self._se_reduce = Conv2d(
|
||||
in_channels=oup,
|
||||
out_channels=num_squeezed_channels,
|
||||
kernel_size=1
|
||||
)
|
||||
self._se_expand = Conv2d(
|
||||
in_channels=num_squeezed_channels,
|
||||
out_channels=oup,
|
||||
kernel_size=1
|
||||
)
|
||||
|
||||
# Output phase
|
||||
final_oup = self._block_args.output_filters
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._project_conv = Conv2d(
|
||||
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn2 = nn.BatchNorm2d(
|
||||
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
|
||||
def forward(self, inputs, drop_connect_rate=None):
|
||||
"""
|
||||
:param inputs: input tensor
|
||||
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
|
||||
:return: output of block
|
||||
"""
|
||||
|
||||
# Expansion and Depthwise Convolution
|
||||
x = inputs
|
||||
if self._block_args.expand_ratio != 1:
|
||||
x = self._swish(self._bn0(self._expand_conv(inputs)))
|
||||
x = self._swish(self._bn1(self._depthwise_conv(x)))
|
||||
|
||||
# Squeeze and Excitation
|
||||
if self.has_se:
|
||||
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
||||
x_squeezed = self._se_expand(
|
||||
self._swish(self._se_reduce(x_squeezed))
|
||||
)
|
||||
x = torch.sigmoid(x_squeezed) * x
|
||||
|
||||
x = self._bn2(self._project_conv(x))
|
||||
|
||||
# Skip connection and drop connect
|
||||
input_filters, output_filters = (
|
||||
self._block_args.input_filters,
|
||||
self._block_args.output_filters,
|
||||
)
|
||||
if (
|
||||
self.id_skip and self._block_args.stride == 1
|
||||
and input_filters == output_filters
|
||||
):
|
||||
if drop_connect_rate:
|
||||
x = drop_connect(
|
||||
x, p=drop_connect_rate, training=self.training
|
||||
)
|
||||
x = x + inputs # skip connection
|
||||
return x
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
|
||||
|
||||
class EfficientNet(Backbone):
|
||||
"""
|
||||
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
|
||||
|
||||
Args:
|
||||
blocks_args (list): A list of BlockArgs to construct blocks
|
||||
global_params (namedtuple): A set of GlobalParams shared between blocks
|
||||
|
||||
Example:
|
||||
model = EfficientNet.from_pretrained('efficientnet-b0')
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, blocks_args=None, global_params=None):
|
||||
super().__init__()
|
||||
assert isinstance(blocks_args, list), "blocks_args should be a list"
|
||||
assert len(blocks_args) > 0, "block args must be greater than 0"
|
||||
self._global_params = global_params
|
||||
self._blocks_args = blocks_args
|
||||
|
||||
# Batch norm parameters
|
||||
bn_mom = 1 - self._global_params.batch_norm_momentum
|
||||
bn_eps = self._global_params.batch_norm_epsilon
|
||||
|
||||
# Get stem static or dynamic convolution depending on image size
|
||||
image_size = global_params.image_size
|
||||
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
||||
|
||||
# Stem
|
||||
in_channels = 3 # rgb
|
||||
out_channels = round_filters(
|
||||
32, self._global_params
|
||||
) # number of output channels
|
||||
self._conv_stem = Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
||||
)
|
||||
self._bn0 = nn.BatchNorm2d(
|
||||
num_features=out_channels, momentum=bn_mom, eps=bn_eps
|
||||
)
|
||||
image_size = calculate_output_image_size(image_size, 2)
|
||||
|
||||
# Build blocks
|
||||
self._blocks = nn.ModuleList([])
|
||||
for block_args in self._blocks_args:
|
||||
|
||||
# Update block input and output filters based on depth multiplier.
|
||||
block_args = block_args._replace(
|
||||
input_filters=round_filters(
|
||||
block_args.input_filters, self._global_params
|
||||
),
|
||||
output_filters=round_filters(
|
||||
block_args.output_filters, self._global_params
|
||||
),
|
||||
num_repeat=round_repeats(
|
||||
block_args.num_repeat, self._global_params
|
||||
),
|
||||
)
|
||||
|
||||
# The first block needs to take care of stride and filter size increase.
|
||||
self._blocks.append(
|
||||
MBConvBlock(
|
||||
block_args, self._global_params, image_size=image_size
|
||||
)
|
||||
)
|
||||
image_size = calculate_output_image_size(
|
||||
image_size, block_args.stride
|
||||
)
|
||||
if block_args.num_repeat > 1:
|
||||
block_args = block_args._replace(
|
||||
input_filters=block_args.output_filters, stride=1
|
||||
)
|
||||
for _ in range(block_args.num_repeat - 1):
|
||||
self._blocks.append(
|
||||
MBConvBlock(
|
||||
block_args, self._global_params, image_size=image_size
|
||||
)
|
||||
)
|
||||
# image_size = calculate_output_image_size(image_size, block_args.stride) # ?
|
||||
|
||||
# Head
|
||||
in_channels = block_args.output_filters # output of final block
|
||||
out_channels = round_filters(1280, self._global_params)
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._conv_head = Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn1 = nn.BatchNorm2d(
|
||||
num_features=out_channels, momentum=bn_mom, eps=bn_eps
|
||||
)
|
||||
|
||||
# Final linear layer
|
||||
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
||||
# self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
|
||||
self._out_features = out_channels
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
for block in self._blocks:
|
||||
block.set_swish(memory_efficient)
|
||||
|
||||
def extract_features(self, inputs):
|
||||
"""Returns output of the final convolution layer"""
|
||||
|
||||
# Stem
|
||||
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
||||
|
||||
# Blocks
|
||||
for idx, block in enumerate(self._blocks):
|
||||
drop_connect_rate = self._global_params.drop_connect_rate
|
||||
if drop_connect_rate:
|
||||
drop_connect_rate *= float(idx) / len(self._blocks)
|
||||
x = block(x, drop_connect_rate=drop_connect_rate)
|
||||
|
||||
# Head
|
||||
x = self._swish(self._bn1(self._conv_head(x)))
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Calls extract_features to extract features, applies
|
||||
final linear layer, and returns logits.
|
||||
"""
|
||||
bs = inputs.size(0)
|
||||
# Convolution layers
|
||||
x = self.extract_features(inputs)
|
||||
|
||||
# Pooling and final linear layer
|
||||
x = self._avg_pooling(x)
|
||||
x = x.view(bs, -1)
|
||||
x = self._dropout(x)
|
||||
# x = self._fc(x)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, model_name, override_params=None):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
blocks_args, global_params = get_model_params(
|
||||
model_name, override_params
|
||||
)
|
||||
return cls(blocks_args, global_params)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name, advprop=False, num_classes=1000, in_channels=3
|
||||
):
|
||||
model = cls.from_name(
|
||||
model_name, override_params={"num_classes": num_classes}
|
||||
)
|
||||
load_pretrained_weights(
|
||||
model, model_name, load_fc=(num_classes == 1000), advprop=advprop
|
||||
)
|
||||
model._change_in_channels(in_channels)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_image_size(cls, model_name):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
_, _, res, _ = efficientnet_params(model_name)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _check_model_name_is_valid(cls, model_name):
|
||||
"""Validates model name."""
|
||||
valid_models = ["efficientnet-b" + str(i) for i in range(9)]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError(
|
||||
"model_name should be one of: " + ", ".join(valid_models)
|
||||
)
|
||||
|
||||
def _change_in_channels(model, in_channels):
|
||||
if in_channels != 3:
|
||||
Conv2d = get_same_padding_conv2d(
|
||||
image_size=model._global_params.image_size
|
||||
)
|
||||
out_channels = round_filters(32, model._global_params)
|
||||
model._conv_stem = Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
||||
)
|
||||
|
||||
|
||||
def build_efficientnet(name, pretrained):
|
||||
if pretrained:
|
||||
return EfficientNet.from_pretrained("efficientnet-{}".format(name))
|
||||
else:
|
||||
return EfficientNet.from_name("efficientnet-{}".format(name))
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b0(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b0", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b1(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b1", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b2(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b2", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b3(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b3", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b4(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b4", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b5(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b5", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b6(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b6", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b7(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b7", pretrained)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
This file contains helper functions for building the model and for loading model parameters.
|
||||
These helper functions are built to mirror those in the official TensorFlow implementation.
|
||||
"""
|
||||
|
||||
import re
|
||||
import math
|
||||
import collections
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils import model_zoo
|
||||
|
||||
########################################################################
|
||||
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
|
||||
########################################################################
|
||||
|
||||
# Parameters for the entire model (stem, all blocks, and head)
|
||||
GlobalParams = collections.namedtuple(
|
||||
"GlobalParams",
|
||||
[
|
||||
"batch_norm_momentum",
|
||||
"batch_norm_epsilon",
|
||||
"dropout_rate",
|
||||
"num_classes",
|
||||
"width_coefficient",
|
||||
"depth_coefficient",
|
||||
"depth_divisor",
|
||||
"min_depth",
|
||||
"drop_connect_rate",
|
||||
"image_size",
|
||||
],
|
||||
)
|
||||
|
||||
# Parameters for an individual model block
|
||||
BlockArgs = collections.namedtuple(
|
||||
"BlockArgs",
|
||||
[
|
||||
"kernel_size",
|
||||
"num_repeat",
|
||||
"input_filters",
|
||||
"output_filters",
|
||||
"expand_ratio",
|
||||
"id_skip",
|
||||
"stride",
|
||||
"se_ratio",
|
||||
],
|
||||
)
|
||||
|
||||
# Change namedtuple defaults
|
||||
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
|
||||
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
|
||||
|
||||
|
||||
class SwishImplementation(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, i):
|
||||
result = i * torch.sigmoid(i)
|
||||
ctx.save_for_backward(i)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
i = ctx.saved_variables[0]
|
||||
sigmoid_i = torch.sigmoid(i)
|
||||
return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i)))
|
||||
|
||||
|
||||
class MemoryEfficientSwish(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return SwishImplementation.apply(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def round_filters(filters, global_params):
|
||||
"""Calculate and round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.width_coefficient
|
||||
if not multiplier:
|
||||
return filters
|
||||
divisor = global_params.depth_divisor
|
||||
min_depth = global_params.min_depth
|
||||
filters *= multiplier
|
||||
min_depth = min_depth or divisor
|
||||
new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor)
|
||||
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
|
||||
def round_repeats(repeats, global_params):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.depth_coefficient
|
||||
if not multiplier:
|
||||
return repeats
|
||||
return int(math.ceil(multiplier * repeats))
|
||||
|
||||
|
||||
def drop_connect(inputs, p, training):
|
||||
"""Drop connect."""
|
||||
if not training:
|
||||
return inputs
|
||||
batch_size = inputs.shape[0]
|
||||
keep_prob = 1 - p
|
||||
random_tensor = keep_prob
|
||||
random_tensor += torch.rand(
|
||||
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
|
||||
)
|
||||
binary_tensor = torch.floor(random_tensor)
|
||||
output = inputs / keep_prob * binary_tensor
|
||||
return output
|
||||
|
||||
|
||||
def get_same_padding_conv2d(image_size=None):
|
||||
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
||||
Static padding is necessary for ONNX exporting of models."""
|
||||
if image_size is None:
|
||||
return Conv2dDynamicSamePadding
|
||||
else:
|
||||
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
||||
|
||||
|
||||
def get_width_and_height_from_size(x):
|
||||
"""Obtains width and height from a int or tuple"""
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
if isinstance(x, list) or isinstance(x, tuple):
|
||||
return x
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
|
||||
def calculate_output_image_size(input_image_size, stride):
|
||||
"""
|
||||
Calculates the output image size when using Conv2dSamePadding with a stride.
|
||||
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
||||
"""
|
||||
if input_image_size is None:
|
||||
return None
|
||||
image_height, image_width = get_width_and_height_from_size(
|
||||
input_image_size
|
||||
)
|
||||
stride = stride if isinstance(stride, int) else stride[0]
|
||||
image_height = int(math.ceil(image_height / stride))
|
||||
image_width = int(math.ceil(image_width / stride))
|
||||
return [image_height, image_width]
|
||||
|
||||
|
||||
class Conv2dDynamicSamePadding(nn.Conv2d):
|
||||
"""2D Convolutions like TensorFlow, for a dynamic image size"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
||||
groups, bias
|
||||
)
|
||||
self.stride = self.stride if len(self.stride
|
||||
) == 2 else [self.stride[0]] * 2
|
||||
|
||||
def forward(self, x):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
sh, sw = self.stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max(
|
||||
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
|
||||
)
|
||||
pad_w = max(
|
||||
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
|
||||
)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(
|
||||
x,
|
||||
[pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2]
|
||||
)
|
||||
return F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
class Conv2dStaticSamePadding(nn.Conv2d):
|
||||
"""2D Convolutions like TensorFlow, for a fixed image size"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
image_size=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
||||
self.stride = self.stride if len(self.stride
|
||||
) == 2 else [self.stride[0]] * 2
|
||||
|
||||
# Calculate padding based on image size and save it
|
||||
assert image_size is not None
|
||||
ih, iw = (image_size,
|
||||
image_size) if isinstance(image_size, int) else image_size
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
sh, sw = self.stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max(
|
||||
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
|
||||
)
|
||||
pad_w = max(
|
||||
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
|
||||
)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
self.static_padding = nn.ZeroPad2d(
|
||||
(pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2)
|
||||
)
|
||||
else:
|
||||
self.static_padding = Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.static_padding(x)
|
||||
x = F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self, ):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
########################################################################
|
||||
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
|
||||
########################################################################
|
||||
|
||||
|
||||
def efficientnet_params(model_name):
|
||||
"""Map EfficientNet model name to parameter coefficients."""
|
||||
params_dict = {
|
||||
# Coefficients: width,depth,res,dropout
|
||||
"efficientnet-b0": (1.0, 1.0, 224, 0.2),
|
||||
"efficientnet-b1": (1.0, 1.1, 240, 0.2),
|
||||
"efficientnet-b2": (1.1, 1.2, 260, 0.3),
|
||||
"efficientnet-b3": (1.2, 1.4, 300, 0.3),
|
||||
"efficientnet-b4": (1.4, 1.8, 380, 0.4),
|
||||
"efficientnet-b5": (1.6, 2.2, 456, 0.4),
|
||||
"efficientnet-b6": (1.8, 2.6, 528, 0.5),
|
||||
"efficientnet-b7": (2.0, 3.1, 600, 0.5),
|
||||
"efficientnet-b8": (2.2, 3.6, 672, 0.5),
|
||||
"efficientnet-l2": (4.3, 5.3, 800, 0.5),
|
||||
}
|
||||
return params_dict[model_name]
|
||||
|
||||
|
||||
class BlockDecoder(object):
|
||||
"""Block Decoder for readability, straight from the official TensorFlow repository"""
|
||||
|
||||
@staticmethod
|
||||
def _decode_block_string(block_string):
|
||||
"""Gets a block through a string notation of arguments."""
|
||||
assert isinstance(block_string, str)
|
||||
|
||||
ops = block_string.split("_")
|
||||
options = {}
|
||||
for op in ops:
|
||||
splits = re.split(r"(\d.*)", op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
# Check stride
|
||||
assert ("s" in options and len(options["s"]) == 1) or (
|
||||
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
|
||||
)
|
||||
|
||||
return BlockArgs(
|
||||
kernel_size=int(options["k"]),
|
||||
num_repeat=int(options["r"]),
|
||||
input_filters=int(options["i"]),
|
||||
output_filters=int(options["o"]),
|
||||
expand_ratio=int(options["e"]),
|
||||
id_skip=("noskip" not in block_string),
|
||||
se_ratio=float(options["se"]) if "se" in options else None,
|
||||
stride=[int(options["s"][0])],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _encode_block_string(block):
|
||||
"""Encodes a block to a string."""
|
||||
args = [
|
||||
"r%d" % block.num_repeat,
|
||||
"k%d" % block.kernel_size,
|
||||
"s%d%d" % (block.strides[0], block.strides[1]),
|
||||
"e%s" % block.expand_ratio,
|
||||
"i%d" % block.input_filters,
|
||||
"o%d" % block.output_filters,
|
||||
]
|
||||
if 0 < block.se_ratio <= 1:
|
||||
args.append("se%s" % block.se_ratio)
|
||||
if block.id_skip is False:
|
||||
args.append("noskip")
|
||||
return "_".join(args)
|
||||
|
||||
@staticmethod
|
||||
def decode(string_list):
|
||||
"""
|
||||
Decodes a list of string notations to specify blocks inside the network.
|
||||
|
||||
:param string_list: a list of strings, each string is a notation of block
|
||||
:return: a list of BlockArgs namedtuples of block args
|
||||
"""
|
||||
assert isinstance(string_list, list)
|
||||
blocks_args = []
|
||||
for block_string in string_list:
|
||||
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
||||
return blocks_args
|
||||
|
||||
@staticmethod
|
||||
def encode(blocks_args):
|
||||
"""
|
||||
Encodes a list of BlockArgs to a list of strings.
|
||||
|
||||
:param blocks_args: a list of BlockArgs namedtuples of block args
|
||||
:return: a list of strings, each string is a notation of block
|
||||
"""
|
||||
block_strings = []
|
||||
for block in blocks_args:
|
||||
block_strings.append(BlockDecoder._encode_block_string(block))
|
||||
return block_strings
|
||||
|
||||
|
||||
def efficientnet(
|
||||
width_coefficient=None,
|
||||
depth_coefficient=None,
|
||||
dropout_rate=0.2,
|
||||
drop_connect_rate=0.2,
|
||||
image_size=None,
|
||||
num_classes=1000,
|
||||
):
|
||||
"""Creates a efficientnet model."""
|
||||
|
||||
blocks_args = [
|
||||
"r1_k3_s11_e1_i32_o16_se0.25",
|
||||
"r2_k3_s22_e6_i16_o24_se0.25",
|
||||
"r2_k5_s22_e6_i24_o40_se0.25",
|
||||
"r3_k3_s22_e6_i40_o80_se0.25",
|
||||
"r3_k5_s11_e6_i80_o112_se0.25",
|
||||
"r4_k5_s22_e6_i112_o192_se0.25",
|
||||
"r1_k3_s11_e6_i192_o320_se0.25",
|
||||
]
|
||||
blocks_args = BlockDecoder.decode(blocks_args)
|
||||
|
||||
global_params = GlobalParams(
|
||||
batch_norm_momentum=0.99,
|
||||
batch_norm_epsilon=1e-3,
|
||||
dropout_rate=dropout_rate,
|
||||
drop_connect_rate=drop_connect_rate,
|
||||
# data_format='channels_last', # removed, this is always true in PyTorch
|
||||
num_classes=num_classes,
|
||||
width_coefficient=width_coefficient,
|
||||
depth_coefficient=depth_coefficient,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
return blocks_args, global_params
|
||||
|
||||
|
||||
def get_model_params(model_name, override_params):
|
||||
"""Get the block args and global params for a given model"""
|
||||
if model_name.startswith("efficientnet"):
|
||||
w, d, s, p = efficientnet_params(model_name)
|
||||
# note: all models have drop connect rate = 0.2
|
||||
blocks_args, global_params = efficientnet(
|
||||
width_coefficient=w,
|
||||
depth_coefficient=d,
|
||||
dropout_rate=p,
|
||||
image_size=s
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"model name is not pre-defined: %s" % model_name
|
||||
)
|
||||
if override_params:
|
||||
# ValueError will be raised here if override_params has fields not included in global_params.
|
||||
global_params = global_params._replace(**override_params)
|
||||
return blocks_args, global_params
|
||||
|
||||
|
||||
url_map = {
|
||||
"efficientnet-b0":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
|
||||
"efficientnet-b1":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
|
||||
"efficientnet-b2":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
|
||||
"efficientnet-b3":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
|
||||
"efficientnet-b4":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||
"efficientnet-b5":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||
"efficientnet-b6":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||
"efficientnet-b7":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
|
||||
}
|
||||
|
||||
url_map_advprop = {
|
||||
"efficientnet-b0":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
|
||||
"efficientnet-b1":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
|
||||
"efficientnet-b2":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
|
||||
"efficientnet-b3":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
|
||||
"efficientnet-b4":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
|
||||
"efficientnet-b5":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
|
||||
"efficientnet-b6":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
|
||||
"efficientnet-b7":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
|
||||
"efficientnet-b8":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
|
||||
}
|
||||
|
||||
|
||||
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
|
||||
"""Loads pretrained weights, and downloads if loading for the first time."""
|
||||
# AutoAugment or Advprop (different preprocessing)
|
||||
url_map_ = url_map_advprop if advprop else url_map
|
||||
state_dict = model_zoo.load_url(url_map_[model_name])
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
"""
|
||||
if load_fc:
|
||||
model.load_state_dict(state_dict)
|
||||
else:
|
||||
state_dict.pop('_fc.weight')
|
||||
state_dict.pop('_fc.bias')
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
||||
|
||||
print('Loaded pretrained weights for {}'.format(model_name))
|
||||
"""
|
||||
Reference in New Issue
Block a user