478 lines
16 KiB
Python
478 lines
16 KiB
Python
"""
|
|
This file contains helper functions for building the model and for loading model parameters.
|
|
These helper functions are built to mirror those in the official TensorFlow implementation.
|
|
"""
|
|
|
|
import re
|
|
import math
|
|
import collections
|
|
from functools import partial
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch.utils import model_zoo
|
|
|
|
########################################################################
|
|
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
|
|
########################################################################
|
|
|
|
# Parameters for the entire model (stem, all blocks, and head)
|
|
GlobalParams = collections.namedtuple(
|
|
"GlobalParams",
|
|
[
|
|
"batch_norm_momentum",
|
|
"batch_norm_epsilon",
|
|
"dropout_rate",
|
|
"num_classes",
|
|
"width_coefficient",
|
|
"depth_coefficient",
|
|
"depth_divisor",
|
|
"min_depth",
|
|
"drop_connect_rate",
|
|
"image_size",
|
|
],
|
|
)
|
|
|
|
# Parameters for an individual model block
|
|
BlockArgs = collections.namedtuple(
|
|
"BlockArgs",
|
|
[
|
|
"kernel_size",
|
|
"num_repeat",
|
|
"input_filters",
|
|
"output_filters",
|
|
"expand_ratio",
|
|
"id_skip",
|
|
"stride",
|
|
"se_ratio",
|
|
],
|
|
)
|
|
|
|
# Change namedtuple defaults
|
|
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
|
|
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
|
|
|
|
|
|
class SwishImplementation(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, i):
|
|
result = i * torch.sigmoid(i)
|
|
ctx.save_for_backward(i)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
i = ctx.saved_variables[0]
|
|
sigmoid_i = torch.sigmoid(i)
|
|
return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i)))
|
|
|
|
|
|
class MemoryEfficientSwish(nn.Module):
|
|
|
|
def forward(self, x):
|
|
return SwishImplementation.apply(x)
|
|
|
|
|
|
class Swish(nn.Module):
|
|
|
|
def forward(self, x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
def round_filters(filters, global_params):
|
|
"""Calculate and round number of filters based on depth multiplier."""
|
|
multiplier = global_params.width_coefficient
|
|
if not multiplier:
|
|
return filters
|
|
divisor = global_params.depth_divisor
|
|
min_depth = global_params.min_depth
|
|
filters *= multiplier
|
|
min_depth = min_depth or divisor
|
|
new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor)
|
|
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
|
new_filters += divisor
|
|
return int(new_filters)
|
|
|
|
|
|
def round_repeats(repeats, global_params):
|
|
"""Round number of filters based on depth multiplier."""
|
|
multiplier = global_params.depth_coefficient
|
|
if not multiplier:
|
|
return repeats
|
|
return int(math.ceil(multiplier * repeats))
|
|
|
|
|
|
def drop_connect(inputs, p, training):
|
|
"""Drop connect."""
|
|
if not training:
|
|
return inputs
|
|
batch_size = inputs.shape[0]
|
|
keep_prob = 1 - p
|
|
random_tensor = keep_prob
|
|
random_tensor += torch.rand(
|
|
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
|
|
)
|
|
binary_tensor = torch.floor(random_tensor)
|
|
output = inputs / keep_prob * binary_tensor
|
|
return output
|
|
|
|
|
|
def get_same_padding_conv2d(image_size=None):
|
|
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
|
Static padding is necessary for ONNX exporting of models."""
|
|
if image_size is None:
|
|
return Conv2dDynamicSamePadding
|
|
else:
|
|
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
|
|
|
|
|
def get_width_and_height_from_size(x):
|
|
"""Obtains width and height from a int or tuple"""
|
|
if isinstance(x, int):
|
|
return x, x
|
|
if isinstance(x, list) or isinstance(x, tuple):
|
|
return x
|
|
else:
|
|
raise TypeError()
|
|
|
|
|
|
def calculate_output_image_size(input_image_size, stride):
|
|
"""
|
|
Calculates the output image size when using Conv2dSamePadding with a stride.
|
|
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
|
"""
|
|
if input_image_size is None:
|
|
return None
|
|
image_height, image_width = get_width_and_height_from_size(
|
|
input_image_size
|
|
)
|
|
stride = stride if isinstance(stride, int) else stride[0]
|
|
image_height = int(math.ceil(image_height / stride))
|
|
image_width = int(math.ceil(image_width / stride))
|
|
return [image_height, image_width]
|
|
|
|
|
|
class Conv2dDynamicSamePadding(nn.Conv2d):
|
|
"""2D Convolutions like TensorFlow, for a dynamic image size"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
):
|
|
super().__init__(
|
|
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
|
groups, bias
|
|
)
|
|
self.stride = self.stride if len(self.stride
|
|
) == 2 else [self.stride[0]] * 2
|
|
|
|
def forward(self, x):
|
|
ih, iw = x.size()[-2:]
|
|
kh, kw = self.weight.size()[-2:]
|
|
sh, sw = self.stride
|
|
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
|
pad_h = max(
|
|
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
|
|
)
|
|
pad_w = max(
|
|
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
|
|
)
|
|
if pad_h > 0 or pad_w > 0:
|
|
x = F.pad(
|
|
x,
|
|
[pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2]
|
|
)
|
|
return F.conv2d(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
|
|
|
|
class Conv2dStaticSamePadding(nn.Conv2d):
|
|
"""2D Convolutions like TensorFlow, for a fixed image size"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
image_size=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
|
self.stride = self.stride if len(self.stride
|
|
) == 2 else [self.stride[0]] * 2
|
|
|
|
# Calculate padding based on image size and save it
|
|
assert image_size is not None
|
|
ih, iw = (image_size,
|
|
image_size) if isinstance(image_size, int) else image_size
|
|
kh, kw = self.weight.size()[-2:]
|
|
sh, sw = self.stride
|
|
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
|
pad_h = max(
|
|
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
|
|
)
|
|
pad_w = max(
|
|
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
|
|
)
|
|
if pad_h > 0 or pad_w > 0:
|
|
self.static_padding = nn.ZeroPad2d(
|
|
(pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2)
|
|
)
|
|
else:
|
|
self.static_padding = Identity()
|
|
|
|
def forward(self, x):
|
|
x = self.static_padding(x)
|
|
x = F.conv2d(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
return x
|
|
|
|
|
|
class Identity(nn.Module):
|
|
|
|
def __init__(self, ):
|
|
super(Identity, self).__init__()
|
|
|
|
def forward(self, input):
|
|
return input
|
|
|
|
|
|
########################################################################
|
|
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
|
|
########################################################################
|
|
|
|
|
|
def efficientnet_params(model_name):
|
|
"""Map EfficientNet model name to parameter coefficients."""
|
|
params_dict = {
|
|
# Coefficients: width,depth,res,dropout
|
|
"efficientnet-b0": (1.0, 1.0, 224, 0.2),
|
|
"efficientnet-b1": (1.0, 1.1, 240, 0.2),
|
|
"efficientnet-b2": (1.1, 1.2, 260, 0.3),
|
|
"efficientnet-b3": (1.2, 1.4, 300, 0.3),
|
|
"efficientnet-b4": (1.4, 1.8, 380, 0.4),
|
|
"efficientnet-b5": (1.6, 2.2, 456, 0.4),
|
|
"efficientnet-b6": (1.8, 2.6, 528, 0.5),
|
|
"efficientnet-b7": (2.0, 3.1, 600, 0.5),
|
|
"efficientnet-b8": (2.2, 3.6, 672, 0.5),
|
|
"efficientnet-l2": (4.3, 5.3, 800, 0.5),
|
|
}
|
|
return params_dict[model_name]
|
|
|
|
|
|
class BlockDecoder(object):
|
|
"""Block Decoder for readability, straight from the official TensorFlow repository"""
|
|
|
|
@staticmethod
|
|
def _decode_block_string(block_string):
|
|
"""Gets a block through a string notation of arguments."""
|
|
assert isinstance(block_string, str)
|
|
|
|
ops = block_string.split("_")
|
|
options = {}
|
|
for op in ops:
|
|
splits = re.split(r"(\d.*)", op)
|
|
if len(splits) >= 2:
|
|
key, value = splits[:2]
|
|
options[key] = value
|
|
|
|
# Check stride
|
|
assert ("s" in options and len(options["s"]) == 1) or (
|
|
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
|
|
)
|
|
|
|
return BlockArgs(
|
|
kernel_size=int(options["k"]),
|
|
num_repeat=int(options["r"]),
|
|
input_filters=int(options["i"]),
|
|
output_filters=int(options["o"]),
|
|
expand_ratio=int(options["e"]),
|
|
id_skip=("noskip" not in block_string),
|
|
se_ratio=float(options["se"]) if "se" in options else None,
|
|
stride=[int(options["s"][0])],
|
|
)
|
|
|
|
@staticmethod
|
|
def _encode_block_string(block):
|
|
"""Encodes a block to a string."""
|
|
args = [
|
|
"r%d" % block.num_repeat,
|
|
"k%d" % block.kernel_size,
|
|
"s%d%d" % (block.strides[0], block.strides[1]),
|
|
"e%s" % block.expand_ratio,
|
|
"i%d" % block.input_filters,
|
|
"o%d" % block.output_filters,
|
|
]
|
|
if 0 < block.se_ratio <= 1:
|
|
args.append("se%s" % block.se_ratio)
|
|
if block.id_skip is False:
|
|
args.append("noskip")
|
|
return "_".join(args)
|
|
|
|
@staticmethod
|
|
def decode(string_list):
|
|
"""
|
|
Decodes a list of string notations to specify blocks inside the network.
|
|
|
|
:param string_list: a list of strings, each string is a notation of block
|
|
:return: a list of BlockArgs namedtuples of block args
|
|
"""
|
|
assert isinstance(string_list, list)
|
|
blocks_args = []
|
|
for block_string in string_list:
|
|
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
|
return blocks_args
|
|
|
|
@staticmethod
|
|
def encode(blocks_args):
|
|
"""
|
|
Encodes a list of BlockArgs to a list of strings.
|
|
|
|
:param blocks_args: a list of BlockArgs namedtuples of block args
|
|
:return: a list of strings, each string is a notation of block
|
|
"""
|
|
block_strings = []
|
|
for block in blocks_args:
|
|
block_strings.append(BlockDecoder._encode_block_string(block))
|
|
return block_strings
|
|
|
|
|
|
def efficientnet(
|
|
width_coefficient=None,
|
|
depth_coefficient=None,
|
|
dropout_rate=0.2,
|
|
drop_connect_rate=0.2,
|
|
image_size=None,
|
|
num_classes=1000,
|
|
):
|
|
"""Creates a efficientnet model."""
|
|
|
|
blocks_args = [
|
|
"r1_k3_s11_e1_i32_o16_se0.25",
|
|
"r2_k3_s22_e6_i16_o24_se0.25",
|
|
"r2_k5_s22_e6_i24_o40_se0.25",
|
|
"r3_k3_s22_e6_i40_o80_se0.25",
|
|
"r3_k5_s11_e6_i80_o112_se0.25",
|
|
"r4_k5_s22_e6_i112_o192_se0.25",
|
|
"r1_k3_s11_e6_i192_o320_se0.25",
|
|
]
|
|
blocks_args = BlockDecoder.decode(blocks_args)
|
|
|
|
global_params = GlobalParams(
|
|
batch_norm_momentum=0.99,
|
|
batch_norm_epsilon=1e-3,
|
|
dropout_rate=dropout_rate,
|
|
drop_connect_rate=drop_connect_rate,
|
|
# data_format='channels_last', # removed, this is always true in PyTorch
|
|
num_classes=num_classes,
|
|
width_coefficient=width_coefficient,
|
|
depth_coefficient=depth_coefficient,
|
|
depth_divisor=8,
|
|
min_depth=None,
|
|
image_size=image_size,
|
|
)
|
|
|
|
return blocks_args, global_params
|
|
|
|
|
|
def get_model_params(model_name, override_params):
|
|
"""Get the block args and global params for a given model"""
|
|
if model_name.startswith("efficientnet"):
|
|
w, d, s, p = efficientnet_params(model_name)
|
|
# note: all models have drop connect rate = 0.2
|
|
blocks_args, global_params = efficientnet(
|
|
width_coefficient=w,
|
|
depth_coefficient=d,
|
|
dropout_rate=p,
|
|
image_size=s
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
"model name is not pre-defined: %s" % model_name
|
|
)
|
|
if override_params:
|
|
# ValueError will be raised here if override_params has fields not included in global_params.
|
|
global_params = global_params._replace(**override_params)
|
|
return blocks_args, global_params
|
|
|
|
|
|
url_map = {
|
|
"efficientnet-b0":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
|
|
"efficientnet-b1":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
|
|
"efficientnet-b2":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
|
|
"efficientnet-b3":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
|
|
"efficientnet-b4":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
|
"efficientnet-b5":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
|
"efficientnet-b6":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
|
"efficientnet-b7":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
|
|
}
|
|
|
|
url_map_advprop = {
|
|
"efficientnet-b0":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
|
|
"efficientnet-b1":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
|
|
"efficientnet-b2":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
|
|
"efficientnet-b3":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
|
|
"efficientnet-b4":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
|
|
"efficientnet-b5":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
|
|
"efficientnet-b6":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
|
|
"efficientnet-b7":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
|
|
"efficientnet-b8":
|
|
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
|
|
}
|
|
|
|
|
|
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
|
|
"""Loads pretrained weights, and downloads if loading for the first time."""
|
|
# AutoAugment or Advprop (different preprocessing)
|
|
url_map_ = url_map_advprop if advprop else url_map
|
|
state_dict = model_zoo.load_url(url_map_[model_name])
|
|
model.load_state_dict(state_dict, strict=False)
|
|
"""
|
|
if load_fc:
|
|
model.load_state_dict(state_dict)
|
|
else:
|
|
state_dict.pop('_fc.weight')
|
|
state_dict.pop('_fc.bias')
|
|
res = model.load_state_dict(state_dict, strict=False)
|
|
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
|
|
|
print('Loaded pretrained weights for {}'.format(model_name))
|
|
"""
|