736 lines
24 KiB
Python
736 lines
24 KiB
Python
import json
|
|
import time
|
|
import numpy as np
|
|
import os.path as osp
|
|
import datetime
|
|
from collections import OrderedDict
|
|
import torch
|
|
import torch.nn as nn
|
|
from tqdm import tqdm
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from dassl.data import DataManager
|
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
|
from dassl.utils import (
|
|
MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
|
|
save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
|
|
load_pretrained_weights
|
|
)
|
|
from dassl.modeling import build_head, build_backbone
|
|
from dassl.evaluation import build_evaluator
|
|
|
|
|
|
class SimpleNet(nn.Module):
|
|
"""A simple neural network composed of a CNN backbone
|
|
and optionally a head such as mlp for classification.
|
|
"""
|
|
|
|
def __init__(self, cfg, model_cfg, num_classes, **kwargs):
|
|
super().__init__()
|
|
self.backbone = build_backbone(
|
|
model_cfg.BACKBONE.NAME,
|
|
verbose=cfg.VERBOSE,
|
|
pretrained=model_cfg.BACKBONE.PRETRAINED,
|
|
**kwargs,
|
|
)
|
|
fdim = self.backbone.out_features
|
|
|
|
self.head = None
|
|
if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:
|
|
self.head = build_head(
|
|
model_cfg.HEAD.NAME,
|
|
verbose=cfg.VERBOSE,
|
|
in_features=fdim,
|
|
hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,
|
|
activation=model_cfg.HEAD.ACTIVATION,
|
|
bn=model_cfg.HEAD.BN,
|
|
dropout=model_cfg.HEAD.DROPOUT,
|
|
**kwargs,
|
|
)
|
|
fdim = self.head.out_features
|
|
|
|
self.classifier = None
|
|
if num_classes > 0:
|
|
self.classifier = nn.Linear(fdim, num_classes)
|
|
|
|
self._fdim = fdim
|
|
|
|
@property
|
|
def fdim(self):
|
|
return self._fdim
|
|
|
|
def forward(self, x, return_feature=False):
|
|
f = self.backbone(x)
|
|
if self.head is not None:
|
|
f = self.head(f)
|
|
|
|
if self.classifier is None:
|
|
return f
|
|
|
|
y = self.classifier(f)
|
|
|
|
if return_feature:
|
|
return y, f
|
|
|
|
return y
|
|
|
|
|
|
class TrainerBase:
|
|
"""Base class for iterative trainer."""
|
|
|
|
def __init__(self):
|
|
self._models = OrderedDict()
|
|
self._optims = OrderedDict()
|
|
self._scheds = OrderedDict()
|
|
self._writer = None
|
|
|
|
def register_model(self, name="model", model=None, optim=None, sched=None):
|
|
if self.__dict__.get("_models") is None:
|
|
raise AttributeError(
|
|
"Cannot assign model before super().__init__() call"
|
|
)
|
|
|
|
if self.__dict__.get("_optims") is None:
|
|
raise AttributeError(
|
|
"Cannot assign optim before super().__init__() call"
|
|
)
|
|
|
|
if self.__dict__.get("_scheds") is None:
|
|
raise AttributeError(
|
|
"Cannot assign sched before super().__init__() call"
|
|
)
|
|
|
|
assert name not in self._models, "Found duplicate model names"
|
|
|
|
self._models[name] = model
|
|
self._optims[name] = optim
|
|
self._scheds[name] = sched
|
|
|
|
def get_model_names(self, names=None):
|
|
names_real = list(self._models.keys())
|
|
if names is not None:
|
|
names = tolist_if_not(names)
|
|
for name in names:
|
|
assert name in names_real
|
|
return names
|
|
else:
|
|
return names_real
|
|
|
|
def save_model(self, epoch, directory, is_best=False, model_name=""):
|
|
names = self.get_model_names()
|
|
|
|
for name in names:
|
|
model_dict = self._models[name].state_dict()
|
|
|
|
optim_dict = None
|
|
if self._optims[name] is not None:
|
|
optim_dict = self._optims[name].state_dict()
|
|
|
|
sched_dict = None
|
|
if self._scheds[name] is not None:
|
|
sched_dict = self._scheds[name].state_dict()
|
|
|
|
save_checkpoint(
|
|
{
|
|
"state_dict": model_dict,
|
|
"epoch": epoch + 1,
|
|
"optimizer": optim_dict,
|
|
"scheduler": sched_dict,
|
|
},
|
|
osp.join(directory, name),
|
|
is_best=is_best,
|
|
model_name=model_name,
|
|
)
|
|
|
|
def resume_model_if_exist(self, directory):
|
|
names = self.get_model_names()
|
|
file_missing = False
|
|
|
|
for name in names:
|
|
path = osp.join(directory, name)
|
|
if not osp.exists(path):
|
|
file_missing = True
|
|
break
|
|
|
|
if file_missing:
|
|
print("No checkpoint found, train from scratch")
|
|
return 0
|
|
|
|
print(
|
|
'Found checkpoint in "{}". Will resume training'.format(directory)
|
|
)
|
|
|
|
for name in names:
|
|
path = osp.join(directory, name)
|
|
start_epoch = resume_from_checkpoint(
|
|
path, self._models[name], self._optims[name],
|
|
self._scheds[name]
|
|
)
|
|
|
|
return start_epoch
|
|
|
|
def load_model(self, directory, epoch=None):
|
|
if not directory:
|
|
print(
|
|
"Note that load_model() is skipped as no pretrained "
|
|
"model is given (ignore this if it's done on purpose)"
|
|
)
|
|
return
|
|
|
|
names = self.get_model_names()
|
|
|
|
# By default, the best model is loaded
|
|
model_file = "model-best.pth.tar"
|
|
|
|
if epoch is not None:
|
|
model_file = "model.pth.tar-" + str(epoch)
|
|
|
|
for name in names:
|
|
model_path = osp.join(directory, name, model_file)
|
|
|
|
if not osp.exists(model_path):
|
|
raise FileNotFoundError(
|
|
'Model not found at "{}"'.format(model_path)
|
|
)
|
|
|
|
checkpoint = load_checkpoint(model_path)
|
|
state_dict = checkpoint["state_dict"]
|
|
epoch = checkpoint["epoch"]
|
|
|
|
print(
|
|
"Loading weights to {} "
|
|
'from "{}" (epoch = {})'.format(name, model_path, epoch)
|
|
)
|
|
self._models[name].load_state_dict(state_dict)
|
|
|
|
def set_model_mode(self, mode="train", names=None):
|
|
names = self.get_model_names(names)
|
|
|
|
for name in names:
|
|
if mode == "train":
|
|
self._models[name].train()
|
|
elif mode in ["test", "eval"]:
|
|
self._models[name].eval()
|
|
else:
|
|
raise KeyError
|
|
|
|
def update_lr(self, names=None):
|
|
names = self.get_model_names(names)
|
|
|
|
for name in names:
|
|
if self._scheds[name] is not None:
|
|
self._scheds[name].step()
|
|
|
|
def detect_anomaly(self, loss):
|
|
if not torch.isfinite(loss).all():
|
|
raise FloatingPointError("Loss is infinite or NaN!")
|
|
|
|
def init_writer(self, log_dir):
|
|
if self.__dict__.get("_writer") is None or self._writer is None:
|
|
print(
|
|
"Initializing summary writer for tensorboard "
|
|
"with log_dir={}".format(log_dir)
|
|
)
|
|
self._writer = SummaryWriter(log_dir=log_dir)
|
|
|
|
def close_writer(self):
|
|
if self._writer is not None:
|
|
self._writer.close()
|
|
|
|
def write_scalar(self, tag, scalar_value, global_step=None):
|
|
if self._writer is None:
|
|
# Do nothing if writer is not initialized
|
|
# Note that writer is only used when training is needed
|
|
pass
|
|
else:
|
|
self._writer.add_scalar(tag, scalar_value, global_step)
|
|
|
|
def train(self, start_epoch, max_epoch):
|
|
"""Generic training loops."""
|
|
self.start_epoch = start_epoch
|
|
self.max_epoch = max_epoch
|
|
|
|
self.before_train()
|
|
for self.epoch in range(self.start_epoch, self.max_epoch):
|
|
self.before_epoch()
|
|
self.run_epoch()
|
|
self.after_epoch()
|
|
self.after_train()
|
|
|
|
def before_train(self):
|
|
pass
|
|
|
|
def after_train(self):
|
|
pass
|
|
|
|
def before_epoch(self):
|
|
pass
|
|
|
|
def after_epoch(self):
|
|
pass
|
|
|
|
def run_epoch(self):
|
|
raise NotImplementedError
|
|
|
|
def test(self):
|
|
raise NotImplementedError
|
|
|
|
def parse_batch_train(self, batch):
|
|
raise NotImplementedError
|
|
|
|
def parse_batch_test(self, batch):
|
|
raise NotImplementedError
|
|
|
|
def forward_backward(self, batch):
|
|
raise NotImplementedError
|
|
|
|
def model_inference(self, input):
|
|
raise NotImplementedError
|
|
|
|
def model_zero_grad(self, names=None):
|
|
names = self.get_model_names(names)
|
|
for name in names:
|
|
if self._optims[name] is not None:
|
|
self._optims[name].zero_grad()
|
|
|
|
def model_backward(self, loss):
|
|
self.detect_anomaly(loss)
|
|
loss.backward()
|
|
|
|
def model_update(self, names=None):
|
|
names = self.get_model_names(names)
|
|
for name in names:
|
|
if self._optims[name] is not None:
|
|
self._optims[name].step()
|
|
|
|
def model_backward_and_update(self, loss, names=None):
|
|
self.model_zero_grad(names)
|
|
self.model_backward(loss)
|
|
self.model_update(names)
|
|
|
|
def prograd_backward_and_update(
|
|
self, loss_a, loss_b, lambda_=1, names=None
|
|
):
|
|
# loss_b not increase is okay
|
|
# loss_a has to decline
|
|
self.model_zero_grad(names)
|
|
# get name of the model parameters
|
|
names = self.get_model_names(names)
|
|
# backward loss_a
|
|
self.detect_anomaly(loss_b)
|
|
loss_b.backward(retain_graph=True)
|
|
# normalize gradient
|
|
b_grads = []
|
|
for name in names:
|
|
for p in self._models[name].parameters():
|
|
b_grads.append(p.grad.clone())
|
|
|
|
# optimizer don't step
|
|
for name in names:
|
|
self._optims[name].zero_grad()
|
|
|
|
# backward loss_a
|
|
self.detect_anomaly(loss_a)
|
|
loss_a.backward()
|
|
for name in names:
|
|
for p, b_grad in zip(self._models[name].parameters(), b_grads):
|
|
# calculate cosine distance
|
|
b_grad_norm = b_grad / torch.linalg.norm(b_grad)
|
|
a_grad = p.grad.clone()
|
|
a_grad_norm = a_grad / torch.linalg.norm(a_grad)
|
|
|
|
if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:
|
|
p.grad = a_grad - lambda_ * torch.dot(
|
|
a_grad.flatten(), b_grad_norm.flatten()
|
|
) * b_grad_norm
|
|
|
|
# optimizer
|
|
for name in names:
|
|
self._optims[name].step()
|
|
|
|
|
|
class SimpleTrainer(TrainerBase):
|
|
"""A simple trainer class implementing generic functions."""
|
|
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.check_cfg(cfg)
|
|
|
|
if torch.cuda.is_available() and cfg.USE_CUDA:
|
|
self.device = torch.device("cuda")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
|
|
# Save as attributes some frequently used variables
|
|
self.start_epoch = self.epoch = 0
|
|
self.max_epoch = cfg.OPTIM.MAX_EPOCH
|
|
self.output_dir = cfg.OUTPUT_DIR
|
|
|
|
self.cfg = cfg
|
|
self.build_data_loader()
|
|
self.build_model()
|
|
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
|
|
self.best_result = -np.inf
|
|
|
|
def check_cfg(self, cfg):
|
|
"""Check whether some variables are set correctly for
|
|
the trainer (optional).
|
|
|
|
For example, a trainer might require a particular sampler
|
|
for training such as 'RandomDomainSampler', so it is good
|
|
to do the checking:
|
|
|
|
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
|
|
"""
|
|
pass
|
|
|
|
def build_data_loader(self):
|
|
"""Create essential data-related attributes.
|
|
|
|
A re-implementation of this method must create the
|
|
same attributes (except self.dm).
|
|
"""
|
|
dm = DataManager(self.cfg)
|
|
|
|
self.train_loader_x = dm.train_loader_x
|
|
self.train_loader_u = dm.train_loader_u # optional, can be None
|
|
self.val_loader = dm.val_loader # optional, can be None
|
|
self.test_loader = dm.test_loader
|
|
self.num_classes = dm.num_classes
|
|
self.num_source_domains = dm.num_source_domains
|
|
self.lab2cname = dm.lab2cname # dict {label: classname}
|
|
|
|
self.dm = dm
|
|
|
|
def build_model(self):
|
|
"""Build and register model.
|
|
|
|
The default builds a classification model along with its
|
|
optimizer and scheduler.
|
|
|
|
Custom trainers can re-implement this method if necessary.
|
|
"""
|
|
cfg = self.cfg
|
|
|
|
print("Building model")
|
|
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
|
if cfg.MODEL.INIT_WEIGHTS:
|
|
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
|
|
self.model.to(self.device)
|
|
print("# params: {:,}".format(count_num_param(self.model)))
|
|
self.optim = build_optimizer(self.model, cfg.OPTIM)
|
|
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
|
self.register_model("model", self.model, self.optim, self.sched)
|
|
|
|
device_count = torch.cuda.device_count()
|
|
if device_count > 1:
|
|
print(
|
|
f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel"
|
|
)
|
|
self.model = nn.DataParallel(self.model)
|
|
|
|
def train(self):
|
|
super().train(self.start_epoch, self.max_epoch)
|
|
|
|
def before_train(self):
|
|
directory = self.cfg.OUTPUT_DIR
|
|
if self.cfg.RESUME:
|
|
directory = self.cfg.RESUME
|
|
self.start_epoch = self.resume_model_if_exist(directory)
|
|
|
|
# Initialize summary writer
|
|
writer_dir = osp.join(self.output_dir, "tensorboard")
|
|
mkdir_if_missing(writer_dir)
|
|
self.init_writer(writer_dir)
|
|
|
|
# Remember the starting time (for computing the elapsed time)
|
|
self.time_start = time.time()
|
|
|
|
def after_train(self):
|
|
print("Finished training")
|
|
|
|
do_test = not self.cfg.TEST.NO_TEST
|
|
if do_test:
|
|
if self.cfg.TEST.FINAL_MODEL == "best_val":
|
|
print("Deploy the model with the best val performance")
|
|
self.load_model(self.output_dir)
|
|
self.test()
|
|
|
|
# Show elapsed time
|
|
elapsed = round(time.time() - self.time_start)
|
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
|
print("Elapsed: {}".format(elapsed))
|
|
|
|
# Close writer
|
|
self.close_writer()
|
|
|
|
def after_epoch(self):
|
|
last_epoch = (self.epoch + 1) == self.max_epoch
|
|
do_test = not self.cfg.TEST.NO_TEST
|
|
meet_checkpoint_freq = (
|
|
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
|
|
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
|
|
)
|
|
|
|
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
|
|
curr_result = self.test(split="val")
|
|
is_best = curr_result > self.best_result
|
|
if is_best:
|
|
self.best_result = curr_result
|
|
self.save_model(
|
|
self.epoch,
|
|
self.output_dir,
|
|
model_name="model-best.pth.tar"
|
|
)
|
|
|
|
if meet_checkpoint_freq or last_epoch:
|
|
self.save_model(self.epoch, self.output_dir)
|
|
|
|
@torch.no_grad()
|
|
def output_test(self, split=None):
|
|
"""testing pipline, which could also output the results."""
|
|
self.set_model_mode("eval")
|
|
self.evaluator.reset()
|
|
|
|
output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')
|
|
res_json = {}
|
|
|
|
if split is None:
|
|
split = self.cfg.TEST.SPLIT
|
|
|
|
if split == "val" and self.val_loader is not None:
|
|
data_loader = self.val_loader
|
|
print("Do evaluation on {} set".format(split))
|
|
else:
|
|
data_loader = self.test_loader
|
|
print("Do evaluation on test set")
|
|
|
|
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
|
img_path = batch['impath']
|
|
input, label = self.parse_batch_test(batch)
|
|
output = self.model_inference(input)
|
|
self.evaluator.process(output, label)
|
|
for i in range(len(img_path)):
|
|
res_json[img_path[i]] = {
|
|
'predict': output[i].cpu().numpy().tolist(),
|
|
'gt': label[i].cpu().numpy().tolist()
|
|
}
|
|
with open(output_file, 'w') as f:
|
|
json.dump(res_json, f)
|
|
results = self.evaluator.evaluate()
|
|
|
|
for k, v in results.items():
|
|
tag = "{}/{}".format(split, k)
|
|
self.write_scalar(tag, v, self.epoch)
|
|
|
|
return list(results.values())[0]
|
|
|
|
@torch.no_grad()
|
|
def test(self, split=None):
|
|
"""A generic testing pipeline."""
|
|
self.set_model_mode("eval")
|
|
self.evaluator.reset()
|
|
|
|
if split is None:
|
|
split = self.cfg.TEST.SPLIT
|
|
|
|
if split == "val" and self.val_loader is not None:
|
|
data_loader = self.val_loader
|
|
print("Do evaluation on {} set".format(split))
|
|
else:
|
|
data_loader = self.test_loader
|
|
print("Do evaluation on test set")
|
|
|
|
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
|
input, label = self.parse_batch_test(batch)
|
|
output = self.model_inference(input)
|
|
self.evaluator.process(output, label)
|
|
|
|
results = self.evaluator.evaluate()
|
|
|
|
for k, v in results.items():
|
|
tag = "{}/{}".format(split, k)
|
|
self.write_scalar(tag, v, self.epoch)
|
|
|
|
return list(results.values())[0]
|
|
|
|
def model_inference(self, input):
|
|
return self.model(input)
|
|
|
|
def parse_batch_test(self, batch):
|
|
input = batch["img"]
|
|
label = batch["label"]
|
|
|
|
input = input.to(self.device)
|
|
label = label.to(self.device)
|
|
|
|
return input, label
|
|
|
|
def get_current_lr(self, names=None):
|
|
names = self.get_model_names(names)
|
|
name = names[0]
|
|
return self._optims[name].param_groups[0]["lr"]
|
|
|
|
|
|
class TrainerXU(SimpleTrainer):
|
|
"""A base trainer using both labeled and unlabeled data.
|
|
|
|
In the context of domain adaptation, labeled and unlabeled data
|
|
come from source and target domains respectively.
|
|
|
|
When it comes to semi-supervised learning, all data comes from the
|
|
same domain.
|
|
"""
|
|
|
|
def run_epoch(self):
|
|
self.set_model_mode("train")
|
|
losses = MetricMeter()
|
|
batch_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
|
|
# Decide to iterate over labeled or unlabeled dataset
|
|
len_train_loader_x = len(self.train_loader_x)
|
|
len_train_loader_u = len(self.train_loader_u)
|
|
if self.cfg.TRAIN.COUNT_ITER == "train_x":
|
|
self.num_batches = len_train_loader_x
|
|
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
|
|
self.num_batches = len_train_loader_u
|
|
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
|
|
self.num_batches = min(len_train_loader_x, len_train_loader_u)
|
|
else:
|
|
raise ValueError
|
|
|
|
train_loader_x_iter = iter(self.train_loader_x)
|
|
train_loader_u_iter = iter(self.train_loader_u)
|
|
|
|
end = time.time()
|
|
for self.batch_idx in range(self.num_batches):
|
|
try:
|
|
batch_x = next(train_loader_x_iter)
|
|
except StopIteration:
|
|
train_loader_x_iter = iter(self.train_loader_x)
|
|
batch_x = next(train_loader_x_iter)
|
|
|
|
try:
|
|
batch_u = next(train_loader_u_iter)
|
|
except StopIteration:
|
|
train_loader_u_iter = iter(self.train_loader_u)
|
|
batch_u = next(train_loader_u_iter)
|
|
|
|
data_time.update(time.time() - end)
|
|
loss_summary = self.forward_backward(batch_x, batch_u)
|
|
batch_time.update(time.time() - end)
|
|
losses.update(loss_summary)
|
|
|
|
if (
|
|
self.batch_idx + 1
|
|
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
|
|
nb_remain = 0
|
|
nb_remain += self.num_batches - self.batch_idx - 1
|
|
nb_remain += (
|
|
self.max_epoch - self.epoch - 1
|
|
) * self.num_batches
|
|
eta_seconds = batch_time.avg * nb_remain
|
|
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
print(
|
|
"epoch [{0}/{1}][{2}/{3}]\t"
|
|
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
|
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
|
"eta {eta}\t"
|
|
"{losses}\t"
|
|
"lr {lr:.6e}".format(
|
|
self.epoch + 1,
|
|
self.max_epoch,
|
|
self.batch_idx + 1,
|
|
self.num_batches,
|
|
batch_time=batch_time,
|
|
data_time=data_time,
|
|
eta=eta,
|
|
losses=losses,
|
|
lr=self.get_current_lr(),
|
|
)
|
|
)
|
|
|
|
n_iter = self.epoch * self.num_batches + self.batch_idx
|
|
for name, meter in losses.meters.items():
|
|
self.write_scalar("train/" + name, meter.avg, n_iter)
|
|
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
|
|
|
end = time.time()
|
|
|
|
def parse_batch_train(self, batch_x, batch_u):
|
|
input_x = batch_x["img"]
|
|
label_x = batch_x["label"]
|
|
input_u = batch_u["img"]
|
|
|
|
input_x = input_x.to(self.device)
|
|
label_x = label_x.to(self.device)
|
|
input_u = input_u.to(self.device)
|
|
|
|
return input_x, label_x, input_u
|
|
|
|
|
|
class TrainerX(SimpleTrainer):
|
|
"""A base trainer using labeled data only."""
|
|
|
|
def run_epoch(self):
|
|
self.set_model_mode("train")
|
|
losses = MetricMeter()
|
|
batch_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
self.num_batches = len(self.train_loader_x)
|
|
|
|
end = time.time()
|
|
for self.batch_idx, batch in enumerate(self.train_loader_x):
|
|
data_time.update(time.time() - end)
|
|
loss_summary = self.forward_backward(batch)
|
|
batch_time.update(time.time() - end)
|
|
losses.update(loss_summary)
|
|
|
|
if (
|
|
self.batch_idx + 1
|
|
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
|
|
nb_remain = 0
|
|
nb_remain += self.num_batches - self.batch_idx - 1
|
|
nb_remain += (
|
|
self.max_epoch - self.epoch - 1
|
|
) * self.num_batches
|
|
eta_seconds = batch_time.avg * nb_remain
|
|
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
print(
|
|
"epoch [{0}/{1}][{2}/{3}]\t"
|
|
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
|
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
|
"eta {eta}\t"
|
|
"{losses}\t"
|
|
"lr {lr:.6e}".format(
|
|
self.epoch + 1,
|
|
self.max_epoch,
|
|
self.batch_idx + 1,
|
|
self.num_batches,
|
|
batch_time=batch_time,
|
|
data_time=data_time,
|
|
eta=eta,
|
|
losses=losses,
|
|
lr=self.get_current_lr(),
|
|
)
|
|
)
|
|
|
|
n_iter = self.epoch * self.num_batches + self.batch_idx
|
|
for name, meter in losses.meters.items():
|
|
self.write_scalar("train/" + name, meter.avg, n_iter)
|
|
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
|
|
|
end = time.time()
|
|
|
|
def parse_batch_train(self, batch):
|
|
input = batch["img"]
|
|
label = batch["label"]
|
|
domain = batch["domain"]
|
|
|
|
input = input.to(self.device)
|
|
label = label.to(self.device)
|
|
domain = domain.to(self.device)
|
|
|
|
return input, label, domain
|