release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
from .build import TRAINER_REGISTRY, build_trainer # isort:skip
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip
from .da import *
from .dg import *
from .ssl import *

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
TRAINER_REGISTRY = Registry("TRAINER")
def build_trainer(cfg):
avai_trainers = TRAINER_REGISTRY.registered_names()
check_availability(cfg.TRAINER.NAME, avai_trainers)
if cfg.VERBOSE:
print("Loading trainer: {}".format(cfg.TRAINER.NAME))
return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)

View File

@@ -0,0 +1,9 @@
from .mcd import MCD
from .mme import MME
from .adda import ADDA
from .dael import DAEL
from .dann import DANN
from .adabn import AdaBN
from .m3sda import M3SDA
from .source_only import SourceOnly
from .self_ensembling import SelfEnsembling

View File

@@ -0,0 +1,38 @@
import torch
from dassl.utils import check_isfile
from dassl.engine import TRAINER_REGISTRY, TrainerXU
@TRAINER_REGISTRY.register()
class AdaBN(TrainerXU):
"""Adaptive Batch Normalization.
https://arxiv.org/abs/1603.04779.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.done_reset_bn_stats = False
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def before_epoch(self):
if not self.done_reset_bn_stats:
for m in self.model.modules():
classname = m.__class__.__name__
if classname.find("BatchNorm") != -1:
m.reset_running_stats()
self.done_reset_bn_stats = True
def forward_backward(self, batch_x, batch_u):
input_u = batch_u["img"].to(self.device)
with torch.no_grad():
self.model(input_u)
return None

View File

@@ -0,0 +1,85 @@
import copy
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import check_isfile, count_num_param, open_specified_layers
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling import build_head
@TRAINER_REGISTRY.register()
class ADDA(TrainerXU):
"""Adversarial Discriminative Domain Adaptation.
https://arxiv.org/abs/1702.05464.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.open_layers = ["backbone"]
if isinstance(self.model.head, nn.Module):
self.open_layers.append("head")
self.source_model = copy.deepcopy(self.model)
self.source_model.eval()
for param in self.source_model.parameters():
param.requires_grad_(False)
self.build_critic()
self.bce = nn.BCEWithLogitsLoss()
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim // 2],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
def forward_backward(self, batch_x, batch_u):
open_specified_layers(self.model, self.open_layers)
input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
_, feat_x = self.source_model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
logit_xd = self.critic(feat_x)
logit_ud = self.critic(feat_u.detach())
loss_critic = self.bce(logit_xd, domain_x)
loss_critic += self.bce(logit_ud, domain_u)
self.model_backward_and_update(loss_critic, "critic")
logit_ud = self.critic(feat_u)
loss_model = self.bce(logit_ud, 1 - domain_u)
self.model_backward_and_update(loss_model, "model")
loss_summary = {
"loss_critic": loss_critic.item(),
"loss_model": loss_model.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,210 @@
import torch
import torch.nn as nn
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot
class Experts(nn.Module):
def __init__(self, n_source, fdim, num_classes):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
)
self.softmax = nn.Softmax(dim=1)
def forward(self, i, x):
x = self.linears[i](x)
x = self.softmax(x)
return x
@TRAINER_REGISTRY.register()
class DAEL(TrainerXU):
"""Domain Adaptive Ensemble Learning.
https://arxiv.org/abs/2003.07325.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u
self.val_loader = dm.val_loader
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building E")
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
self.E.to(self.device)
print("# params: {:,}".format(count_num_param(self.E)))
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
self.register_model("E", self.E, self.optim_E, self.sched_E)
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
input_x = torch.split(input_x, self.split_batch, 0)
input_x2 = torch.split(input_x2, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Generate pseudo label
with torch.no_grad():
feat_u = self.F(input_u)
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1) # (B, K, C)
# Get the highest probability and index (label) for each expert
experts_max_p, experts_max_idx = pred_u.max(2) # (B, K)
# Get the most confident expert
max_expert_p, max_expert_idx = experts_max_p.max(1) # (B)
pseudo_label_u = []
for i, experts_label in zip(max_expert_idx, experts_max_idx):
pseudo_label_u.append(experts_label[i])
pseudo_label_u = torch.stack(pseudo_label_u, 0)
pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
pseudo_label_u = pseudo_label_u.to(self.device)
label_u_mask = (max_expert_p >= self.conf_thre).float()
loss_x = 0
loss_cr = 0
acc_x = 0
feat_x = [self.F(x) for x in input_x]
feat_x2 = [self.F(x) for x in input_x2]
feat_u2 = self.F(input_u2)
for feat_xi, feat_x2i, label_xi, i in zip(
feat_x, feat_x2, label_x, domain_x
):
cr_s = [j for j in domain_x if j != i]
# Learning expert
pred_xi = self.E(i, feat_xi)
loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
expert_label_xi = pred_xi.detach()
acc_x += compute_accuracy(pred_xi.detach(),
label_xi.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat_x2i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc_x /= self.n_domain
# Unsupervised loss
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u2)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1)
pred_u = pred_u.mean(1)
l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
loss_u = (l_u * label_u_mask).mean()
loss = 0
loss += loss_x
loss += loss_cr
loss += loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": acc_x,
"loss_cr": loss_cr.item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
input_x2 = batch_x["img2"]
label_x = batch_x["label"]
domain_x = batch_x["domain"]
input_u = batch_u["img"]
input_u2 = batch_u["img2"]
label_x = create_onehot(label_x, self.num_classes)
input_x = input_x.to(self.device)
input_x2 = input_x2.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_u2 = input_u2.to(self.device)
return input_x, input_x2, label_x, domain_x, input_u, input_u2
def model_inference(self, input):
f = self.F(input)
p = []
for k in range(self.num_source_domains):
p_k = self.E(k, f)
p_k = p_k.unsqueeze(1)
p.append(p_k)
p = torch.cat(p, 1)
p = p.mean(1)
return p

View File

@@ -0,0 +1,78 @@
import numpy as np
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling import build_head
from dassl.modeling.ops import ReverseGrad
@TRAINER_REGISTRY.register()
class DANN(TrainerXU):
"""Domain-Adversarial Neural Networks.
https://arxiv.org/abs/1505.07818.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.build_critic()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
global_step = self.batch_idx + self.epoch * self.num_batches
progress = global_step / (self.max_epoch * self.num_batches)
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
logit_x, feat_x = self.model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
loss_x = self.ce(logit_x, label_x)
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
output_xd = self.critic(feat_x)
output_ud = self.critic(feat_u)
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
loss = loss_x + loss_d
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,208 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
class PairClassifiers(nn.Module):
def __init__(self, fdim, num_classes):
super().__init__()
self.c1 = nn.Linear(fdim, num_classes)
self.c2 = nn.Linear(fdim, num_classes)
def forward(self, x):
z1 = self.c1(x)
if not self.training:
return z1
z2 = self.c2(x)
return z1, z2
@TRAINER_REGISTRY.register()
class M3SDA(TrainerXU):
"""Moment Matching for Multi-Source Domain Adaptation.
https://arxiv.org/abs/1812.01754.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
self.lmda = cfg.TRAINER.M3SDA.LMDA
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C")
self.C = nn.ModuleList(
[
PairClassifiers(fdim, self.num_classes)
for _ in range(self.num_source_domains)
]
)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, domain_x, input_u = parsed
input_x = torch.split(input_x, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Step A
loss_x = 0
feat_x = []
for x, y, d in zip(input_x, label_x, domain_x):
f = self.F(x)
z1, z2 = self.C[d](f)
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
feat_x.append(f)
loss_x /= self.n_domain
feat_u = self.F(input_u)
loss_msda = self.moment_distance(feat_x, feat_u)
loss_step_A = loss_x + loss_msda * self.lmda
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_u = self.F(input_u)
loss_x, loss_dis = 0, 0
for x, y, d in zip(input_x, label_x, domain_x):
with torch.no_grad():
f = self.F(x)
z1, z2 = self.C[d](f)
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
z1, z2 = self.C[d](feat_u)
p1 = F.softmax(z1, 1)
p2 = F.softmax(z2, 1)
loss_dis += self.discrepancy(p1, p2)
loss_x /= self.n_domain
loss_dis /= self.n_domain
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, "C")
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
loss_dis = 0
for d in domain_x:
z1, z2 = self.C[d](feat_u)
p1 = F.softmax(z1, 1)
p2 = F.softmax(z2, 1)
loss_dis += self.discrepancy(p1, p2)
loss_dis /= self.n_domain
loss_step_C = loss_dis
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def moment_distance(self, x, u):
# x (list): a list of feature matrix.
# u (torch.Tensor): feature matrix.
x_mean = [xi.mean(0) for xi in x]
u_mean = u.mean(0)
dist1 = self.pairwise_distance(x_mean, u_mean)
x_var = [xi.var(0) for xi in x]
u_var = u.var(0)
dist2 = self.pairwise_distance(x_var, u_var)
return (dist1+dist2) / 2
def pairwise_distance(self, x, u):
# x (list): a list of feature vector.
# u (torch.Tensor): feature vector.
dist = 0
count = 0
for xi in x:
dist += self.euclidean(xi, u)
count += 1
for i in range(len(x) - 1):
for j in range(i + 1, len(x)):
dist += self.euclidean(x[i], x[j])
count += 1
return dist / count
def euclidean(self, input1, input2):
return ((input1 - input2)**2).sum().sqrt()
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
label_x = batch_x["label"]
domain_x = batch_x["domain"]
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
return input_x, label_x, domain_x, input_u
def model_inference(self, input):
f = self.F(input)
p = 0
for C_i in self.C:
z = C_i(f)
p += F.softmax(z, 1)
p = p / len(self.C)
return p

View File

@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class MCD(TrainerXU):
"""Maximum Classifier Discrepancy.
https://arxiv.org/abs/1712.02560.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.n_step_F = cfg.TRAINER.MCD.N_STEP_F
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C1")
self.C1 = nn.Linear(fdim, self.num_classes)
self.C1.to(self.device)
print("# params: {:,}".format(count_num_param(self.C1)))
self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)
print("Building C2")
self.C2 = nn.Linear(fdim, self.num_classes)
self.C2.to(self.device)
print("# params: {:,}".format(count_num_param(self.C2)))
self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u = parsed
# Step A
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_step_A = loss_x1 + loss_x2
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_x = loss_x1 + loss_x2
with torch.no_grad():
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_dis = self.discrepancy(pred_u1, pred_u2)
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, ["C1", "C2"])
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_step_C = self.discrepancy(pred_u1, pred_u2)
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def model_inference(self, input):
feat = self.F(input)
return self.C1(feat)

