Files
2025-08-16 21:13:50 +08:00

736 lines
24 KiB
Python

import json
import time
import numpy as np
import os.path as osp
import datetime
from collections import OrderedDict
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import (
MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
load_pretrained_weights
)
from dassl.modeling import build_head, build_backbone
from dassl.evaluation import build_evaluator
class SimpleNet(nn.Module):
"""A simple neural network composed of a CNN backbone
and optionally a head such as mlp for classification.
"""
def __init__(self, cfg, model_cfg, num_classes, **kwargs):
super().__init__()
self.backbone = build_backbone(
model_cfg.BACKBONE.NAME,
verbose=cfg.VERBOSE,
pretrained=model_cfg.BACKBONE.PRETRAINED,
**kwargs,
)
fdim = self.backbone.out_features
self.head = None
if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:
self.head = build_head(
model_cfg.HEAD.NAME,
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,
activation=model_cfg.HEAD.ACTIVATION,
bn=model_cfg.HEAD.BN,
dropout=model_cfg.HEAD.DROPOUT,
**kwargs,
)
fdim = self.head.out_features
self.classifier = None
if num_classes > 0:
self.classifier = nn.Linear(fdim, num_classes)
self._fdim = fdim
@property
def fdim(self):
return self._fdim
def forward(self, x, return_feature=False):
f = self.backbone(x)
if self.head is not None:
f = self.head(f)
if self.classifier is None:
return f
y = self.classifier(f)
if return_feature:
return y, f
return y
class TrainerBase:
"""Base class for iterative trainer."""
def __init__(self):
self._models = OrderedDict()
self._optims = OrderedDict()
self._scheds = OrderedDict()
self._writer = None
def register_model(self, name="model", model=None, optim=None, sched=None):
if self.__dict__.get("_models") is None:
raise AttributeError(
"Cannot assign model before super().__init__() call"
)
if self.__dict__.get("_optims") is None:
raise AttributeError(
"Cannot assign optim before super().__init__() call"
)
if self.__dict__.get("_scheds") is None:
raise AttributeError(
"Cannot assign sched before super().__init__() call"
)
assert name not in self._models, "Found duplicate model names"
self._models[name] = model
self._optims[name] = optim
self._scheds[name] = sched
def get_model_names(self, names=None):
names_real = list(self._models.keys())
if names is not None:
names = tolist_if_not(names)
for name in names:
assert name in names_real
return names
else:
return names_real
def save_model(self, epoch, directory, is_best=False, model_name=""):
names = self.get_model_names()
for name in names:
model_dict = self._models[name].state_dict()
optim_dict = None
if self._optims[name] is not None:
optim_dict = self._optims[name].state_dict()
sched_dict = None
if self._scheds[name] is not None:
sched_dict = self._scheds[name].state_dict()
save_checkpoint(
{
"state_dict": model_dict,
"epoch": epoch + 1,
"optimizer": optim_dict,
"scheduler": sched_dict,
},
osp.join(directory, name),
is_best=is_best,
model_name=model_name,
)
def resume_model_if_exist(self, directory):
names = self.get_model_names()
file_missing = False
for name in names:
path = osp.join(directory, name)
if not osp.exists(path):
file_missing = True
break
if file_missing:
print("No checkpoint found, train from scratch")
return 0
print(
'Found checkpoint in "{}". Will resume training'.format(directory)
)
for name in names:
path = osp.join(directory, name)
start_epoch = resume_from_checkpoint(
path, self._models[name], self._optims[name],
self._scheds[name]
)
return start_epoch
def load_model(self, directory, epoch=None):
if not directory:
print(
"Note that load_model() is skipped as no pretrained "
"model is given (ignore this if it's done on purpose)"
)
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError(
'Model not found at "{}"'.format(model_path)
)
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
print(
"Loading weights to {} "
'from "{}" (epoch = {})'.format(name, model_path, epoch)
)
self._models[name].load_state_dict(state_dict)
def set_model_mode(self, mode="train", names=None):
names = self.get_model_names(names)
for name in names:
if mode == "train":
self._models[name].train()
elif mode in ["test", "eval"]:
self._models[name].eval()
else:
raise KeyError
def update_lr(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._scheds[name] is not None:
self._scheds[name].step()
def detect_anomaly(self, loss):
if not torch.isfinite(loss).all():
raise FloatingPointError("Loss is infinite or NaN!")
def init_writer(self, log_dir):
if self.__dict__.get("_writer") is None or self._writer is None:
print(
"Initializing summary writer for tensorboard "
"with log_dir={}".format(log_dir)
)
self._writer = SummaryWriter(log_dir=log_dir)
def close_writer(self):
if self._writer is not None:
self._writer.close()
def write_scalar(self, tag, scalar_value, global_step=None):
if self._writer is None:
# Do nothing if writer is not initialized
# Note that writer is only used when training is needed
pass
else:
self._writer.add_scalar(tag, scalar_value, global_step)
def train(self, start_epoch, max_epoch):
"""Generic training loops."""
self.start_epoch = start_epoch
self.max_epoch = max_epoch
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
self.before_epoch()
self.run_epoch()
self.after_epoch()
self.after_train()
def before_train(self):
pass
def after_train(self):
pass
def before_epoch(self):
pass
def after_epoch(self):
pass
def run_epoch(self):
raise NotImplementedError
def test(self):
raise NotImplementedError
def parse_batch_train(self, batch):
raise NotImplementedError
def parse_batch_test(self, batch):
raise NotImplementedError
def forward_backward(self, batch):
raise NotImplementedError
def model_inference(self, input):
raise NotImplementedError
def model_zero_grad(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].zero_grad()
def model_backward(self, loss):
self.detect_anomaly(loss)
loss.backward()
def model_update(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].step()
def model_backward_and_update(self, loss, names=None):
self.model_zero_grad(names)
self.model_backward(loss)
self.model_update(names)
def prograd_backward_and_update(
self, loss_a, loss_b, lambda_=1, names=None
):
# loss_b not increase is okay
# loss_a has to decline
self.model_zero_grad(names)
# get name of the model parameters
names = self.get_model_names(names)
# backward loss_a
self.detect_anomaly(loss_b)
loss_b.backward(retain_graph=True)
# normalize gradient
b_grads = []
for name in names:
for p in self._models[name].parameters():
b_grads.append(p.grad.clone())
# optimizer don't step
for name in names:
self._optims[name].zero_grad()
# backward loss_a
self.detect_anomaly(loss_a)
loss_a.backward()
for name in names:
for p, b_grad in zip(self._models[name].parameters(), b_grads):
# calculate cosine distance
b_grad_norm = b_grad / torch.linalg.norm(b_grad)
a_grad = p.grad.clone()
a_grad_norm = a_grad / torch.linalg.norm(a_grad)
if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:
p.grad = a_grad - lambda_ * torch.dot(
a_grad.flatten(), b_grad_norm.flatten()
) * b_grad_norm
# optimizer
for name in names:
self._optims[name].step()
class SimpleTrainer(TrainerBase):
"""A simple trainer class implementing generic functions."""
def __init__(self, cfg):
super().__init__()
self.check_cfg(cfg)
if torch.cuda.is_available() and cfg.USE_CUDA:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# Save as attributes some frequently used variables
self.start_epoch = self.epoch = 0
self.max_epoch = cfg.OPTIM.MAX_EPOCH
self.output_dir = cfg.OUTPUT_DIR
self.cfg = cfg
self.build_data_loader()
self.build_model()
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
self.best_result = -np.inf
def check_cfg(self, cfg):
"""Check whether some variables are set correctly for
the trainer (optional).
For example, a trainer might require a particular sampler
for training such as 'RandomDomainSampler', so it is good
to do the checking:
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
"""
pass
def build_data_loader(self):
"""Create essential data-related attributes.
A re-implementation of this method must create the
same attributes (except self.dm).
"""
dm = DataManager(self.cfg)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u # optional, can be None
self.val_loader = dm.val_loader # optional, can be None
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname # dict {label: classname}
self.dm = dm
def build_model(self):
"""Build and register model.
The default builds a classification model along with its
optimizer and scheduler.
Custom trainers can re-implement this method if necessary.
"""
cfg = self.cfg
print("Building model")
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
print("# params: {:,}".format(count_num_param(self.model)))
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("model", self.model, self.optim, self.sched)
device_count = torch.cuda.device_count()
if device_count > 1:
print(
f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel"
)
self.model = nn.DataParallel(self.model)
def train(self):
super().train(self.start_epoch, self.max_epoch)
def before_train(self):
directory = self.cfg.OUTPUT_DIR
if self.cfg.RESUME:
directory = self.cfg.RESUME
self.start_epoch = self.resume_model_if_exist(directory)
# Initialize summary writer
writer_dir = osp.join(self.output_dir, "tensorboard")
mkdir_if_missing(writer_dir)
self.init_writer(writer_dir)
# Remember the starting time (for computing the elapsed time)
self.time_start = time.time()
def after_train(self):
print("Finished training")
do_test = not self.cfg.TEST.NO_TEST
if do_test:
if self.cfg.TEST.FINAL_MODEL == "best_val":
print("Deploy the model with the best val performance")
self.load_model(self.output_dir)
self.test()
# Show elapsed time
elapsed = round(time.time() - self.time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed: {}".format(elapsed))
# Close writer
self.close_writer()
def after_epoch(self):
last_epoch = (self.epoch + 1) == self.max_epoch
do_test = not self.cfg.TEST.NO_TEST
meet_checkpoint_freq = (
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
)
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
curr_result = self.test(split="val")
is_best = curr_result > self.best_result
if is_best:
self.best_result = curr_result
self.save_model(
self.epoch,
self.output_dir,
model_name="model-best.pth.tar"
)
if meet_checkpoint_freq or last_epoch:
self.save_model(self.epoch, self.output_dir)
@torch.no_grad()
def output_test(self, split=None):
"""testing pipline, which could also output the results."""
self.set_model_mode("eval")
self.evaluator.reset()
output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')
res_json = {}
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
print("Do evaluation on {} set".format(split))
else:
data_loader = self.test_loader
print("Do evaluation on test set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
img_path = batch['impath']
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
for i in range(len(img_path)):
res_json[img_path[i]] = {
'predict': output[i].cpu().numpy().tolist(),
'gt': label[i].cpu().numpy().tolist()
}
with open(output_file, 'w') as f:
json.dump(res_json, f)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = "{}/{}".format(split, k)
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
@torch.no_grad()
def test(self, split=None):
"""A generic testing pipeline."""
self.set_model_mode("eval")
self.evaluator.reset()
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
print("Do evaluation on {} set".format(split))
else:
data_loader = self.test_loader
print("Do evaluation on test set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = "{}/{}".format(split, k)
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
def model_inference(self, input):
return self.model(input)
def parse_batch_test(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def get_current_lr(self, names=None):
names = self.get_model_names(names)
name = names[0]
return self._optims[name].param_groups[0]["lr"]
class TrainerXU(SimpleTrainer):
"""A base trainer using both labeled and unlabeled data.
In the context of domain adaptation, labeled and unlabeled data
come from source and target domains respectively.
When it comes to semi-supervised learning, all data comes from the
same domain.
"""
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
# Decide to iterate over labeled or unlabeled dataset
len_train_loader_x = len(self.train_loader_x)
len_train_loader_u = len(self.train_loader_u)
if self.cfg.TRAIN.COUNT_ITER == "train_x":
self.num_batches = len_train_loader_x
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
self.num_batches = len_train_loader_u
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
self.num_batches = min(len_train_loader_x, len_train_loader_u)
else:
raise ValueError
train_loader_x_iter = iter(self.train_loader_x)
train_loader_u_iter = iter(self.train_loader_u)
end = time.time()
for self.batch_idx in range(self.num_batches):
try:
batch_x = next(train_loader_x_iter)
except StopIteration:
train_loader_x_iter = iter(self.train_loader_x)
batch_x = next(train_loader_x_iter)
try:
batch_u = next(train_loader_u_iter)
except StopIteration:
train_loader_u_iter = iter(self.train_loader_u)
batch_u = next(train_loader_u_iter)
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch_x, batch_u)
batch_time.update(time.time() - end)
losses.update(loss_summary)
if (
self.batch_idx + 1
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
"epoch [{0}/{1}][{2}/{3}]\t"
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"eta {eta}\t"
"{losses}\t"
"lr {lr:.6e}".format(
self.epoch + 1,
self.max_epoch,
self.batch_idx + 1,
self.num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta,
losses=losses,
lr=self.get_current_lr(),
)
)
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
return input_x, label_x, input_u
class TrainerX(SimpleTrainer):
"""A base trainer using labeled data only."""
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.num_batches = len(self.train_loader_x)
end = time.time()
for self.batch_idx, batch in enumerate(self.train_loader_x):
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch)
batch_time.update(time.time() - end)
losses.update(loss_summary)
if (
self.batch_idx + 1
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
"epoch [{0}/{1}][{2}/{3}]\t"
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"eta {eta}\t"
"{losses}\t"
"lr {lr:.6e}".format(
self.epoch + 1,
self.max_epoch,
self.batch_idx + 1,
self.num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta,
losses=losses,
lr=self.get_current_lr(),
)
)
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
domain = batch["domain"]
input = input.to(self.device)
label = label.to(self.device)
domain = domain.to(self.device)
return input, label, domain