release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
from .tools import *
from .logger import *
from .meters import *
from .registry import *
from .torchtools import *

View File

@@ -0,0 +1,73 @@
import os
import sys
import time
import os.path as osp
from .tools import mkdir_if_missing
__all__ = ["Logger", "setup_logger"]
class Logger:
"""Write console output to external text file.
Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py>`_
Args:
fpath (str): directory to save logging file.
Examples::
>>> import sys
>>> import os.path as osp
>>> save_dir = 'output/experiment-1'
>>> log_name = 'train.log'
>>> sys.stdout = Logger(osp.join(save_dir, log_name))
"""
def __init__(self, fpath=None):
self.console = sys.stdout
self.file = None
if fpath is not None:
mkdir_if_missing(osp.dirname(fpath))
self.file = open(fpath, "w")
def __del__(self):
self.close()
def __enter__(self):
pass
def __exit__(self, *args):
self.close()
def write(self, msg):
self.console.write(msg)
if self.file is not None:
self.file.write(msg)
def flush(self):
self.console.flush()
if self.file is not None:
self.file.flush()
os.fsync(self.file.fileno())
def close(self):
self.console.close()
if self.file is not None:
self.file.close()
def setup_logger(output=None):
if output is None:
return
if output.endswith(".txt") or output.endswith(".log"):
fpath = output
else:
fpath = osp.join(output, "log.txt")
if osp.exists(fpath):
# make sure the existing log file is not over-written
fpath += time.strftime("-%Y-%m-%d-%H-%M-%S")
sys.stdout = Logger(fpath)

View File

@@ -0,0 +1,80 @@
from collections import defaultdict
import torch
__all__ = ["AverageMeter", "MetricMeter"]
class AverageMeter:
"""Compute and store the average and current value.
Examples::
>>> # 1. Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # 2. Update meter after every mini-batch update
>>> losses.update(loss_value, batch_size)
"""
def __init__(self, ema=False):
"""
Args:
ema (bool, optional): apply exponential moving average.
"""
self.ema = ema
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if isinstance(val, torch.Tensor):
val = val.item()
self.val = val
self.sum += val * n
self.count += n
if self.ema:
self.avg = self.avg * 0.9 + self.val * 0.1
else:
self.avg = self.sum / self.count
class MetricMeter:
"""Store the average and current value for a set of metrics.
Examples::
>>> # 1. Create an instance of MetricMeter
>>> metric = MetricMeter()
>>> # 2. Update using a dictionary as input
>>> input_dict = {'loss_1': value_1, 'loss_2': value_2}
>>> metric.update(input_dict)
>>> # 3. Convert to string and print
>>> print(str(metric))
"""
def __init__(self, delimiter="\t"):
self.meters = defaultdict(AverageMeter)
self.delimiter = delimiter
def update(self, input_dict):
if input_dict is None:
return
if not isinstance(input_dict, dict):
raise TypeError(
"Input to MetricMeter.update() must be a dictionary"
)
for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
v = v.item()
self.meters[k].update(v)
def __str__(self):
output_str = []
for name, meter in self.meters.items():
output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})")
return self.delimiter.join(output_str)

View File

