release code
This commit is contained in:
5
Dassl.ProGrad.pytorch/dassl/utils/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .tools import *
|
||||
from .logger import *
|
||||
from .meters import *
|
||||
from .registry import *
|
||||
from .torchtools import *
|
||||
73
Dassl.ProGrad.pytorch/dassl/utils/logger.py
Normal file
73
Dassl.ProGrad.pytorch/dassl/utils/logger.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import os.path as osp
|
||||
|
||||
from .tools import mkdir_if_missing
|
||||
|
||||
__all__ = ["Logger", "setup_logger"]
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Write console output to external text file.
|
||||
|
||||
Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py>`_
|
||||
|
||||
Args:
|
||||
fpath (str): directory to save logging file.
|
||||
|
||||
Examples::
|
||||
>>> import sys
|
||||
>>> import os.path as osp
|
||||
>>> save_dir = 'output/experiment-1'
|
||||
>>> log_name = 'train.log'
|
||||
>>> sys.stdout = Logger(osp.join(save_dir, log_name))
|
||||
"""
|
||||
|
||||
def __init__(self, fpath=None):
|
||||
self.console = sys.stdout
|
||||
self.file = None
|
||||
if fpath is not None:
|
||||
mkdir_if_missing(osp.dirname(fpath))
|
||||
self.file = open(fpath, "w")
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
def write(self, msg):
|
||||
self.console.write(msg)
|
||||
if self.file is not None:
|
||||
self.file.write(msg)
|
||||
|
||||
def flush(self):
|
||||
self.console.flush()
|
||||
if self.file is not None:
|
||||
self.file.flush()
|
||||
os.fsync(self.file.fileno())
|
||||
|
||||
def close(self):
|
||||
self.console.close()
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
|
||||
|
||||
def setup_logger(output=None):
|
||||
if output is None:
|
||||
return
|
||||
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
fpath = output
|
||||
else:
|
||||
fpath = osp.join(output, "log.txt")
|
||||
|
||||
if osp.exists(fpath):
|
||||
# make sure the existing log file is not over-written
|
||||
fpath += time.strftime("-%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
sys.stdout = Logger(fpath)
|
||||
80
Dassl.ProGrad.pytorch/dassl/utils/meters.py
Normal file
80
Dassl.ProGrad.pytorch/dassl/utils/meters.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
|
||||
__all__ = ["AverageMeter", "MetricMeter"]
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Compute and store the average and current value.
|
||||
|
||||
Examples::
|
||||
>>> # 1. Initialize a meter to record loss
|
||||
>>> losses = AverageMeter()
|
||||
>>> # 2. Update meter after every mini-batch update
|
||||
>>> losses.update(loss_value, batch_size)
|
||||
"""
|
||||
|
||||
def __init__(self, ema=False):
|
||||
"""
|
||||
Args:
|
||||
ema (bool, optional): apply exponential moving average.
|
||||
"""
|
||||
self.ema = ema
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
if isinstance(val, torch.Tensor):
|
||||
val = val.item()
|
||||
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
|
||||
if self.ema:
|
||||
self.avg = self.avg * 0.9 + self.val * 0.1
|
||||
else:
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class MetricMeter:
|
||||
"""Store the average and current value for a set of metrics.
|
||||
|
||||
Examples::
|
||||
>>> # 1. Create an instance of MetricMeter
|
||||
>>> metric = MetricMeter()
|
||||
>>> # 2. Update using a dictionary as input
|
||||
>>> input_dict = {'loss_1': value_1, 'loss_2': value_2}
|
||||
>>> metric.update(input_dict)
|
||||
>>> # 3. Convert to string and print
|
||||
>>> print(str(metric))
|
||||
"""
|
||||
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(AverageMeter)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, input_dict):
|
||||
if input_dict is None:
|
||||
return
|
||||
|
||||
if not isinstance(input_dict, dict):
|
||||
raise TypeError(
|
||||
"Input to MetricMeter.update() must be a dictionary"
|
||||
)
|
||||
|
||||
for k, v in input_dict.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __str__(self):
|
||||
output_str = []
|
||||
for name, meter in self.meters.items():
|
||||
output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})")
|
||||
return self.delimiter.join(output_str)
|
||||
69
Dassl.ProGrad.pytorch/dassl/utils/registry.py
Normal file
69
Dassl.ProGrad.pytorch/dassl/utils/registry.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Modified from https://github.com/facebookresearch/fvcore
|
||||
"""
|
||||
__all__ = ["Registry"]
|
||||
|
||||
|
||||
class Registry:
|
||||
"""A registry providing name -> object mapping, to support
|
||||
custom modules.
|
||||
|
||||
To create a registry (e.g. a backbone registry):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||
|
||||
To register an object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
class MyBackbone(nn.Module):
|
||||
...
|
||||
|
||||
Or:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY.register(MyBackbone)
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
self._obj_map = dict()
|
||||
|
||||
def _do_register(self, name, obj, force=False):
|
||||
if name in self._obj_map and not force:
|
||||
raise KeyError(
|
||||
'An object named "{}" was already '
|
||||
'registered in "{}" registry'.format(name, self._name)
|
||||
)
|
||||
|
||||
self._obj_map[name] = obj
|
||||
|
||||
def register(self, obj=None, force=False):
|
||||
if obj is None:
|
||||
# Used as a decorator
|
||||
def wrapper(fn_or_class):
|
||||
name = fn_or_class.__name__
|
||||
self._do_register(name, fn_or_class, force=force)
|
||||
return fn_or_class
|
||||
|
||||
return wrapper
|
||||
|
||||
# Used as a function call
|
||||
name = obj.__name__
|
||||
self._do_register(name, obj, force=force)
|
||||
|
||||
def get(self, name):
|
||||
if name not in self._obj_map:
|
||||
raise KeyError(
|
||||
'Object name "{}" does not exist '
|
||||
'in "{}" registry'.format(name, self._name)
|
||||
)
|
||||
|
||||
return self._obj_map[name]
|
||||
|
||||
def registered_names(self):
|
||||
return list(self._obj_map.keys())
|
||||
196
Dassl.ProGrad.pytorch/dassl/utils/tools.py
Normal file
196
Dassl.ProGrad.pytorch/dassl/utils/tools.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Modified from https://github.com/KaiyangZhou/deep-person-reid
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import errno
|
||||
import numpy as np
|
||||
import random
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from difflib import SequenceMatcher
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
__all__ = [
|
||||
"mkdir_if_missing",
|
||||
"check_isfile",
|
||||
"read_json",
|
||||
"write_json",
|
||||
"set_random_seed",
|
||||
"download_url",
|
||||
"read_image",
|
||||
"collect_env_info",
|
||||
"listdir_nohidden",
|
||||
"get_most_similar_str_to_a_from_b",
|
||||
"check_availability",
|
||||
"tolist_if_not",
|
||||
]
|
||||
|
||||
|
||||
def mkdir_if_missing(dirname):
|
||||
"""Create dirname if it is missing."""
|
||||
if not osp.exists(dirname):
|
||||
try:
|
||||
os.makedirs(dirname)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def check_isfile(fpath):
|
||||
"""Check if the given path is a file.
|
||||
|
||||
Args:
|
||||
fpath (str): file path.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
isfile = osp.isfile(fpath)
|
||||
if not isfile:
|
||||
warnings.warn('No file found at "{}"'.format(fpath))
|
||||
return isfile
|
||||
|
||||
|
||||
def read_json(fpath):
|
||||
"""Read json file from a path."""
|
||||
with open(fpath, "r") as f:
|
||||
obj = json.load(f)
|
||||
return obj
|
||||
|
||||
|
||||
def write_json(obj, fpath):
|
||||
"""Writes to a json file."""
|
||||
mkdir_if_missing(osp.dirname(fpath))
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(obj, f, indent=4, separators=(",", ": "))
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def download_url(url, dst):
|
||||
"""Download file from a url to a destination.
|
||||
|
||||
Args:
|
||||
url (str): url to download file.
|
||||
dst (str): destination path.
|
||||
"""
|
||||
from six.moves import urllib
|
||||
|
||||
print('* url="{}"'.format(url))
|
||||
print('* destination="{}"'.format(dst))
|
||||
|
||||
def _reporthook(count, block_size, total_size):
|
||||
global start_time
|
||||
if count == 0:
|
||||
start_time = time.time()
|
||||
return
|
||||
duration = time.time() - start_time
|
||||
progress_size = int(count * block_size)
|
||||
speed = int(progress_size / (1024*duration))
|
||||
percent = int(count * block_size * 100 / total_size)
|
||||
sys.stdout.write(
|
||||
"\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
|
||||
(percent, progress_size / (1024*1024), speed, duration)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
urllib.request.urlretrieve(url, dst, _reporthook)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
def read_image(path):
|
||||
"""Read image from path using ``PIL.Image``.
|
||||
|
||||
Args:
|
||||
path (str): path to an image.
|
||||
|
||||
Returns:
|
||||
PIL image
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
raise IOError("No file exists at {}".format(path))
|
||||
|
||||
while True:
|
||||
try:
|
||||
img = Image.open(path).convert("RGB")
|
||||
return img
|
||||
except IOError:
|
||||
print(
|
||||
"Cannot read image from {}, "
|
||||
"probably due to heavy IO. Will re-try".format(path)
|
||||
)
|
||||
|
||||
|
||||
def collect_env_info():
|
||||
"""Return env info as a string.
|
||||
|
||||
Code source: github.com/facebookresearch/maskrcnn-benchmark
|
||||
"""
|
||||
from torch.utils.collect_env import get_pretty_env_info
|
||||
|
||||
env_str = get_pretty_env_info()
|
||||
env_str += "\n Pillow ({})".format(PIL.__version__)
|
||||
return env_str
|
||||
|
||||
|
||||
def listdir_nohidden(path, sort=False):
|
||||
"""List non-hidden items in a directory.
|
||||
|
||||
Args:
|
||||
path (str): directory path.
|
||||
sort (bool): sort the items.
|
||||
"""
|
||||
items = [f for f in os.listdir(path) if not f.startswith(".")]
|
||||
if sort:
|
||||
items.sort()
|
||||
return items
|
||||
|
||||
|
||||
def get_most_similar_str_to_a_from_b(a, b):
|
||||
"""Return the most similar string to a in b.
|
||||
|
||||
Args:
|
||||
a (str): probe string.
|
||||
b (list): a list of candidate strings.
|
||||
"""
|
||||
highest_sim = 0
|
||||
chosen = None
|
||||
for candidate in b:
|
||||
sim = SequenceMatcher(None, a, candidate).ratio()
|
||||
if sim >= highest_sim:
|
||||
highest_sim = sim
|
||||
chosen = candidate
|
||||
return chosen
|
||||
|
||||
|
||||
def check_availability(requested, available):
|
||||
"""Check if an element is available in a list.
|
||||
|
||||
Args:
|
||||
requested (str): probe string.
|
||||
available (list): a list of available strings.
|
||||
"""
|
||||
if requested not in available:
|
||||
psb_ans = get_most_similar_str_to_a_from_b(requested, available)
|
||||
raise ValueError(
|
||||
"The requested one is expected "
|
||||
"to belong to {}, but got [{}] "
|
||||
"(do you mean [{}]?)".format(available, requested, psb_ans)
|
||||
)
|
||||
|
||||
|
||||
def tolist_if_not(x):
|
||||
"""Convert to a list."""
|
||||
if not isinstance(x, list):
|
||||
x = [x]
|
||||
return x
|
||||
356
Dassl.ProGrad.pytorch/dassl/utils/torchtools.py
Normal file
356
Dassl.ProGrad.pytorch/dassl/utils/torchtools.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""
|
||||
Modified from https://github.com/KaiyangZhou/deep-person-reid
|
||||
"""
|
||||
import pickle
|
||||
import shutil
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .tools import mkdir_if_missing
|
||||
|
||||
__all__ = [
|
||||
"save_checkpoint",
|
||||
"load_checkpoint",
|
||||
"resume_from_checkpoint",
|
||||
"open_all_layers",
|
||||
"open_specified_layers",
|
||||
"count_num_param",
|
||||
"load_pretrained_weights",
|
||||
"init_network_weights",
|
||||
]
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
state,
|
||||
save_dir,
|
||||
is_best=False,
|
||||
remove_module_from_keys=True,
|
||||
model_name=""
|
||||
):
|
||||
r"""Save checkpoint.
|
||||
|
||||
Args:
|
||||
state (dict): dictionary.
|
||||
save_dir (str): directory to save checkpoint.
|
||||
is_best (bool, optional): if True, this checkpoint will be copied and named
|
||||
``model-best.pth.tar``. Default is False.
|
||||
remove_module_from_keys (bool, optional): whether to remove "module."
|
||||
from layer names. Default is True.
|
||||
model_name (str, optional): model name to save.
|
||||
|
||||
Examples::
|
||||
>>> state = {
|
||||
>>> 'state_dict': model.state_dict(),
|
||||
>>> 'epoch': 10,
|
||||
>>> 'optimizer': optimizer.state_dict()
|
||||
>>> }
|
||||
>>> save_checkpoint(state, 'log/my_model')
|
||||
"""
|
||||
mkdir_if_missing(save_dir)
|
||||
|
||||
if remove_module_from_keys:
|
||||
# remove 'module.' in state_dict's keys
|
||||
state_dict = state["state_dict"]
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("module."):
|
||||
k = k[7:]
|
||||
new_state_dict[k] = v
|
||||
state["state_dict"] = new_state_dict
|
||||
|
||||
# save model
|
||||
epoch = state["epoch"]
|
||||
if not model_name:
|
||||
model_name = "model.pth.tar-" + str(epoch)
|
||||
fpath = osp.join(save_dir, model_name)
|
||||
torch.save(state, fpath)
|
||||
print('Checkpoint saved to "{}"'.format(fpath))
|
||||
|
||||
# save current model name
|
||||
checkpoint_file = osp.join(save_dir, "checkpoint")
|
||||
checkpoint = open(checkpoint_file, "w+")
|
||||
checkpoint.write("{}\n".format(osp.basename(fpath)))
|
||||
checkpoint.close()
|
||||
|
||||
if is_best:
|
||||
best_fpath = osp.join(osp.dirname(fpath), "model-best.pth.tar")
|
||||
shutil.copy(fpath, best_fpath)
|
||||
print('Best checkpoint saved to "{}"'.format(best_fpath))
|
||||
|
||||
|
||||
def load_checkpoint(fpath):
|
||||
r"""Load checkpoint.
|
||||
|
||||
``UnicodeDecodeError`` can be well handled, which means
|
||||
python2-saved files can be read from python3.
|
||||
|
||||
Args:
|
||||
fpath (str): path to checkpoint.
|
||||
|
||||
Returns:
|
||||
dict
|
||||
|
||||
Examples::
|
||||
>>> fpath = 'log/my_model/model.pth.tar-10'
|
||||
>>> checkpoint = load_checkpoint(fpath)
|
||||
"""
|
||||
if fpath is None:
|
||||
raise ValueError("File path is None")
|
||||
|
||||
if not osp.exists(fpath):
|
||||
raise FileNotFoundError('File is not found at "{}"'.format(fpath))
|
||||
|
||||
map_location = None if torch.cuda.is_available() else "cpu"
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(fpath, map_location=map_location)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
pickle.load = partial(pickle.load, encoding="latin1")
|
||||
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
|
||||
checkpoint = torch.load(
|
||||
fpath, pickle_module=pickle, map_location=map_location
|
||||
)
|
||||
|
||||
except Exception:
|
||||
print('Unable to load checkpoint from "{}"'.format(fpath))
|
||||
raise
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def resume_from_checkpoint(fdir, model, optimizer=None, scheduler=None):
|
||||
r"""Resume training from a checkpoint.
|
||||
|
||||
This will load (1) model weights and (2) ``state_dict``
|
||||
of optimizer if ``optimizer`` is not None.
|
||||
|
||||
Args:
|
||||
fdir (str): directory where the model was saved.
|
||||
model (nn.Module): model.
|
||||
optimizer (Optimizer, optional): an Optimizer.
|
||||
scheduler (Scheduler, optional): an Scheduler.
|
||||
|
||||
Returns:
|
||||
int: start_epoch.
|
||||
|
||||
Examples::
|
||||
>>> fdir = 'log/my_model'
|
||||
>>> start_epoch = resume_from_checkpoint(fdir, model, optimizer, scheduler)
|
||||
"""
|
||||
with open(osp.join(fdir, "checkpoint"), "r") as checkpoint:
|
||||
model_name = checkpoint.readlines()[0].strip("\n")
|
||||
fpath = osp.join(fdir, model_name)
|
||||
|
||||
print('Loading checkpoint from "{}"'.format(fpath))
|
||||
checkpoint = load_checkpoint(fpath)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
print("Loaded model weights")
|
||||
|
||||
if optimizer is not None and "optimizer" in checkpoint.keys():
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
print("Loaded optimizer")
|
||||
|
||||
if scheduler is not None and "scheduler" in checkpoint.keys():
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
print("Loaded scheduler")
|
||||
|
||||
start_epoch = checkpoint["epoch"]
|
||||
print("Previous epoch: {}".format(start_epoch))
|
||||
|
||||
return start_epoch
|
||||
|
||||
|
||||
def adjust_learning_rate(
|
||||
optimizer,
|
||||
base_lr,
|
||||
epoch,
|
||||
stepsize=20,
|
||||
gamma=0.1,
|
||||
linear_decay=False,
|
||||
final_lr=0,
|
||||
max_epoch=100,
|
||||
):
|
||||
r"""Adjust learning rate.
|
||||
|
||||
Deprecated.
|
||||
"""
|
||||
if linear_decay:
|
||||
# linearly decay learning rate from base_lr to final_lr
|
||||
frac_done = epoch / max_epoch
|
||||
lr = frac_done*final_lr + (1.0-frac_done) * base_lr
|
||||
else:
|
||||
# decay learning rate by gamma for every stepsize
|
||||
lr = base_lr * (gamma**(epoch // stepsize))
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def set_bn_to_eval(m):
|
||||
r"""Set BatchNorm layers to eval mode."""
|
||||
# 1. no update for running mean and var
|
||||
# 2. scale and shift parameters are still trainable
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("BatchNorm") != -1:
|
||||
m.eval()
|
||||
|
||||
|
||||
def open_all_layers(model):
|
||||
r"""Open all layers in model for training.
|
||||
|
||||
Examples::
|
||||
>>> open_all_layers(model)
|
||||
"""
|
||||
model.train()
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
|
||||
def open_specified_layers(model, open_layers):
|
||||
r"""Open specified layers in model for training while keeping
|
||||
other layers frozen.
|
||||
|
||||
Args:
|
||||
model (nn.Module): neural net model.
|
||||
open_layers (str or list): layers open for training.
|
||||
|
||||
Examples::
|
||||
>>> # Only model.classifier will be updated.
|
||||
>>> open_layers = 'classifier'
|
||||
>>> open_specified_layers(model, open_layers)
|
||||
>>> # Only model.fc and model.classifier will be updated.
|
||||
>>> open_layers = ['fc', 'classifier']
|
||||
>>> open_specified_layers(model, open_layers)
|
||||
"""
|
||||
if isinstance(model, nn.DataParallel):
|
||||
model = model.module
|
||||
|
||||
if isinstance(open_layers, str):
|
||||
open_layers = [open_layers]
|
||||
|
||||
for layer in open_layers:
|
||||
assert hasattr(
|
||||
model, layer
|
||||
), '"{}" is not an attribute of the model, please provide the correct name'.format(
|
||||
layer
|
||||
)
|
||||
|
||||
for name, module in model.named_children():
|
||||
if name in open_layers:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad = True
|
||||
else:
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
|
||||
def count_num_param(model):
|
||||
r"""Count number of parameters in a model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): network model.
|
||||
|
||||
Examples::
|
||||
>>> model_size = count_num_param(model)
|
||||
"""
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def load_pretrained_weights(model, weight_path):
|
||||
r"""Load pretrianed weights to model.
|
||||
|
||||
Features::
|
||||
- Incompatible layers (unmatched in name or size) will be ignored.
|
||||
- Can automatically deal with keys containing "module.".
|
||||
|
||||
Args:
|
||||
model (nn.Module): network model.
|
||||
weight_path (str): path to pretrained weights.
|
||||
|
||||
Examples::
|
||||
>>> weight_path = 'log/my_model/model-best.pth.tar'
|
||||
>>> load_pretrained_weights(model, weight_path)
|
||||
"""
|
||||
checkpoint = load_checkpoint(weight_path)
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
model_dict = model.state_dict()
|
||||
new_state_dict = OrderedDict()
|
||||
matched_layers, discarded_layers = [], []
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("module."):
|
||||
k = k[7:] # discard module.
|
||||
|
||||
if k in model_dict and model_dict[k].size() == v.size():
|
||||
new_state_dict[k] = v
|
||||
matched_layers.append(k)
|
||||
else:
|
||||
discarded_layers.append(k)
|
||||
|
||||
model_dict.update(new_state_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
if len(matched_layers) == 0:
|
||||
warnings.warn(
|
||||
'The pretrained weights "{}" cannot be loaded, '
|
||||
"please check the key names manually "
|
||||
"(** ignored and continue **)".format(weight_path)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
'Successfully loaded pretrained weights from "{}"'.
|
||||
format(weight_path)
|
||||
)
|
||||
if len(discarded_layers) > 0:
|
||||
print(
|
||||
"** The following layers are discarded "
|
||||
"due to unmatched keys or layer size: {}".
|
||||
format(discarded_layers)
|
||||
)
|
||||
|
||||
|
||||
def init_network_weights(model, init_type="normal", gain=0.02):
|
||||
|
||||
def _init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
|
||||
if hasattr(m, "weight") and (
|
||||
classname.find("Conv") != -1 or classname.find("Linear") != -1
|
||||
):
|
||||
if init_type == "normal":
|
||||
nn.init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == "xavier":
|
||||
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == "kaiming":
|
||||
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
||||
elif init_type == "orthogonal":
|
||||
nn.init.orthogonal_(m.weight.data, gain=gain)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"initialization method {} is not implemented".
|
||||
format(init_type)
|
||||
)
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
elif classname.find("InstanceNorm") != -1:
|
||||
if m.weight is not None and m.bias is not None:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
model.apply(_init_func)
|
||||
Reference in New Issue
Block a user