View File

@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import ReverseGrad
from dassl.engine.trainer import SimpleNet
class Prototypes(nn.Module):
def __init__(self, fdim, num_classes, temp=0.05):
super().__init__()
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
self.temp = temp
def forward(self, x):
x = F.normalize(x, p=2, dim=1)
out = self.prototypes(x)
out = out / self.temp
return out
@TRAINER_REGISTRY.register()
class MME(TrainerXU):
"""Minimax Entropy.
https://arxiv.org/abs/1904.06487.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.MME.LMDA
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building C")
self.C = Prototypes(self.F.fdim, self.num_classes)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
feat_x = self.F(input_x)
logit_x = self.C(feat_x)
loss_x = F.cross_entropy(logit_x, label_x)
self.model_backward_and_update(loss_x)
feat_u = self.F(input_u)
feat_u = self.revgrad(feat_u)
logit_u = self.C(feat_u)
prob_u = F.softmax(logit_u, 1)
loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
self.model_backward_and_update(loss_u * self.lmda)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.C(self.F(input))

View File

@@ -0,0 +1,78 @@
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class SelfEnsembling(TrainerXU):
"""Self-ensembling for visual domain adaptation.
https://arxiv.org/abs/1706.05208.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA
self.conf_thre = cfg.TRAINER.SE.CONF_THRE
self.rampup = cfg.TRAINER.SE.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS == 2
def forward_backward(self, batch_x, batch_u):
global_step = self.batch_idx + self.epoch * self.num_batches
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u1, input_u2 = parsed
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
prob_u = F.softmax(self.model(input_u1), 1)
t_prob_u = F.softmax(self.teacher(input_u2), 1)
loss_u = ((prob_u - t_prob_u)**2).sum(1)
if self.conf_thre:
max_prob = t_prob_u.max(1)[0]
mask = (max_prob > self.conf_thre).float()
loss_u = (loss_u * mask).mean()
else:
weight_u = sigmoid_rampup(global_step, self.rampup)
loss_u = loss_u.mean() * weight_u
loss = loss_x + loss_u
self.model_backward_and_update(loss)
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_u1, input_u2 = input_u
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u1 = input_u1.to(self.device)
input_u2 = input_u2.to(self.device)
return input_x, label_x, input_u1, input_u2