@@ -0,0 +1,69 @@
"""
Modified from https://github.com/facebookresearch/fvcore
"""
__all__ = ["Registry"]
class Registry:
"""A registry providing name -> object mapping, to support
custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone(nn.Module):
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
self._name = name
self._obj_map = dict()
def _do_register(self, name, obj, force=False):
if name in self._obj_map and not force:
raise KeyError(
'An object named "{}" was already '
'registered in "{}" registry'.format(name, self._name)
)
self._obj_map[name] = obj
def register(self, obj=None, force=False):
if obj is None:
# Used as a decorator
def wrapper(fn_or_class):
name = fn_or_class.__name__
self._do_register(name, fn_or_class, force=force)
return fn_or_class
return wrapper
# Used as a function call
name = obj.__name__
self._do_register(name, obj, force=force)
def get(self, name):
if name not in self._obj_map:
raise KeyError(
'Object name "{}" does not exist '
'in "{}" registry'.format(name, self._name)
)
return self._obj_map[name]
def registered_names(self):
return list(self._obj_map.keys())

View File

@@ -0,0 +1,196 @@
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import os
import sys
import json
import time
import errno
import numpy as np
import random
import os.path as osp
import warnings
from difflib import SequenceMatcher
import PIL
import torch
from PIL import Image
__all__ = [
"mkdir_if_missing",
"check_isfile",
"read_json",
"write_json",
"set_random_seed",
"download_url",
"read_image",
"collect_env_info",
"listdir_nohidden",
"get_most_similar_str_to_a_from_b",
"check_availability",
"tolist_if_not",
]
def mkdir_if_missing(dirname):
"""Create dirname if it is missing."""
if not osp.exists(dirname):
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def check_isfile(fpath):
"""Check if the given path is a file.
Args:
fpath (str): file path.
Returns:
bool
"""
isfile = osp.isfile(fpath)
if not isfile:
warnings.warn('No file found at "{}"'.format(fpath))
return isfile
def read_json(fpath):
"""Read json file from a path."""
with open(fpath, "r") as f:
obj = json.load(f)
return obj
def write_json(obj, fpath):
"""Writes to a json file."""
mkdir_if_missing(osp.dirname(fpath))
with open(fpath, "w") as f:
json.dump(obj, f, indent=4, separators=(",", ": "))
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def download_url(url, dst):
"""Download file from a url to a destination.
Args:
url (str): url to download file.
dst (str): destination path.
"""
from six.moves import urllib
print('* url="{}"'.format(url))
print('* destination="{}"'.format(dst))
def _reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024*duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write(
"\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024*1024), speed, duration)
)
sys.stdout.flush()
urllib.request.urlretrieve(url, dst, _reporthook)
sys.stdout.write("\n")
def read_image(path):
"""Read image from path using ``PIL.Image``.
Args:
path (str): path to an image.
Returns:
PIL image
"""
if not osp.exists(path):
raise IOError("No file exists at {}".format(path))
while True:
try:
img = Image.open(path).convert("RGB")
return img
except IOError:
print(
"Cannot read image from {}, "
"probably due to heavy IO. Will re-try".format(path)
)
def collect_env_info():
"""Return env info as a string.
Code source: github.com/facebookresearch/maskrcnn-benchmark
"""
from torch.utils.collect_env import get_pretty_env_info
env_str = get_pretty_env_info()
env_str += "\n Pillow ({})".format(PIL.__version__)
return env_str
def listdir_nohidden(path, sort=False):
"""List non-hidden items in a directory.
Args:
path (str): directory path.
sort (bool): sort the items.
"""
items = [f for f in os.listdir(path) if not f.startswith(".")]
if sort:
items.sort()
return items
def get_most_similar_str_to_a_from_b(a, b):
"""Return the most similar string to a in b.
Args:
a (str): probe string.
b (list): a list of candidate strings.
"""
highest_sim = 0
chosen = None
for candidate in b:
sim = SequenceMatcher(None, a, candidate).ratio()
if sim >= highest_sim:
highest_sim = sim
chosen = candidate
return chosen
def check_availability(requested, available):
"""Check if an element is available in a list.
Args:
requested (str): probe string.
available (list): a list of available strings.
"""
if requested not in available:
psb_ans = get_most_similar_str_to_a_from_b(requested, available)
raise ValueError(
"The requested one is expected "
"to belong to {}, but got [{}] "
"(do you mean [{}]?)".format(available, requested, psb_ans)
)
def tolist_if_not(x):
"""Convert to a list."""
if not isinstance(x, list):
x = [x]
return x

View File

@@ -0,0 +1,356 @@
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import pickle
import shutil
import os.path as osp
import warnings
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
from .tools import mkdir_if_missing
__all__ = [
"save_checkpoint",
"load_checkpoint",
"resume_from_checkpoint",
"open_all_layers",
"open_specified_layers",
"count_num_param",
"load_pretrained_weights",
"init_network_weights",
]
def save_checkpoint(
state,
save_dir,
is_best=False,
remove_module_from_keys=True,
model_name=""
):
r"""Save checkpoint.
Args:
state (dict): dictionary.
save_dir (str): directory to save checkpoint.
is_best (bool, optional): if True, this checkpoint will be copied and named
``model-best.pth.tar``. Default is False.
remove_module_from_keys (bool, optional): whether to remove "module."
from layer names. Default is True.
model_name (str, optional): model name to save.
Examples::
>>> state = {
>>> 'state_dict': model.state_dict(),
>>> 'epoch': 10,
>>> 'optimizer': optimizer.state_dict()
>>> }
>>> save_checkpoint(state, 'log/my_model')
"""
mkdir_if_missing(save_dir)
if remove_module_from_keys:
# remove 'module.' in state_dict's keys
state_dict = state["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:]
new_state_dict[k] = v
state["state_dict"] = new_state_dict
# save model
epoch = state["epoch"]
if not model_name:
model_name = "model.pth.tar-" + str(epoch)
fpath = osp.join(save_dir, model_name)
torch.save(state, fpath)
print('Checkpoint saved to "{}"'.format(fpath))
# save current model name
checkpoint_file = osp.join(save_dir, "checkpoint")
checkpoint = open(checkpoint_file, "w+")
checkpoint.write("{}\n".format(osp.basename(fpath)))
checkpoint.close()
if is_best:
best_fpath = osp.join(osp.dirname(fpath), "model-best.pth.tar")
shutil.copy(fpath, best_fpath)
print('Best checkpoint saved to "{}"'.format(best_fpath))
def load_checkpoint(fpath):
r"""Load checkpoint.
``UnicodeDecodeError`` can be well handled, which means
python2-saved files can be read from python3.
Args:
fpath (str): path to checkpoint.
Returns:
dict
Examples::
>>> fpath = 'log/my_model/model.pth.tar-10'
>>> checkpoint = load_checkpoint(fpath)
"""
if fpath is None:
raise ValueError("File path is None")
if not osp.exists(fpath):
raise FileNotFoundError('File is not found at "{}"'.format(fpath))
map_location = None if torch.cuda.is_available() else "cpu"
try:
checkpoint = torch.load(fpath, map_location=map_location)
except UnicodeDecodeError:
pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
checkpoint = torch.load(
fpath, pickle_module=pickle, map_location=map_location
)
except Exception:
print('Unable to load checkpoint from "{}"'.format(fpath))
raise
return checkpoint
def resume_from_checkpoint(fdir, model, optimizer=None, scheduler=None):
r"""Resume training from a checkpoint.
This will load (1) model weights and (2) ``state_dict``
of optimizer if ``optimizer`` is not None.
Args:
fdir (str): directory where the model was saved.
model (nn.Module): model.
optimizer (Optimizer, optional): an Optimizer.
scheduler (Scheduler, optional): an Scheduler.
Returns:
int: start_epoch.
Examples::
>>> fdir = 'log/my_model'
>>> start_epoch = resume_from_checkpoint(fdir, model, optimizer, scheduler)
"""
with open(osp.join(fdir, "checkpoint"), "r") as checkpoint:
model_name = checkpoint.readlines()[0].strip("\n")
fpath = osp.join(fdir, model_name)
print('Loading checkpoint from "{}"'.format(fpath))
checkpoint = load_checkpoint(fpath)
model.load_state_dict(checkpoint["state_dict"])
print("Loaded model weights")
if optimizer is not None and "optimizer" in checkpoint.keys():
optimizer.load_state_dict(checkpoint["optimizer"])
print("Loaded optimizer")
if scheduler is not None and "scheduler" in checkpoint.keys():
scheduler.load_state_dict(checkpoint["scheduler"])
print("Loaded scheduler")
start_epoch = checkpoint["epoch"]
print("Previous epoch: {}".format(start_epoch))
return start_epoch
def adjust_learning_rate(
optimizer,
base_lr,
epoch,
stepsize=20,
gamma=0.1,
linear_decay=False,
final_lr=0,
max_epoch=100,
):
r"""Adjust learning rate.
Deprecated.
"""
if linear_decay:
# linearly decay learning rate from base_lr to final_lr
frac_done = epoch / max_epoch
lr = frac_done*final_lr + (1.0-frac_done) * base_lr
else:
# decay learning rate by gamma for every stepsize
lr = base_lr * (gamma**(epoch // stepsize))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def set_bn_to_eval(m):
r"""Set BatchNorm layers to eval mode."""
# 1. no update for running mean and var
# 2. scale and shift parameters are still trainable
classname = m.__class__.__name__
if classname.find("BatchNorm") != -1:
m.eval()
def open_all_layers(model):
r"""Open all layers in model for training.
Examples::
>>> open_all_layers(model)
"""
model.train()
for p in model.parameters():
p.requires_grad = True
def open_specified_layers(model, open_layers):
r"""Open specified layers in model for training while keeping
other layers frozen.
Args:
model (nn.Module): neural net model.
open_layers (str or list): layers open for training.
Examples::
>>> # Only model.classifier will be updated.
>>> open_layers = 'classifier'
>>> open_specified_layers(model, open_layers)
>>> # Only model.fc and model.classifier will be updated.
>>> open_layers = ['fc', 'classifier']
>>> open_specified_layers(model, open_layers)
"""
if isinstance(model, nn.DataParallel):
model = model.module
if isinstance(open_layers, str):
open_layers = [open_layers]
for layer in open_layers:
assert hasattr(
model, layer
), '"{}" is not an attribute of the model, please provide the correct name'.format(
layer
)
for name, module in model.named_children():
if name in open_layers:
module.train()
for p in module.parameters():
p.requires_grad = True
else:
module.eval()
for p in module.parameters():
p.requires_grad = False
def count_num_param(model):
r"""Count number of parameters in a model.
Args:
model (nn.Module): network model.
Examples::
>>> model_size = count_num_param(model)
"""
return sum(p.numel() for p in model.parameters())
def load_pretrained_weights(model, weight_path):
r"""Load pretrianed weights to model.
Features::
- Incompatible layers (unmatched in name or size) will be ignored.
- Can automatically deal with keys containing "module.".
Args:
model (nn.Module): network model.
weight_path (str): path to pretrained weights.
Examples::
>>> weight_path = 'log/my_model/model-best.pth.tar'
>>> load_pretrained_weights(model, weight_path)
"""
checkpoint = load_checkpoint(weight_path)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights "{}" cannot be loaded, '
"please check the key names manually "
"(** ignored and continue **)".format(weight_path)
)
else:
print(
'Successfully loaded pretrained weights from "{}"'.
format(weight_path)
)
if len(discarded_layers) > 0:
print(
"** The following layers are discarded "
"due to unmatched keys or layer size: {}".
format(discarded_layers)
)
def init_network_weights(model, init_type="normal", gain=0.02):
def _init_func(m):
classname = m.__class__.__name__
if hasattr(m, "weight") and (
classname.find("Conv") != -1 or classname.find("Linear") != -1
):
if init_type == "normal":
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "kaiming":
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
nn.init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError(
"initialization method {} is not implemented".
format(init_type)
)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm") != -1:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("InstanceNorm") != -1:
if m.weight is not None and m.bias is not None:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
model.apply(_init_func)