release code
This commit is contained in:
735
Dassl.ProGrad.pytorch/dassl/engine/trainer.py
Normal file
735
Dassl.ProGrad.pytorch/dassl/engine/trainer.py
Normal file
@@ -0,0 +1,735 @@
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import (
|
||||
MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
|
||||
save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
|
||||
load_pretrained_weights
|
||||
)
|
||||
from dassl.modeling import build_head, build_backbone
|
||||
from dassl.evaluation import build_evaluator
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
"""A simple neural network composed of a CNN backbone
|
||||
and optionally a head such as mlp for classification.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, model_cfg, num_classes, **kwargs):
|
||||
super().__init__()
|
||||
self.backbone = build_backbone(
|
||||
model_cfg.BACKBONE.NAME,
|
||||
verbose=cfg.VERBOSE,
|
||||
pretrained=model_cfg.BACKBONE.PRETRAINED,
|
||||
**kwargs,
|
||||
)
|
||||
fdim = self.backbone.out_features
|
||||
|
||||
self.head = None
|
||||
if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:
|
||||
self.head = build_head(
|
||||
model_cfg.HEAD.NAME,
|
||||
verbose=cfg.VERBOSE,
|
||||
in_features=fdim,
|
||||
hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,
|
||||
activation=model_cfg.HEAD.ACTIVATION,
|
||||
bn=model_cfg.HEAD.BN,
|
||||
dropout=model_cfg.HEAD.DROPOUT,
|
||||
**kwargs,
|
||||
)
|
||||
fdim = self.head.out_features
|
||||
|
||||
self.classifier = None
|
||||
if num_classes > 0:
|
||||
self.classifier = nn.Linear(fdim, num_classes)
|
||||
|
||||
self._fdim = fdim
|
||||
|
||||
@property
|
||||
def fdim(self):
|
||||
return self._fdim
|
||||
|
||||
def forward(self, x, return_feature=False):
|
||||
f = self.backbone(x)
|
||||
if self.head is not None:
|
||||
f = self.head(f)
|
||||
|
||||
if self.classifier is None:
|
||||
return f
|
||||
|
||||
y = self.classifier(f)
|
||||
|
||||
if return_feature:
|
||||
return y, f
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class TrainerBase:
|
||||
"""Base class for iterative trainer."""
|
||||
|
||||
def __init__(self):
|
||||
self._models = OrderedDict()
|
||||
self._optims = OrderedDict()
|
||||
self._scheds = OrderedDict()
|
||||
self._writer = None
|
||||
|
||||
def register_model(self, name="model", model=None, optim=None, sched=None):
|
||||
if self.__dict__.get("_models") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign model before super().__init__() call"
|
||||
)
|
||||
|
||||
if self.__dict__.get("_optims") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign optim before super().__init__() call"
|
||||
)
|
||||
|
||||
if self.__dict__.get("_scheds") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign sched before super().__init__() call"
|
||||
)
|
||||
|
||||
assert name not in self._models, "Found duplicate model names"
|
||||
|
||||
self._models[name] = model
|
||||
self._optims[name] = optim
|
||||
self._scheds[name] = sched
|
||||
|
||||
def get_model_names(self, names=None):
|
||||
names_real = list(self._models.keys())
|
||||
if names is not None:
|
||||
names = tolist_if_not(names)
|
||||
for name in names:
|
||||
assert name in names_real
|
||||
return names
|
||||
else:
|
||||
return names_real
|
||||
|
||||
def save_model(self, epoch, directory, is_best=False, model_name=""):
|
||||
names = self.get_model_names()
|
||||
|
||||
for name in names:
|
||||
model_dict = self._models[name].state_dict()
|
||||
|
||||
optim_dict = None
|
||||
if self._optims[name] is not None:
|
||||
optim_dict = self._optims[name].state_dict()
|
||||
|
||||
sched_dict = None
|
||||
if self._scheds[name] is not None:
|
||||
sched_dict = self._scheds[name].state_dict()
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"state_dict": model_dict,
|
||||
"epoch": epoch + 1,
|
||||
"optimizer": optim_dict,
|
||||
"scheduler": sched_dict,
|
||||
},
|
||||
osp.join(directory, name),
|
||||
is_best=is_best,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
def resume_model_if_exist(self, directory):
|
||||
names = self.get_model_names()
|
||||
file_missing = False
|
||||
|
||||
for name in names:
|
||||
path = osp.join(directory, name)
|
||||
if not osp.exists(path):
|
||||
file_missing = True
|
||||
break
|
||||
|
||||
if file_missing:
|
||||
print("No checkpoint found, train from scratch")
|
||||
return 0
|
||||
|
||||
print(
|
||||
'Found checkpoint in "{}". Will resume training'.format(directory)
|
||||
)
|
||||
|
||||
for name in names:
|
||||
path = osp.join(directory, name)
|
||||
start_epoch = resume_from_checkpoint(
|
||||
path, self._models[name], self._optims[name],
|
||||
self._scheds[name]
|
||||
)
|
||||
|
||||
return start_epoch
|
||||
|
||||
def load_model(self, directory, epoch=None):
|
||||
if not directory:
|
||||
print(
|
||||
"Note that load_model() is skipped as no pretrained "
|
||||
"model is given (ignore this if it's done on purpose)"
|
||||
)
|
||||
return
|
||||
|
||||
names = self.get_model_names()
|
||||
|
||||
# By default, the best model is loaded
|
||||
model_file = "model-best.pth.tar"
|
||||
|
||||
if epoch is not None:
|
||||
model_file = "model.pth.tar-" + str(epoch)
|
||||
|
||||
for name in names:
|
||||
model_path = osp.join(directory, name, model_file)
|
||||
|
||||
if not osp.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
'Model not found at "{}"'.format(model_path)
|
||||
)
|
||||
|
||||
checkpoint = load_checkpoint(model_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint["epoch"]
|
||||
|
||||
print(
|
||||
"Loading weights to {} "
|
||||
'from "{}" (epoch = {})'.format(name, model_path, epoch)
|
||||
)
|
||||
self._models[name].load_state_dict(state_dict)
|
||||
|
||||
def set_model_mode(self, mode="train", names=None):
|
||||
names = self.get_model_names(names)
|
||||
|
||||
for name in names:
|
||||
if mode == "train":
|
||||
self._models[name].train()
|
||||
elif mode in ["test", "eval"]:
|
||||
self._models[name].eval()
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
def update_lr(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
|
||||
for name in names:
|
||||
if self._scheds[name] is not None:
|
||||
self._scheds[name].step()
|
||||
|
||||
def detect_anomaly(self, loss):
|
||||
if not torch.isfinite(loss).all():
|
||||
raise FloatingPointError("Loss is infinite or NaN!")
|
||||
|
||||
def init_writer(self, log_dir):
|
||||
if self.__dict__.get("_writer") is None or self._writer is None:
|
||||
print(
|
||||
"Initializing summary writer for tensorboard "
|
||||
"with log_dir={}".format(log_dir)
|
||||
)
|
||||
self._writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def close_writer(self):
|
||||
if self._writer is not None:
|
||||
self._writer.close()
|
||||
|
||||
def write_scalar(self, tag, scalar_value, global_step=None):
|
||||
if self._writer is None:
|
||||
# Do nothing if writer is not initialized
|
||||
# Note that writer is only used when training is needed
|
||||
pass
|
||||
else:
|
||||
self._writer.add_scalar(tag, scalar_value, global_step)
|
||||
|
||||
def train(self, start_epoch, max_epoch):
|
||||
"""Generic training loops."""
|
||||
self.start_epoch = start_epoch
|
||||
self.max_epoch = max_epoch
|
||||
|
||||
self.before_train()
|
||||
for self.epoch in range(self.start_epoch, self.max_epoch):
|
||||
self.before_epoch()
|
||||
self.run_epoch()
|
||||
self.after_epoch()
|
||||
self.after_train()
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def after_train(self):
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def run_epoch(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_batch_test(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_backward(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def model_inference(self, input):
|
||||
raise NotImplementedError
|
||||
|
||||
def model_zero_grad(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
for name in names:
|
||||
if self._optims[name] is not None:
|
||||
self._optims[name].zero_grad()
|
||||
|
||||
def model_backward(self, loss):
|
||||
self.detect_anomaly(loss)
|
||||
loss.backward()
|
||||
|
||||
def model_update(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
for name in names:
|
||||
if self._optims[name] is not None:
|
||||
self._optims[name].step()
|
||||
|
||||
def model_backward_and_update(self, loss, names=None):
|
||||
self.model_zero_grad(names)
|
||||
self.model_backward(loss)
|
||||
self.model_update(names)
|
||||
|
||||
def prograd_backward_and_update(
|
||||
self, loss_a, loss_b, lambda_=1, names=None
|
||||
):
|
||||
# loss_b not increase is okay
|
||||
# loss_a has to decline
|
||||
self.model_zero_grad(names)
|
||||
# get name of the model parameters
|
||||
names = self.get_model_names(names)
|
||||
# backward loss_a
|
||||
self.detect_anomaly(loss_b)
|
||||
loss_b.backward(retain_graph=True)
|
||||
# normalize gradient
|
||||
b_grads = []
|
||||
for name in names:
|
||||
for p in self._models[name].parameters():
|
||||
b_grads.append(p.grad.clone())
|
||||
|
||||
# optimizer don't step
|
||||
for name in names:
|
||||
self._optims[name].zero_grad()
|
||||
|
||||
# backward loss_a
|
||||
self.detect_anomaly(loss_a)
|
||||
loss_a.backward()
|
||||
for name in names:
|
||||
for p, b_grad in zip(self._models[name].parameters(), b_grads):
|
||||
# calculate cosine distance
|
||||
b_grad_norm = b_grad / torch.linalg.norm(b_grad)
|
||||
a_grad = p.grad.clone()
|
||||
a_grad_norm = a_grad / torch.linalg.norm(a_grad)
|
||||
|
||||
if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:
|
||||
p.grad = a_grad - lambda_ * torch.dot(
|
||||
a_grad.flatten(), b_grad_norm.flatten()
|
||||
) * b_grad_norm
|
||||
|
||||
# optimizer
|
||||
for name in names:
|
||||
self._optims[name].step()
|
||||
|
||||
|
||||
class SimpleTrainer(TrainerBase):
|
||||
"""A simple trainer class implementing generic functions."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.check_cfg(cfg)
|
||||
|
||||
if torch.cuda.is_available() and cfg.USE_CUDA:
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# Save as attributes some frequently used variables
|
||||
self.start_epoch = self.epoch = 0
|
||||
self.max_epoch = cfg.OPTIM.MAX_EPOCH
|
||||
self.output_dir = cfg.OUTPUT_DIR
|
||||
|
||||
self.cfg = cfg
|
||||
self.build_data_loader()
|
||||
self.build_model()
|
||||
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
|
||||
self.best_result = -np.inf
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
"""Check whether some variables are set correctly for
|
||||
the trainer (optional).
|
||||
|
||||
For example, a trainer might require a particular sampler
|
||||
for training such as 'RandomDomainSampler', so it is good
|
||||
to do the checking:
|
||||
|
||||
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
|
||||
"""
|
||||
pass
|
||||
|
||||
def build_data_loader(self):
|
||||
"""Create essential data-related attributes.
|
||||
|
||||
A re-implementation of this method must create the
|
||||
same attributes (except self.dm).
|
||||
"""
|
||||
dm = DataManager(self.cfg)
|
||||
|
||||
self.train_loader_x = dm.train_loader_x
|
||||
self.train_loader_u = dm.train_loader_u # optional, can be None
|
||||
self.val_loader = dm.val_loader # optional, can be None
|
||||
self.test_loader = dm.test_loader
|
||||
self.num_classes = dm.num_classes
|
||||
self.num_source_domains = dm.num_source_domains
|
||||
self.lab2cname = dm.lab2cname # dict {label: classname}
|
||||
|
||||
self.dm = dm
|
||||
|
||||
def build_model(self):
|
||||
"""Build and register model.
|
||||
|
||||
The default builds a classification model along with its
|
||||
optimizer and scheduler.
|
||||
|
||||
Custom trainers can re-implement this method if necessary.
|
||||
"""
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building model")
|
||||
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||
if cfg.MODEL.INIT_WEIGHTS:
|
||||
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
|
||||
self.model.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.model)))
|
||||
self.optim = build_optimizer(self.model, cfg.OPTIM)
|
||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||
self.register_model("model", self.model, self.optim, self.sched)
|
||||
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count > 1:
|
||||
print(
|
||||
f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel"
|
||||
)
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
def train(self):
|
||||
super().train(self.start_epoch, self.max_epoch)
|
||||
|
||||
def before_train(self):
|
||||
directory = self.cfg.OUTPUT_DIR
|
||||
if self.cfg.RESUME:
|
||||
directory = self.cfg.RESUME
|
||||
self.start_epoch = self.resume_model_if_exist(directory)
|
||||
|
||||
# Initialize summary writer
|
||||
writer_dir = osp.join(self.output_dir, "tensorboard")
|
||||
mkdir_if_missing(writer_dir)
|
||||
self.init_writer(writer_dir)
|
||||
|
||||
# Remember the starting time (for computing the elapsed time)
|
||||
self.time_start = time.time()
|
||||
|
||||
def after_train(self):
|
||||
print("Finished training")
|
||||
|
||||
do_test = not self.cfg.TEST.NO_TEST
|
||||
if do_test:
|
||||
if self.cfg.TEST.FINAL_MODEL == "best_val":
|
||||
print("Deploy the model with the best val performance")
|
||||
self.load_model(self.output_dir)
|
||||
self.test()
|
||||
|
||||
# Show elapsed time
|
||||
elapsed = round(time.time() - self.time_start)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed: {}".format(elapsed))
|
||||
|
||||
# Close writer
|
||||
self.close_writer()
|
||||
|
||||
def after_epoch(self):
|
||||
last_epoch = (self.epoch + 1) == self.max_epoch
|
||||
do_test = not self.cfg.TEST.NO_TEST
|
||||
meet_checkpoint_freq = (
|
||||
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
|
||||
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
|
||||
)
|
||||
|
||||
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
|
||||
curr_result = self.test(split="val")
|
||||
is_best = curr_result > self.best_result
|
||||
if is_best:
|
||||
self.best_result = curr_result
|
||||
self.save_model(
|
||||
self.epoch,
|
||||
self.output_dir,
|
||||
model_name="model-best.pth.tar"
|
||||
)
|
||||
|
||||
if meet_checkpoint_freq or last_epoch:
|
||||
self.save_model(self.epoch, self.output_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def output_test(self, split=None):
|
||||
"""testing pipline, which could also output the results."""
|
||||
self.set_model_mode("eval")
|
||||
self.evaluator.reset()
|
||||
|
||||
output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')
|
||||
res_json = {}
|
||||
|
||||
if split is None:
|
||||
split = self.cfg.TEST.SPLIT
|
||||
|
||||
if split == "val" and self.val_loader is not None:
|
||||
data_loader = self.val_loader
|
||||
print("Do evaluation on {} set".format(split))
|
||||
else:
|
||||
data_loader = self.test_loader
|
||||
print("Do evaluation on test set")
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
img_path = batch['impath']
|
||||
input, label = self.parse_batch_test(batch)
|
||||
output = self.model_inference(input)
|
||||
self.evaluator.process(output, label)
|
||||
for i in range(len(img_path)):
|
||||
res_json[img_path[i]] = {
|
||||
'predict': output[i].cpu().numpy().tolist(),
|
||||
'gt': label[i].cpu().numpy().tolist()
|
||||
}
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(res_json, f)
|
||||
results = self.evaluator.evaluate()
|
||||
|
||||
for k, v in results.items():
|
||||
tag = "{}/{}".format(split, k)
|
||||
self.write_scalar(tag, v, self.epoch)
|
||||
|
||||
return list(results.values())[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def test(self, split=None):
|
||||
"""A generic testing pipeline."""
|
||||
self.set_model_mode("eval")
|
||||
self.evaluator.reset()
|
||||
|
||||
if split is None:
|
||||
split = self.cfg.TEST.SPLIT
|
||||
|
||||
if split == "val" and self.val_loader is not None:
|
||||
data_loader = self.val_loader
|
||||
print("Do evaluation on {} set".format(split))
|
||||
else:
|
||||
data_loader = self.test_loader
|
||||
print("Do evaluation on test set")
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
input, label = self.parse_batch_test(batch)
|
||||
output = self.model_inference(input)
|
||||
self.evaluator.process(output, label)
|
||||
|
||||
results = self.evaluator.evaluate()
|
||||
|
||||
for k, v in results.items():
|
||||
tag = "{}/{}".format(split, k)
|
||||
self.write_scalar(tag, v, self.epoch)
|
||||
|
||||
return list(results.values())[0]
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.model(input)
|
||||
|
||||
def parse_batch_test(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
|
||||
return input, label
|
||||
|
||||
def get_current_lr(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
name = names[0]
|
||||
return self._optims[name].param_groups[0]["lr"]
|
||||
|
||||
|
||||
class TrainerXU(SimpleTrainer):
|
||||
"""A base trainer using both labeled and unlabeled data.
|
||||
|
||||
In the context of domain adaptation, labeled and unlabeled data
|
||||
come from source and target domains respectively.
|
||||
|
||||
When it comes to semi-supervised learning, all data comes from the
|
||||
same domain.
|
||||
"""
|
||||
|
||||
def run_epoch(self):
|
||||
self.set_model_mode("train")
|
||||
losses = MetricMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
||||
# Decide to iterate over labeled or unlabeled dataset
|
||||
len_train_loader_x = len(self.train_loader_x)
|
||||
len_train_loader_u = len(self.train_loader_u)
|
||||
if self.cfg.TRAIN.COUNT_ITER == "train_x":
|
||||
self.num_batches = len_train_loader_x
|
||||
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
|
||||
self.num_batches = len_train_loader_u
|
||||
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
|
||||
self.num_batches = min(len_train_loader_x, len_train_loader_u)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
train_loader_x_iter = iter(self.train_loader_x)
|
||||
train_loader_u_iter = iter(self.train_loader_u)
|
||||
|
||||
end = time.time()
|
||||
for self.batch_idx in range(self.num_batches):
|
||||
try:
|
||||
batch_x = next(train_loader_x_iter)
|
||||
except StopIteration:
|
||||
train_loader_x_iter = iter(self.train_loader_x)
|
||||
batch_x = next(train_loader_x_iter)
|
||||
|
||||
try:
|
||||
batch_u = next(train_loader_u_iter)
|
||||
except StopIteration:
|
||||
train_loader_u_iter = iter(self.train_loader_u)
|
||||
batch_u = next(train_loader_u_iter)
|
||||
|
||||
data_time.update(time.time() - end)
|
||||
loss_summary = self.forward_backward(batch_x, batch_u)
|
||||
batch_time.update(time.time() - end)
|
||||
losses.update(loss_summary)
|
||||
|
||||
if (
|
||||
self.batch_idx + 1
|
||||
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
|
||||
nb_remain = 0
|
||||
nb_remain += self.num_batches - self.batch_idx - 1
|
||||
nb_remain += (
|
||||
self.max_epoch - self.epoch - 1
|
||||
) * self.num_batches
|
||||
eta_seconds = batch_time.avg * nb_remain
|
||||
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
print(
|
||||
"epoch [{0}/{1}][{2}/{3}]\t"
|
||||
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
||||
"eta {eta}\t"
|
||||
"{losses}\t"
|
||||
"lr {lr:.6e}".format(
|
||||
self.epoch + 1,
|
||||
self.max_epoch,
|
||||
self.batch_idx + 1,
|
||||
self.num_batches,
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
eta=eta,
|
||||
losses=losses,
|
||||
lr=self.get_current_lr(),
|
||||
)
|
||||
)
|
||||
|
||||
n_iter = self.epoch * self.num_batches + self.batch_idx
|
||||
for name, meter in losses.meters.items():
|
||||
self.write_scalar("train/" + name, meter.avg, n_iter)
|
||||
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
||||
|
||||
end = time.time()
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
label_x = batch_x["label"]
|
||||
input_u = batch_u["img"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
|
||||
return input_x, label_x, input_u
|
||||
|
||||
|
||||
class TrainerX(SimpleTrainer):
|
||||
"""A base trainer using labeled data only."""
|
||||
|
||||
def run_epoch(self):
|
||||
self.set_model_mode("train")
|
||||
losses = MetricMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
self.num_batches = len(self.train_loader_x)
|
||||
|
||||
end = time.time()
|
||||
for self.batch_idx, batch in enumerate(self.train_loader_x):
|
||||
data_time.update(time.time() - end)
|
||||
loss_summary = self.forward_backward(batch)
|
||||
batch_time.update(time.time() - end)
|
||||
losses.update(loss_summary)
|
||||
|
||||
if (
|
||||
self.batch_idx + 1
|
||||
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
|
||||
nb_remain = 0
|
||||
nb_remain += self.num_batches - self.batch_idx - 1
|
||||
nb_remain += (
|
||||
self.max_epoch - self.epoch - 1
|
||||
) * self.num_batches
|
||||
eta_seconds = batch_time.avg * nb_remain
|
||||
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
print(
|
||||
"epoch [{0}/{1}][{2}/{3}]\t"
|
||||
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
||||
"eta {eta}\t"
|
||||
"{losses}\t"
|
||||
"lr {lr:.6e}".format(
|
||||
self.epoch + 1,
|
||||
self.max_epoch,
|
||||
self.batch_idx + 1,
|
||||
self.num_batches,
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
eta=eta,
|
||||
losses=losses,
|
||||
lr=self.get_current_lr(),
|
||||
)
|
||||
)
|
||||
|
||||
n_iter = self.epoch * self.num_batches + self.batch_idx
|
||||
for name, meter in losses.meters.items():
|
||||
self.write_scalar("train/" + name, meter.avg, n_iter)
|
||||
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
||||
|
||||
end = time.time()
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
domain = batch["domain"]
|
||||
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
domain = domain.to(self.device)
|
||||
|
||||
return input, label, domain
|
||||
Reference in New Issue
Block a user