View File

@@ -0,0 +1,34 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class SourceOnly(TrainerXU):
"""Baseline model for domain adaptation, which is
trained using source data only.
"""
def forward_backward(self, batch_x, batch_u):
input, label = self.parse_batch_train(batch_x, batch_u)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input = batch_x["img"]
label = batch_x["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label

View File

@@ -0,0 +1,4 @@
from .ddaig import DDAIG
from .daeldg import DAELDG
from .vanilla import Vanilla
from .crossgrad import CrossGrad

View File

@@ -0,0 +1,83 @@
import torch
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class CrossGrad(TrainerX):
"""Cross-gradient training.
https://arxiv.org/abs/1804.10745.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.eps_f = cfg.TRAINER.CG.EPS_F
self.eps_d = cfg.TRAINER.CG.EPS_D
self.alpha_f = cfg.TRAINER.CG.ALPHA_F
self.alpha_d = cfg.TRAINER.CG.ALPHA_D
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building D")
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
self.D.to(self.device)
print("# params: {:,}".format(count_num_param(self.D)))
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
self.register_model("D", self.D, self.optim_D, self.sched_D)
def forward_backward(self, batch):
input, label, domain = self.parse_batch_train(batch)
input.requires_grad = True
# Compute domain perturbation
loss_d = F.cross_entropy(self.D(input), domain)
loss_d.backward()
grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
input_d = input.data + self.eps_f * grad_d
# Compute label perturbation
input.grad.data.zero_()
loss_f = F.cross_entropy(self.F(input), label)
loss_f.backward()
grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
input_f = input.data + self.eps_d * grad_f
input = input.detach()
# Update label net
loss_f1 = F.cross_entropy(self.F(input), label)
loss_f2 = F.cross_entropy(self.F(input_d), label)
loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
self.model_backward_and_update(loss_f, "F")
# Update domain net
loss_d1 = F.cross_entropy(self.D(input), domain)
loss_d2 = F.cross_entropy(self.D(input_f), domain)
loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
self.model_backward_and_update(loss_d, "D")
loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.F(input)

View File

@@ -0,0 +1,169 @@
import torch
import torch.nn as nn
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot
class Experts(nn.Module):
def __init__(self, n_source, fdim, num_classes):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
)
self.softmax = nn.Softmax(dim=1)
def forward(self, i, x):
x = self.linears[i](x)
x = self.softmax(x)
return x
@TRAINER_REGISTRY.register()
class DAELDG(TrainerX):
"""Domain Adaptive Ensemble Learning.
DG version: only use labeled source data.
https://arxiv.org/abs/2003.07325.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u
self.val_loader = dm.val_loader
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building E")
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
self.E.to(self.device)
print("# params: {:,}".format(count_num_param(self.E)))
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
self.register_model("E", self.E, self.optim_E, self.sched_E)
def forward_backward(self, batch):
parsed_data = self.parse_batch_train(batch)
input, input2, label, domain = parsed_data
input = torch.split(input, self.split_batch, 0)
input2 = torch.split(input2, self.split_batch, 0)
label = torch.split(label, self.split_batch, 0)
domain = torch.split(domain, self.split_batch, 0)
domain = [d[0].item() for d in domain]
loss_x = 0
loss_cr = 0
acc = 0
feat = [self.F(x) for x in input]
feat2 = [self.F(x) for x in input2]
for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
cr_s = [j for j in domain if j != i]
# Learning expert
pred_i = self.E(i, feat_i)
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
expert_label_i = pred_i.detach()
acc += compute_accuracy(pred_i.detach(),
label_i.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat2_i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc /= self.n_domain
loss = 0
loss += loss_x
loss += loss_cr
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc": acc,
"loss_cr": loss_cr.item()
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
input2 = batch["img2"]
label = batch["label"]
domain = batch["domain"]
label = create_onehot(label, self.num_classes)
input = input.to(self.device)
input2 = input2.to(self.device)
label = label.to(self.device)
return input, input2, label, domain
def model_inference(self, input):
f = self.F(input)
p = []
for k in range(self.num_source_domains):
p_k = self.E(k, f)
p_k = p_k.unsqueeze(1)
p.append(p_k)
p = torch.cat(p, 1)
p = p.mean(1)
return p

View File

@@ -0,0 +1,107 @@
import torch
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.modeling import build_network
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class DDAIG(TrainerX):
"""Deep Domain-Adversarial Image Generation.
https://arxiv.org/abs/2003.06054.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.DDAIG.LMDA
self.clamp = cfg.TRAINER.DDAIG.CLAMP
self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
self.warmup = cfg.TRAINER.DDAIG.WARMUP
self.alpha = cfg.TRAINER.DDAIG.ALPHA
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building D")
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
self.D.to(self.device)
print("# params: {:,}".format(count_num_param(self.D)))
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
self.register_model("D", self.D, self.optim_D, self.sched_D)
print("Building G")
self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
self.G.to(self.device)
print("# params: {:,}".format(count_num_param(self.G)))
self.optim_G = build_optimizer(self.G, cfg.OPTIM)
self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
self.register_model("G", self.G, self.optim_G, self.sched_G)
def forward_backward(self, batch):
input, label, domain = self.parse_batch_train(batch)
#############
# Update G
#############
input_p = self.G(input, lmda=self.lmda)
if self.clamp:
input_p = torch.clamp(
input_p, min=self.clamp_min, max=self.clamp_max
)
loss_g = 0
# Minimize label loss
loss_g += F.cross_entropy(self.F(input_p), label)
# Maximize domain loss
loss_g -= F.cross_entropy(self.D(input_p), domain)
self.model_backward_and_update(loss_g, "G")
# Perturb data with new G
with torch.no_grad():
input_p = self.G(input, lmda=self.lmda)
if self.clamp:
input_p = torch.clamp(
input_p, min=self.clamp_min, max=self.clamp_max
)
#############
# Update F
#############
loss_f = F.cross_entropy(self.F(input), label)
if (self.epoch + 1) > self.warmup:
loss_fp = F.cross_entropy(self.F(input_p), label)
loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
self.model_backward_and_update(loss_f, "F")
#############
# Update D
#############
loss_d = F.cross_entropy(self.D(input), domain)
self.model_backward_and_update(loss_d, "D")
loss_summary = {
"loss_g": loss_g.item(),
"loss_f": loss_f.item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.F(input)

View File

@@ -0,0 +1,32 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class Vanilla(TrainerX):
"""Vanilla baseline."""
def forward_backward(self, batch):
input, label = self.parse_batch_train(batch)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label

View File

@@ -0,0 +1,5 @@
from .entmin import EntMin
from .fixmatch import FixMatch
from .mixmatch import MixMatch
from .mean_teacher import MeanTeacher
from .sup_baseline import SupBaseline

View File

@@ -0,0 +1,41 @@
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class EntMin(TrainerXU):
"""Entropy Minimization.
http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.ENTMIN.LMDA
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
output_x = self.model(input_x)
loss_x = F.cross_entropy(output_x, label_x)
output_u = F.softmax(self.model(input_u), 1)
loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()
loss = loss_x + loss_u * self.lmda
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,112 @@
import torch
from torch.nn import functional as F
from dassl.data import DataManager
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.data.transforms import build_transform
@TRAINER_REGISTRY.register()
class FixMatch(TrainerXU):
"""FixMatch: Simplifying Semi-Supervised Learning with
Consistency and Confidence.
https://arxiv.org/abs/2001.07685.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE
def check_cfg(self, cfg):
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = self.dm.train_loader_x
self.train_loader_u = self.dm.train_loader_u
self.val_loader = self.dm.val_loader
self.test_loader = self.dm.test_loader
self.num_classes = self.dm.num_classes
def assess_y_pred_quality(self, y_pred, y_true, mask):
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
keep_rate = mask.sum() / mask.numel()
output = {
"acc_thre": acc_thre,
"acc_raw": acc_raw,
"keep_rate": keep_rate
}
return output
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
input_u = torch.cat([input_x, input_u], 0)
input_u2 = torch.cat([input_x2, input_u2], 0)
n_x = input_x.size(0)
# Generate pseudo labels
with torch.no_grad():
output_u = F.softmax(self.model(input_u), 1)
max_prob, label_u_pred = output_u.max(1)
mask_u = (max_prob >= self.conf_thre).float()
# Evaluate pseudo labels' accuracy
y_u_pred_stats = self.assess_y_pred_quality(
label_u_pred[n_x:], label_u, mask_u[n_x:]
)
# Supervised loss
output_x = self.model(input_x)
loss_x = F.cross_entropy(output_x, label_x)
# Unsupervised loss
output_u = self.model(input_u2)
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none")
loss_u = (loss_u * mask_u).mean()
loss = loss_x + loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
"loss_u": loss_u.item(),
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"],
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"],
"y_u_pred_keep": y_u_pred_stats["keep_rate"],
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
input_x2 = batch_x["img2"]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_u2 = batch_u["img2"]
# label_u is used only for evaluating pseudo labels' accuracy
label_u = batch_u["label"]
input_x = input_x.to(self.device)
input_x2 = input_x2.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_u2 = input_u2.to(self.device)
label_u = label_u.to(self.device)
return input_x, input_x2, label_x, input_u, input_u2, label_u

View File

@@ -0,0 +1,54 @@
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class MeanTeacher(TrainerXU):
"""Mean teacher.
https://arxiv.org/abs/1703.01780.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U
self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA
self.rampup = cfg.TRAINER.MEANTEA.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
target_u = F.softmax(self.teacher(input_u), 1)
prob_u = F.softmax(self.model(input_u), 1)
loss_u = ((prob_u - target_u)**2).sum(1).mean()
weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
global_step = self.batch_idx + self.epoch * self.num_batches
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,98 @@
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling.ops import mixup
from dassl.modeling.ops.utils import (
sharpen_prob, create_onehot, linear_rampup, shuffle_index
)
@TRAINER_REGISTRY.register()
class MixMatch(TrainerXU):
"""MixMatch: A Holistic Approach to Semi-Supervised Learning.
https://arxiv.org/abs/1905.02249.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U
self.temp = cfg.TRAINER.MIXMATCH.TEMP
self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA
self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS > 1
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
num_x = input_x.shape[0]
global_step = self.batch_idx + self.epoch * self.num_batches
weight_u = self.weight_u * linear_rampup(global_step, self.rampup)
# Generate pseudo-label for unlabeled data
with torch.no_grad():
output_u = 0
for input_ui in input_u:
output_ui = F.softmax(self.model(input_ui), 1)
output_u += output_ui
output_u /= len(input_u)
label_u = sharpen_prob(output_u, self.temp)
label_u = [label_u] * len(input_u)
label_u = torch.cat(label_u, 0)
input_u = torch.cat(input_u, 0)
# Combine and shuffle labeled and unlabeled data
input_xu = torch.cat([input_x, input_u], 0)
label_xu = torch.cat([label_x, label_u], 0)
input_xu, label_xu = shuffle_index(input_xu, label_xu)
# Mixup
input_x, label_x = mixup(
input_x,
input_xu[:num_x],
label_x,
label_xu[:num_x],
self.beta,
preserve_order=True,
)
input_u, label_u = mixup(
input_u,
input_xu[num_x:],
label_u,
label_xu[num_x:],
self.beta,
preserve_order=True,
)
# Compute losses
output_x = F.softmax(self.model(input_x), 1)
loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()
output_u = F.softmax(self.model(input_u), 1)
loss_u = ((label_u - output_u)**2).mean()
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
label_x = create_onehot(label_x, self.num_classes)
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = [input_ui.to(self.device) for input_ui in input_u]
return input_x, label_x, input_u

View File

@@ -0,0 +1,32 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class SupBaseline(TrainerXU):
"""Supervised Baseline."""
def forward_backward(self, batch_x, batch_u):
input, label = self.parse_batch_train(batch_x, batch_u)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input = batch_x["img"]
label = batch_x["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label

View File

@@ -0,0 +1,735 @@
import json
import time
import numpy as np
import os.path as osp
import datetime
from collections import OrderedDict
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import (
MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
load_pretrained_weights
)
from dassl.modeling import build_head, build_backbone
from dassl.evaluation import build_evaluator
class SimpleNet(nn.Module):
"""A simple neural network composed of a CNN backbone
and optionally a head such as mlp for classification.
"""
def __init__(self, cfg, model_cfg, num_classes, **kwargs):
super().__init__()
self.backbone = build_backbone(
model_cfg.BACKBONE.NAME,
verbose=cfg.VERBOSE,
pretrained=model_cfg.BACKBONE.PRETRAINED,
**kwargs,
)
fdim = self.backbone.out_features
self.head = None
if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:
self.head = build_head(
model_cfg.HEAD.NAME,
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,
activation=model_cfg.HEAD.ACTIVATION,
bn=model_cfg.HEAD.BN,
dropout=model_cfg.HEAD.DROPOUT,
**kwargs,
)
fdim = self.head.out_features
self.classifier = None
if num_classes > 0:
self.classifier = nn.Linear(fdim, num_classes)
self._fdim = fdim
@property
def fdim(self):
return self._fdim
def forward(self, x, return_feature=False):
f = self.backbone(x)
if self.head is not None:
f = self.head(f)
if self.classifier is None:
return f
y = self.classifier(f)
if return_feature:
return y, f
return y
class TrainerBase:
"""Base class for iterative trainer."""
def __init__(self):
self._models = OrderedDict()
self._optims = OrderedDict()
self._scheds = OrderedDict()
self._writer = None
def register_model(self, name="model", model=None, optim=None, sched=None):
if self.__dict__.get("_models") is None:
raise AttributeError(
"Cannot assign model before super().__init__() call"
)
if self.__dict__.get("_optims") is None:
raise AttributeError(
"Cannot assign optim before super().__init__() call"
)
if self.__dict__.get("_scheds") is None:
raise AttributeError(
"Cannot assign sched before super().__init__() call"
)
assert name not in self._models, "Found duplicate model names"
self._models[name] = model
self._optims[name] = optim
self._scheds[name] = sched
def get_model_names(self, names=None):
names_real = list(self._models.keys())
if names is not None:
names = tolist_if_not(names)
for name in names:
assert name in names_real
return names
else:
return names_real
def save_model(self, epoch, directory, is_best=False, model_name=""):
names = self.get_model_names()
for name in names:
model_dict = self._models[name].state_dict()
optim_dict = None
if self._optims[name] is not None:
optim_dict = self._optims[name].state_dict()
sched_dict = None
if self._scheds[name] is not None:
sched_dict = self._scheds[name].state_dict()
save_checkpoint(
{
"state_dict": model_dict,
"epoch": epoch + 1,
"optimizer": optim_dict,
"scheduler": sched_dict,
},
osp.join(directory, name),
is_best=is_best,
model_name=model_name,
)
def resume_model_if_exist(self, directory):
names = self.get_model_names()
file_missing = False
for name in names:
path = osp.join(directory, name)
if not osp.exists(path):
file_missing = True
break
if file_missing:
print("No checkpoint found, train from scratch")
return 0
print(
'Found checkpoint in "{}". Will resume training'.format(directory)
)
for name in names:
path = osp.join(directory, name)
start_epoch = resume_from_checkpoint(
path, self._models[name], self._optims[name],
self._scheds[name]
)
return start_epoch
def load_model(self, directory, epoch=None):
if not directory:
print(
"Note that load_model() is skipped as no pretrained "
"model is given (ignore this if it's done on purpose)"
)
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError(
'Model not found at "{}"'.format(model_path)
)
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
print(
"Loading weights to {} "
'from "{}" (epoch = {})'.format(name, model_path, epoch)
)
self._models[name].load_state_dict(state_dict)
def set_model_mode(self, mode="train", names=None):
names = self.get_model_names(names)
for name in names:
if mode == "train":
self._models[name].train()
elif mode in ["test", "eval"]:
self._models[name].eval()
else:
raise KeyError
def update_lr(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._scheds[name] is not None:
self._scheds[name].step()
def detect_anomaly(self, loss):
if not torch.isfinite(loss).all():
raise FloatingPointError("Loss is infinite or NaN!")
def init_writer(self, log_dir):
if self.__dict__.get("_writer") is None or self._writer is None:
print(
"Initializing summary writer for tensorboard "
"with log_dir={}".format(log_dir)
)
self._writer = SummaryWriter(log_dir=log_dir)
def close_writer(self):
if self._writer is not None:
self._writer.close()
def write_scalar(self, tag, scalar_value, global_step=None):
if self._writer is None:
# Do nothing if writer is not initialized
# Note that writer is only used when training is needed
pass
else:
self._writer.add_scalar(tag, scalar_value, global_step)
def train(self, start_epoch, max_epoch):
"""Generic training loops."""
self.start_epoch = start_epoch
self.max_epoch = max_epoch
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
self.before_epoch()
self.run_epoch()
self.after_epoch()
self.after_train()
def before_train(self):
pass
def after_train(self):
pass
def before_epoch(self):
pass
def after_epoch(self):
pass
def run_epoch(self):
raise NotImplementedError
def test(self):
raise NotImplementedError
def parse_batch_train(self, batch):
raise NotImplementedError
def parse_batch_test(self, batch):
raise NotImplementedError
def forward_backward(self, batch):
raise NotImplementedError
def model_inference(self, input):
raise NotImplementedError
def model_zero_grad(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].zero_grad()
def model_backward(self, loss):
self.detect_anomaly(loss)
loss.backward()
def model_update(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].step()
def model_backward_and_update(self, loss, names=None):
self.model_zero_grad(names)
self.model_backward(loss)
self.model_update(names)
def prograd_backward_and_update(
self, loss_a, loss_b, lambda_=1, names=None
):
# loss_b not increase is okay
# loss_a has to decline
self.model_zero_grad(names)
# get name of the model parameters
names = self.get_model_names(names)
# backward loss_a
self.detect_anomaly(loss_b)
loss_b.backward(retain_graph=True)
# normalize gradient
b_grads = []
for name in names:
for p in self._models[name].parameters():
b_grads.append(p.grad.clone())
# optimizer don't step
for name in names:
self._optims[name].zero_grad()
# backward loss_a
self.detect_anomaly(loss_a)
loss_a.backward()
for name in names:
for p, b_grad in zip(self._models[name].parameters(), b_grads):
# calculate cosine distance
b_grad_norm = b_grad / torch.linalg.norm(b_grad)
a_grad = p.grad.clone()
a_grad_norm = a_grad / torch.linalg.norm(a_grad)
if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:
p.grad = a_grad - lambda_ * torch.dot(
a_grad.flatten(), b_grad_norm.flatten()
) * b_grad_norm
# optimizer
for name in names:
self._optims[name].step()
class SimpleTrainer(TrainerBase):
"""A simple trainer class implementing generic functions."""
def __init__(self, cfg):
super().__init__()
self.check_cfg(cfg)
if torch.cuda.is_available() and cfg.USE_CUDA:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# Save as attributes some frequently used variables
self.start_epoch = self.epoch = 0
self.max_epoch = cfg.OPTIM.MAX_EPOCH
self.output_dir = cfg.OUTPUT_DIR
self.cfg = cfg
self.build_data_loader()
self.build_model()
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
self.best_result = -np.inf
def check_cfg(self, cfg):
"""Check whether some variables are set correctly for
the trainer (optional).
For example, a trainer might require a particular sampler
for training such as 'RandomDomainSampler', so it is good
to do the checking:
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
"""
pass
def build_data_loader(self):
"""Create essential data-related attributes.
A re-implementation of this method must create the
same attributes (except self.dm).
"""
dm = DataManager(self.cfg)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u # optional, can be None
self.val_loader = dm.val_loader # optional, can be None
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname # dict {label: classname}
self.dm = dm
def build_model(self):
"""Build and register model.
The default builds a classification model along with its
optimizer and scheduler.
Custom trainers can re-implement this method if necessary.
"""
cfg = self.cfg
print("Building model")
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
print("# params: {:,}".format(count_num_param(self.model)))
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("model", self.model, self.optim, self.sched)
device_count = torch.cuda.device_count()
if device_count > 1:
print(
f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel"
)
self.model = nn.DataParallel(self.model)
def train(self):
super().train(self.start_epoch, self.max_epoch)
def before_train(self):
directory = self.cfg.OUTPUT_DIR
if self.cfg.RESUME:
directory = self.cfg.RESUME
self.start_epoch = self.resume_model_if_exist(directory)
# Initialize summary writer
writer_dir = osp.join(self.output_dir, "tensorboard")
mkdir_if_missing(writer_dir)
self.init_writer(writer_dir)
# Remember the starting time (for computing the elapsed time)
self.time_start = time.time()
def after_train(self):
print("Finished training")
do_test = not self.cfg.TEST.NO_TEST
if do_test:
if self.cfg.TEST.FINAL_MODEL == "best_val":
print("Deploy the model with the best val performance")
self.load_model(self.output_dir)
self.test()
# Show elapsed time
elapsed = round(time.time() - self.time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed: {}".format(elapsed))
# Close writer
self.close_writer()
def after_epoch(self):
last_epoch = (self.epoch + 1) == self.max_epoch
do_test = not self.cfg.TEST.NO_TEST
meet_checkpoint_freq = (
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
)
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
curr_result = self.test(split="val")
is_best = curr_result > self.best_result
if is_best:
self.best_result = curr_result
self.save_model(
self.epoch,
self.output_dir,
model_name="model-best.pth.tar"
)
if meet_checkpoint_freq or last_epoch:
self.save_model(self.epoch, self.output_dir)
@torch.no_grad()
def output_test(self, split=None):
"""testing pipline, which could also output the results."""
self.set_model_mode("eval")
self.evaluator.reset()
output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')
res_json = {}
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
print("Do evaluation on {} set".format(split))
else:
data_loader = self.test_loader
print("Do evaluation on test set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
img_path = batch['impath']
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
for i in range(len(img_path)):
res_json[img_path[i]] = {
'predict': output[i].cpu().numpy().tolist(),
'gt': label[i].cpu().numpy().tolist()
}
with open(output_file, 'w') as f:
json.dump(res_json, f)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = "{}/{}".format(split, k)
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
@torch.no_grad()
def test(self, split=None):
"""A generic testing pipeline."""
self.set_model_mode("eval")
self.evaluator.reset()
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
print("Do evaluation on {} set".format(split))
else:
data_loader = self.test_loader
print("Do evaluation on test set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = "{}/{}".format(split, k)
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
def model_inference(self, input):
return self.model(input)
def parse_batch_test(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def get_current_lr(self, names=None):
names = self.get_model_names(names)
name = names[0]
return self._optims[name].param_groups[0]["lr"]
class TrainerXU(SimpleTrainer):
"""A base trainer using both labeled and unlabeled data.
In the context of domain adaptation, labeled and unlabeled data
come from source and target domains respectively.
When it comes to semi-supervised learning, all data comes from the
same domain.
"""
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
# Decide to iterate over labeled or unlabeled dataset
len_train_loader_x = len(self.train_loader_x)
len_train_loader_u = len(self.train_loader_u)
if self.cfg.TRAIN.COUNT_ITER == "train_x":
self.num_batches = len_train_loader_x
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
self.num_batches = len_train_loader_u
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
self.num_batches = min(len_train_loader_x, len_train_loader_u)
else:
raise ValueError
train_loader_x_iter = iter(self.train_loader_x)
train_loader_u_iter = iter(self.train_loader_u)
end = time.time()
for self.batch_idx in range(self.num_batches):
try:
batch_x = next(train_loader_x_iter)
except StopIteration:
train_loader_x_iter = iter(self.train_loader_x)
batch_x = next(train_loader_x_iter)
try:
batch_u = next(train_loader_u_iter)
except StopIteration:
train_loader_u_iter = iter(self.train_loader_u)
batch_u = next(train_loader_u_iter)
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch_x, batch_u)
batch_time.update(time.time() - end)
losses.update(loss_summary)
if (
self.batch_idx + 1
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
"epoch [{0}/{1}][{2}/{3}]\t"
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"eta {eta}\t"
"{losses}\t"
"lr {lr:.6e}".format(
self.epoch + 1,
self.max_epoch,
self.batch_idx + 1,
self.num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta,
losses=losses,
lr=self.get_current_lr(),
)
)
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
return input_x, label_x, input_u
class TrainerX(SimpleTrainer):
"""A base trainer using labeled data only."""
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.num_batches = len(self.train_loader_x)
end = time.time()
for self.batch_idx, batch in enumerate(self.train_loader_x):
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch)
batch_time.update(time.time() - end)
losses.update(loss_summary)
if (
self.batch_idx + 1
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
"epoch [{0}/{1}][{2}/{3}]\t"
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"eta {eta}\t"
"{losses}\t"
"lr {lr:.6e}".format(
self.epoch + 1,
self.max_epoch,
self.batch_idx + 1,
self.num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta,
losses=losses,
lr=self.get_current_lr(),
)
)
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
domain = batch["domain"]
input = input.to(self.device)
label = label.to(self.device)
domain = domain.to(self.device)
return input, label, domain