release code
This commit is contained in:
6
Dassl.ProGrad.pytorch/dassl/engine/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/engine/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .build import TRAINER_REGISTRY, build_trainer # isort:skip
|
||||
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip
|
||||
|
||||
from .da import *
|
||||
from .dg import *
|
||||
from .ssl import *
|
||||
11
Dassl.ProGrad.pytorch/dassl/engine/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/engine/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
TRAINER_REGISTRY = Registry("TRAINER")
|
||||
|
||||
|
||||
def build_trainer(cfg):
|
||||
avai_trainers = TRAINER_REGISTRY.registered_names()
|
||||
check_availability(cfg.TRAINER.NAME, avai_trainers)
|
||||
if cfg.VERBOSE:
|
||||
print("Loading trainer: {}".format(cfg.TRAINER.NAME))
|
||||
return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)
|
||||
9
Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py
Normal file
9
Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .mcd import MCD
|
||||
from .mme import MME
|
||||
from .adda import ADDA
|
||||
from .dael import DAEL
|
||||
from .dann import DANN
|
||||
from .adabn import AdaBN
|
||||
from .m3sda import M3SDA
|
||||
from .source_only import SourceOnly
|
||||
from .self_ensembling import SelfEnsembling
|
||||
38
Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
Normal file
38
Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
|
||||
from dassl.utils import check_isfile
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class AdaBN(TrainerXU):
|
||||
"""Adaptive Batch Normalization.
|
||||
|
||||
https://arxiv.org/abs/1603.04779.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.done_reset_bn_stats = False
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert check_isfile(
|
||||
cfg.MODEL.INIT_WEIGHTS
|
||||
), "The weights of source model must be provided"
|
||||
|
||||
def before_epoch(self):
|
||||
if not self.done_reset_bn_stats:
|
||||
for m in self.model.modules():
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("BatchNorm") != -1:
|
||||
m.reset_running_stats()
|
||||
|
||||
self.done_reset_bn_stats = True
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_u = batch_u["img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
self.model(input_u)
|
||||
|
||||
return None
|
||||
85
Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
Normal file
85
Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import check_isfile, count_num_param, open_specified_layers
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.modeling import build_head
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class ADDA(TrainerXU):
|
||||
"""Adversarial Discriminative Domain Adaptation.
|
||||
|
||||
https://arxiv.org/abs/1702.05464.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.open_layers = ["backbone"]
|
||||
if isinstance(self.model.head, nn.Module):
|
||||
self.open_layers.append("head")
|
||||
|
||||
self.source_model = copy.deepcopy(self.model)
|
||||
self.source_model.eval()
|
||||
for param in self.source_model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
self.build_critic()
|
||||
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert check_isfile(
|
||||
cfg.MODEL.INIT_WEIGHTS
|
||||
), "The weights of source model must be provided"
|
||||
|
||||
def build_critic(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building critic network")
|
||||
fdim = self.model.fdim
|
||||
critic_body = build_head(
|
||||
"mlp",
|
||||
verbose=cfg.VERBOSE,
|
||||
in_features=fdim,
|
||||
hidden_layers=[fdim, fdim // 2],
|
||||
activation="leaky_relu",
|
||||
)
|
||||
self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
|
||||
print("# params: {:,}".format(count_num_param(self.critic)))
|
||||
self.critic.to(self.device)
|
||||
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
|
||||
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
|
||||
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
open_specified_layers(self.model, self.open_layers)
|
||||
input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
|
||||
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
|
||||
|
||||
_, feat_x = self.source_model(input_x, return_feature=True)
|
||||
_, feat_u = self.model(input_u, return_feature=True)
|
||||
|
||||
logit_xd = self.critic(feat_x)
|
||||
logit_ud = self.critic(feat_u.detach())
|
||||
|
||||
loss_critic = self.bce(logit_xd, domain_x)
|
||||
loss_critic += self.bce(logit_ud, domain_u)
|
||||
self.model_backward_and_update(loss_critic, "critic")
|
||||
|
||||
logit_ud = self.critic(feat_u)
|
||||
loss_model = self.bce(logit_ud, 1 - domain_u)
|
||||
self.model_backward_and_update(loss_model, "model")
|
||||
|
||||
loss_summary = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"loss_model": loss_model.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
210
Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
Normal file
210
Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
from dassl.data.transforms import build_transform
|
||||
from dassl.modeling.ops.utils import create_onehot
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
|
||||
def __init__(self, n_source, fdim, num_classes):
|
||||
super().__init__()
|
||||
self.linears = nn.ModuleList(
|
||||
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
|
||||
)
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
def forward(self, i, x):
|
||||
x = self.linears[i](x)
|
||||
x = self.softmax(x)
|
||||
return x
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DAEL(TrainerXU):
|
||||
"""Domain Adaptive Ensemble Learning.
|
||||
|
||||
https://arxiv.org/abs/2003.07325.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
if n_domain <= 0:
|
||||
n_domain = self.num_source_domains
|
||||
self.split_batch = batch_size // n_domain
|
||||
self.n_domain = n_domain
|
||||
|
||||
self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
|
||||
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
|
||||
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
|
||||
|
||||
def build_data_loader(self):
|
||||
cfg = self.cfg
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
custom_tfm_train = [tfm_train]
|
||||
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
|
||||
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||
custom_tfm_train += [tfm_train_strong]
|
||||
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||
self.train_loader_x = dm.train_loader_x
|
||||
self.train_loader_u = dm.train_loader_u
|
||||
self.val_loader = dm.val_loader
|
||||
self.test_loader = dm.test_loader
|
||||
self.num_classes = dm.num_classes
|
||||
self.num_source_domains = dm.num_source_domains
|
||||
self.lab2cname = dm.lab2cname
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
fdim = self.F.fdim
|
||||
|
||||
print("Building E")
|
||||
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
|
||||
self.E.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.E)))
|
||||
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
|
||||
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
|
||||
self.register_model("E", self.E, self.optim_E, self.sched_E)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed_data = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
|
||||
|
||||
input_x = torch.split(input_x, self.split_batch, 0)
|
||||
input_x2 = torch.split(input_x2, self.split_batch, 0)
|
||||
label_x = torch.split(label_x, self.split_batch, 0)
|
||||
domain_x = torch.split(domain_x, self.split_batch, 0)
|
||||
domain_x = [d[0].item() for d in domain_x]
|
||||
|
||||
# Generate pseudo label
|
||||
with torch.no_grad():
|
||||
feat_u = self.F(input_u)
|
||||
pred_u = []
|
||||
for k in range(self.num_source_domains):
|
||||
pred_uk = self.E(k, feat_u)
|
||||
pred_uk = pred_uk.unsqueeze(1)
|
||||
pred_u.append(pred_uk)
|
||||
pred_u = torch.cat(pred_u, 1) # (B, K, C)
|
||||
# Get the highest probability and index (label) for each expert
|
||||
experts_max_p, experts_max_idx = pred_u.max(2) # (B, K)
|
||||
# Get the most confident expert
|
||||
max_expert_p, max_expert_idx = experts_max_p.max(1) # (B)
|
||||
pseudo_label_u = []
|
||||
for i, experts_label in zip(max_expert_idx, experts_max_idx):
|
||||
pseudo_label_u.append(experts_label[i])
|
||||
pseudo_label_u = torch.stack(pseudo_label_u, 0)
|
||||
pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
|
||||
pseudo_label_u = pseudo_label_u.to(self.device)
|
||||
label_u_mask = (max_expert_p >= self.conf_thre).float()
|
||||
|
||||
loss_x = 0
|
||||
loss_cr = 0
|
||||
acc_x = 0
|
||||
|
||||
feat_x = [self.F(x) for x in input_x]
|
||||
feat_x2 = [self.F(x) for x in input_x2]
|
||||
feat_u2 = self.F(input_u2)
|
||||
|
||||
for feat_xi, feat_x2i, label_xi, i in zip(
|
||||
feat_x, feat_x2, label_x, domain_x
|
||||
):
|
||||
cr_s = [j for j in domain_x if j != i]
|
||||
|
||||
# Learning expert
|
||||
pred_xi = self.E(i, feat_xi)
|
||||
loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
|
||||
expert_label_xi = pred_xi.detach()
|
||||
acc_x += compute_accuracy(pred_xi.detach(),
|
||||
label_xi.max(1)[1])[0].item()
|
||||
|
||||
# Consistency regularization
|
||||
cr_pred = []
|
||||
for j in cr_s:
|
||||
pred_j = self.E(j, feat_x2i)
|
||||
pred_j = pred_j.unsqueeze(1)
|
||||
cr_pred.append(pred_j)
|
||||
cr_pred = torch.cat(cr_pred, 1)
|
||||
cr_pred = cr_pred.mean(1)
|
||||
loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()
|
||||
|
||||
loss_x /= self.n_domain
|
||||
loss_cr /= self.n_domain
|
||||
acc_x /= self.n_domain
|
||||
|
||||
# Unsupervised loss
|
||||
pred_u = []
|
||||
for k in range(self.num_source_domains):
|
||||
pred_uk = self.E(k, feat_u2)
|
||||
pred_uk = pred_uk.unsqueeze(1)
|
||||
pred_u.append(pred_uk)
|
||||
pred_u = torch.cat(pred_u, 1)
|
||||
pred_u = pred_u.mean(1)
|
||||
l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
|
||||
loss_u = (l_u * label_u_mask).mean()
|
||||
|
||||
loss = 0
|
||||
loss += loss_x
|
||||
loss += loss_cr
|
||||
loss += loss_u * self.weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": acc_x,
|
||||
"loss_cr": loss_cr.item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
input_x2 = batch_x["img2"]
|
||||
label_x = batch_x["label"]
|
||||
domain_x = batch_x["domain"]
|
||||
input_u = batch_u["img"]
|
||||
input_u2 = batch_u["img2"]
|
||||
|
||||
label_x = create_onehot(label_x, self.num_classes)
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
input_x2 = input_x2.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
input_u2 = input_u2.to(self.device)
|
||||
|
||||
return input_x, input_x2, label_x, domain_x, input_u, input_u2
|
||||
|
||||
def model_inference(self, input):
|
||||
f = self.F(input)
|
||||
p = []
|
||||
for k in range(self.num_source_domains):
|
||||
p_k = self.E(k, f)
|
||||
p_k = p_k.unsqueeze(1)
|
||||
p.append(p_k)
|
||||
p = torch.cat(p, 1)
|
||||
p = p.mean(1)
|
||||
return p
|
||||
78
Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
Normal file
78
Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling import build_head
|
||||
from dassl.modeling.ops import ReverseGrad
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DANN(TrainerXU):
|
||||
"""Domain-Adversarial Neural Networks.
|
||||
|
||||
https://arxiv.org/abs/1505.07818.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.build_critic()
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def build_critic(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building critic network")
|
||||
fdim = self.model.fdim
|
||||
critic_body = build_head(
|
||||
"mlp",
|
||||
verbose=cfg.VERBOSE,
|
||||
in_features=fdim,
|
||||
hidden_layers=[fdim, fdim],
|
||||
activation="leaky_relu",
|
||||
)
|
||||
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
|
||||
print("# params: {:,}".format(count_num_param(self.critic)))
|
||||
self.critic.to(self.device)
|
||||
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
|
||||
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
|
||||
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
|
||||
self.revgrad = ReverseGrad()
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
|
||||
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
progress = global_step / (self.max_epoch * self.num_batches)
|
||||
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
|
||||
|
||||
logit_x, feat_x = self.model(input_x, return_feature=True)
|
||||
_, feat_u = self.model(input_u, return_feature=True)
|
||||
|
||||
loss_x = self.ce(logit_x, label_x)
|
||||
|
||||
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
|
||||
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
|
||||
output_xd = self.critic(feat_x)
|
||||
output_ud = self.critic(feat_u)
|
||||
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
|
||||
|
||||
loss = loss_x + loss_d
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_d": loss_d.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
208
Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
Normal file
208
Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
class PairClassifiers(nn.Module):
|
||||
|
||||
def __init__(self, fdim, num_classes):
|
||||
super().__init__()
|
||||
self.c1 = nn.Linear(fdim, num_classes)
|
||||
self.c2 = nn.Linear(fdim, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
z1 = self.c1(x)
|
||||
if not self.training:
|
||||
return z1
|
||||
z2 = self.c2(x)
|
||||
return z1, z2
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class M3SDA(TrainerXU):
|
||||
"""Moment Matching for Multi-Source Domain Adaptation.
|
||||
|
||||
https://arxiv.org/abs/1812.01754.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
if n_domain <= 0:
|
||||
n_domain = self.num_source_domains
|
||||
self.split_batch = batch_size // n_domain
|
||||
self.n_domain = n_domain
|
||||
|
||||
self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
|
||||
self.lmda = cfg.TRAINER.M3SDA.LMDA
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
fdim = self.F.fdim
|
||||
|
||||
print("Building C")
|
||||
self.C = nn.ModuleList(
|
||||
[
|
||||
PairClassifiers(fdim, self.num_classes)
|
||||
for _ in range(self.num_source_domains)
|
||||
]
|
||||
)
|
||||
self.C.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.C)))
|
||||
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
|
||||
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
|
||||
self.register_model("C", self.C, self.optim_C, self.sched_C)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, label_x, domain_x, input_u = parsed
|
||||
|
||||
input_x = torch.split(input_x, self.split_batch, 0)
|
||||
label_x = torch.split(label_x, self.split_batch, 0)
|
||||
domain_x = torch.split(domain_x, self.split_batch, 0)
|
||||
domain_x = [d[0].item() for d in domain_x]
|
||||
|
||||
# Step A
|
||||
loss_x = 0
|
||||
feat_x = []
|
||||
|
||||
for x, y, d in zip(input_x, label_x, domain_x):
|
||||
f = self.F(x)
|
||||
z1, z2 = self.C[d](f)
|
||||
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
|
||||
|
||||
feat_x.append(f)
|
||||
|
||||
loss_x /= self.n_domain
|
||||
|
||||
feat_u = self.F(input_u)
|
||||
loss_msda = self.moment_distance(feat_x, feat_u)
|
||||
|
||||
loss_step_A = loss_x + loss_msda * self.lmda
|
||||
self.model_backward_and_update(loss_step_A)
|
||||
|
||||
# Step B
|
||||
with torch.no_grad():
|
||||
feat_u = self.F(input_u)
|
||||
|
||||
loss_x, loss_dis = 0, 0
|
||||
|
||||
for x, y, d in zip(input_x, label_x, domain_x):
|
||||
with torch.no_grad():
|
||||
f = self.F(x)
|
||||
z1, z2 = self.C[d](f)
|
||||
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
|
||||
|
||||
z1, z2 = self.C[d](feat_u)
|
||||
p1 = F.softmax(z1, 1)
|
||||
p2 = F.softmax(z2, 1)
|
||||
loss_dis += self.discrepancy(p1, p2)
|
||||
|
||||
loss_x /= self.n_domain
|
||||
loss_dis /= self.n_domain
|
||||
|
||||
loss_step_B = loss_x - loss_dis
|
||||
self.model_backward_and_update(loss_step_B, "C")
|
||||
|
||||
# Step C
|
||||
for _ in range(self.n_step_F):
|
||||
feat_u = self.F(input_u)
|
||||
|
||||
loss_dis = 0
|
||||
|
||||
for d in domain_x:
|
||||
z1, z2 = self.C[d](feat_u)
|
||||
p1 = F.softmax(z1, 1)
|
||||
p2 = F.softmax(z2, 1)
|
||||
loss_dis += self.discrepancy(p1, p2)
|
||||
|
||||
loss_dis /= self.n_domain
|
||||
loss_step_C = loss_dis
|
||||
|
||||
self.model_backward_and_update(loss_step_C, "F")
|
||||
|
||||
loss_summary = {
|
||||
"loss_step_A": loss_step_A.item(),
|
||||
"loss_step_B": loss_step_B.item(),
|
||||
"loss_step_C": loss_step_C.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def moment_distance(self, x, u):
|
||||
# x (list): a list of feature matrix.
|
||||
# u (torch.Tensor): feature matrix.
|
||||
x_mean = [xi.mean(0) for xi in x]
|
||||
u_mean = u.mean(0)
|
||||
dist1 = self.pairwise_distance(x_mean, u_mean)
|
||||
|
||||
x_var = [xi.var(0) for xi in x]
|
||||
u_var = u.var(0)
|
||||
dist2 = self.pairwise_distance(x_var, u_var)
|
||||
|
||||
return (dist1+dist2) / 2
|
||||
|
||||
def pairwise_distance(self, x, u):
|
||||
# x (list): a list of feature vector.
|
||||
# u (torch.Tensor): feature vector.
|
||||
dist = 0
|
||||
count = 0
|
||||
|
||||
for xi in x:
|
||||
dist += self.euclidean(xi, u)
|
||||
count += 1
|
||||
|
||||
for i in range(len(x) - 1):
|
||||
for j in range(i + 1, len(x)):
|
||||
dist += self.euclidean(x[i], x[j])
|
||||
count += 1
|
||||
|
||||
return dist / count
|
||||
|
||||
def euclidean(self, input1, input2):
|
||||
return ((input1 - input2)**2).sum().sqrt()
|
||||
|
||||
def discrepancy(self, y1, y2):
|
||||
return (y1 - y2).abs().mean()
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
label_x = batch_x["label"]
|
||||
domain_x = batch_x["domain"]
|
||||
input_u = batch_u["img"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
|
||||
return input_x, label_x, domain_x, input_u
|
||||
|
||||
def model_inference(self, input):
|
||||
f = self.F(input)
|
||||
p = 0
|
||||
for C_i in self.C:
|
||||
z = C_i(f)
|
||||
p += F.softmax(z, 1)
|
||||
p = p / len(self.C)
|
||||
return p
|
||||
105
Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
Normal file
105
Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MCD(TrainerXU):
|
||||
"""Maximum Classifier Discrepancy.
|
||||
|
||||
https://arxiv.org/abs/1712.02560.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.n_step_F = cfg.TRAINER.MCD.N_STEP_F
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
fdim = self.F.fdim
|
||||
|
||||
print("Building C1")
|
||||
self.C1 = nn.Linear(fdim, self.num_classes)
|
||||
self.C1.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.C1)))
|
||||
self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
|
||||
self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
|
||||
self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)
|
||||
|
||||
print("Building C2")
|
||||
self.C2 = nn.Linear(fdim, self.num_classes)
|
||||
self.C2.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.C2)))
|
||||
self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
|
||||
self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
|
||||
self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, label_x, input_u = parsed
|
||||
|
||||
# Step A
|
||||
feat_x = self.F(input_x)
|
||||
logit_x1 = self.C1(feat_x)
|
||||
logit_x2 = self.C2(feat_x)
|
||||
loss_x1 = F.cross_entropy(logit_x1, label_x)
|
||||
loss_x2 = F.cross_entropy(logit_x2, label_x)
|
||||
loss_step_A = loss_x1 + loss_x2
|
||||
self.model_backward_and_update(loss_step_A)
|
||||
|
||||
# Step B
|
||||
with torch.no_grad():
|
||||
feat_x = self.F(input_x)
|
||||
logit_x1 = self.C1(feat_x)
|
||||
logit_x2 = self.C2(feat_x)
|
||||
loss_x1 = F.cross_entropy(logit_x1, label_x)
|
||||
loss_x2 = F.cross_entropy(logit_x2, label_x)
|
||||
loss_x = loss_x1 + loss_x2
|
||||
|
||||
with torch.no_grad():
|
||||
feat_u = self.F(input_u)
|
||||
pred_u1 = F.softmax(self.C1(feat_u), 1)
|
||||
pred_u2 = F.softmax(self.C2(feat_u), 1)
|
||||
loss_dis = self.discrepancy(pred_u1, pred_u2)
|
||||
|
||||
loss_step_B = loss_x - loss_dis
|
||||
self.model_backward_and_update(loss_step_B, ["C1", "C2"])
|
||||
|
||||
# Step C
|
||||
for _ in range(self.n_step_F):
|
||||
feat_u = self.F(input_u)
|
||||
pred_u1 = F.softmax(self.C1(feat_u), 1)
|
||||
pred_u2 = F.softmax(self.C2(feat_u), 1)
|
||||
loss_step_C = self.discrepancy(pred_u1, pred_u2)
|
||||
self.model_backward_and_update(loss_step_C, "F")
|
||||
|
||||
loss_summary = {
|
||||
"loss_step_A": loss_step_A.item(),
|
||||
"loss_step_B": loss_step_B.item(),
|
||||
"loss_step_C": loss_step_C.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def discrepancy(self, y1, y2):
|
||||
return (y1 - y2).abs().mean()
|
||||
|
||||
def model_inference(self, input):
|
||||
feat = self.F(input)
|
||||
return self.C1(feat)
|
||||
86
Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
Normal file
86
Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling.ops import ReverseGrad
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
class Prototypes(nn.Module):
|
||||
|
||||
def __init__(self, fdim, num_classes, temp=0.05):
|
||||
super().__init__()
|
||||
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
|
||||
self.temp = temp
|
||||
|
||||
def forward(self, x):
|
||||
x = F.normalize(x, p=2, dim=1)
|
||||
out = self.prototypes(x)
|
||||
out = out / self.temp
|
||||
return out
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MME(TrainerXU):
|
||||
"""Minimax Entropy.
|
||||
|
||||
https://arxiv.org/abs/1904.06487.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.lmda = cfg.TRAINER.MME.LMDA
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
|
||||
print("Building C")
|
||||
self.C = Prototypes(self.F.fdim, self.num_classes)
|
||||
self.C.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.C)))
|
||||
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
|
||||
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
|
||||
self.register_model("C", self.C, self.optim_C, self.sched_C)
|
||||
|
||||
self.revgrad = ReverseGrad()
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
|
||||
feat_x = self.F(input_x)
|
||||
logit_x = self.C(feat_x)
|
||||
loss_x = F.cross_entropy(logit_x, label_x)
|
||||
self.model_backward_and_update(loss_x)
|
||||
|
||||
feat_u = self.F(input_u)
|
||||
feat_u = self.revgrad(feat_u)
|
||||
logit_u = self.C(feat_u)
|
||||
prob_u = F.softmax(logit_u, 1)
|
||||
loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
|
||||
self.model_backward_and_update(loss_u * self.lmda)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.C(self.F(input))
|
||||
78
Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
Normal file
78
Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import copy
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class SelfEnsembling(TrainerXU):
|
||||
"""Self-ensembling for visual domain adaptation.
|
||||
|
||||
https://arxiv.org/abs/1706.05208.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA
|
||||
self.conf_thre = cfg.TRAINER.SE.CONF_THRE
|
||||
self.rampup = cfg.TRAINER.SE.RAMPUP
|
||||
|
||||
self.teacher = copy.deepcopy(self.model)
|
||||
self.teacher.train()
|
||||
for param in self.teacher.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.K_TRANSFORMS == 2
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, label_x, input_u1, input_u2 = parsed
|
||||
|
||||
logit_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(logit_x, label_x)
|
||||
|
||||
prob_u = F.softmax(self.model(input_u1), 1)
|
||||
t_prob_u = F.softmax(self.teacher(input_u2), 1)
|
||||
loss_u = ((prob_u - t_prob_u)**2).sum(1)
|
||||
|
||||
if self.conf_thre:
|
||||
max_prob = t_prob_u.max(1)[0]
|
||||
mask = (max_prob > self.conf_thre).float()
|
||||
loss_u = (loss_u * mask).mean()
|
||||
else:
|
||||
weight_u = sigmoid_rampup(global_step, self.rampup)
|
||||
loss_u = loss_u.mean() * weight_u
|
||||
|
||||
loss = loss_x + loss_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
|
||||
ema_model_update(self.model, self.teacher, ema_alpha)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"][0]
|
||||
label_x = batch_x["label"]
|
||||
input_u = batch_u["img"]
|
||||
input_u1, input_u2 = input_u
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u1 = input_u1.to(self.device)
|
||||
input_u2 = input_u2.to(self.device)
|
||||
|
||||
return input_x, label_x, input_u1, input_u2
|
||||
34
Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
Normal file
34
Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class SourceOnly(TrainerXU):
|
||||
"""Baseline model for domain adaptation, which is
|
||||
trained using source data only.
|
||||
"""
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input, label = self.parse_batch_train(batch_x, batch_u)
|
||||
output = self.model(input)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss": loss.item(),
|
||||
"acc": compute_accuracy(output, label)[0].item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input = batch_x["img"]
|
||||
label = batch_x["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
4
Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py
Normal file
4
Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .ddaig import DDAIG
|
||||
from .daeldg import DAELDG
|
||||
from .vanilla import Vanilla
|
||||
from .crossgrad import CrossGrad
|
||||
83
Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
Normal file
83
Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class CrossGrad(TrainerX):
|
||||
"""Cross-gradient training.
|
||||
|
||||
https://arxiv.org/abs/1804.10745.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.eps_f = cfg.TRAINER.CG.EPS_F
|
||||
self.eps_d = cfg.TRAINER.CG.EPS_D
|
||||
self.alpha_f = cfg.TRAINER.CG.ALPHA_F
|
||||
self.alpha_d = cfg.TRAINER.CG.ALPHA_D
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
|
||||
print("Building D")
|
||||
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
|
||||
self.D.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.D)))
|
||||
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
|
||||
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
|
||||
self.register_model("D", self.D, self.optim_D, self.sched_D)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
input, label, domain = self.parse_batch_train(batch)
|
||||
|
||||
input.requires_grad = True
|
||||
|
||||
# Compute domain perturbation
|
||||
loss_d = F.cross_entropy(self.D(input), domain)
|
||||
loss_d.backward()
|
||||
grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
|
||||
input_d = input.data + self.eps_f * grad_d
|
||||
|
||||
# Compute label perturbation
|
||||
input.grad.data.zero_()
|
||||
loss_f = F.cross_entropy(self.F(input), label)
|
||||
loss_f.backward()
|
||||
grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
|
||||
input_f = input.data + self.eps_d * grad_f
|
||||
|
||||
input = input.detach()
|
||||
|
||||
# Update label net
|
||||
loss_f1 = F.cross_entropy(self.F(input), label)
|
||||
loss_f2 = F.cross_entropy(self.F(input_d), label)
|
||||
loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
|
||||
self.model_backward_and_update(loss_f, "F")
|
||||
|
||||
# Update domain net
|
||||
loss_d1 = F.cross_entropy(self.D(input), domain)
|
||||
loss_d2 = F.cross_entropy(self.D(input_f), domain)
|
||||
loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
|
||||
self.model_backward_and_update(loss_d, "D")
|
||||
|
||||
loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.F(input)
|
||||
169
Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
Normal file
169
Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
from dassl.data.transforms import build_transform
|
||||
from dassl.modeling.ops.utils import create_onehot
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
|
||||
def __init__(self, n_source, fdim, num_classes):
|
||||
super().__init__()
|
||||
self.linears = nn.ModuleList(
|
||||
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
|
||||
)
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
def forward(self, i, x):
|
||||
x = self.linears[i](x)
|
||||
x = self.softmax(x)
|
||||
return x
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DAELDG(TrainerX):
|
||||
"""Domain Adaptive Ensemble Learning.
|
||||
|
||||
DG version: only use labeled source data.
|
||||
|
||||
https://arxiv.org/abs/2003.07325.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
if n_domain <= 0:
|
||||
n_domain = self.num_source_domains
|
||||
self.split_batch = batch_size // n_domain
|
||||
self.n_domain = n_domain
|
||||
|
||||
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
|
||||
|
||||
def build_data_loader(self):
|
||||
cfg = self.cfg
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
custom_tfm_train = [tfm_train]
|
||||
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
|
||||
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||
custom_tfm_train += [tfm_train_strong]
|
||||
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||
self.train_loader_x = dm.train_loader_x
|
||||
self.train_loader_u = dm.train_loader_u
|
||||
self.val_loader = dm.val_loader
|
||||
self.test_loader = dm.test_loader
|
||||
self.num_classes = dm.num_classes
|
||||
self.num_source_domains = dm.num_source_domains
|
||||
self.lab2cname = dm.lab2cname
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
fdim = self.F.fdim
|
||||
|
||||
print("Building E")
|
||||
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
|
||||
self.E.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.E)))
|
||||
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
|
||||
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
|
||||
self.register_model("E", self.E, self.optim_E, self.sched_E)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
parsed_data = self.parse_batch_train(batch)
|
||||
input, input2, label, domain = parsed_data
|
||||
|
||||
input = torch.split(input, self.split_batch, 0)
|
||||
input2 = torch.split(input2, self.split_batch, 0)
|
||||
label = torch.split(label, self.split_batch, 0)
|
||||
domain = torch.split(domain, self.split_batch, 0)
|
||||
domain = [d[0].item() for d in domain]
|
||||
|
||||
loss_x = 0
|
||||
loss_cr = 0
|
||||
acc = 0
|
||||
|
||||
feat = [self.F(x) for x in input]
|
||||
feat2 = [self.F(x) for x in input2]
|
||||
|
||||
for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
|
||||
cr_s = [j for j in domain if j != i]
|
||||
|
||||
# Learning expert
|
||||
pred_i = self.E(i, feat_i)
|
||||
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
|
||||
expert_label_i = pred_i.detach()
|
||||
acc += compute_accuracy(pred_i.detach(),
|
||||
label_i.max(1)[1])[0].item()
|
||||
|
||||
# Consistency regularization
|
||||
cr_pred = []
|
||||
for j in cr_s:
|
||||
pred_j = self.E(j, feat2_i)
|
||||
pred_j = pred_j.unsqueeze(1)
|
||||
cr_pred.append(pred_j)
|
||||
cr_pred = torch.cat(cr_pred, 1)
|
||||
cr_pred = cr_pred.mean(1)
|
||||
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
|
||||
|
||||
loss_x /= self.n_domain
|
||||
loss_cr /= self.n_domain
|
||||
acc /= self.n_domain
|
||||
|
||||
loss = 0
|
||||
loss += loss_x
|
||||
loss += loss_cr
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc": acc,
|
||||
"loss_cr": loss_cr.item()
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
input2 = batch["img2"]
|
||||
label = batch["label"]
|
||||
domain = batch["domain"]
|
||||
|
||||
label = create_onehot(label, self.num_classes)
|
||||
|
||||
input = input.to(self.device)
|
||||
input2 = input2.to(self.device)
|
||||
label = label.to(self.device)
|
||||
|
||||
return input, input2, label, domain
|
||||
|
||||
def model_inference(self, input):
|
||||
f = self.F(input)
|
||||
p = []
|
||||
for k in range(self.num_source_domains):
|
||||
p_k = self.E(k, f)
|
||||
p_k = p_k.unsqueeze(1)
|
||||
p.append(p_k)
|
||||
p = torch.cat(p, 1)
|
||||
p = p.mean(1)
|
||||
return p
|
||||
107
Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
Normal file
107
Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.modeling import build_network
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DDAIG(TrainerX):
|
||||
"""Deep Domain-Adversarial Image Generation.
|
||||
|
||||
https://arxiv.org/abs/2003.06054.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.lmda = cfg.TRAINER.DDAIG.LMDA
|
||||
self.clamp = cfg.TRAINER.DDAIG.CLAMP
|
||||
self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
|
||||
self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
|
||||
self.warmup = cfg.TRAINER.DDAIG.WARMUP
|
||||
self.alpha = cfg.TRAINER.DDAIG.ALPHA
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
|
||||
print("Building D")
|
||||
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
|
||||
self.D.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.D)))
|
||||
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
|
||||
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
|
||||
self.register_model("D", self.D, self.optim_D, self.sched_D)
|
||||
|
||||
print("Building G")
|
||||
self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
|
||||
self.G.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.G)))
|
||||
self.optim_G = build_optimizer(self.G, cfg.OPTIM)
|
||||
self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
|
||||
self.register_model("G", self.G, self.optim_G, self.sched_G)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
input, label, domain = self.parse_batch_train(batch)
|
||||
|
||||
#############
|
||||
# Update G
|
||||
#############
|
||||
input_p = self.G(input, lmda=self.lmda)
|
||||
if self.clamp:
|
||||
input_p = torch.clamp(
|
||||
input_p, min=self.clamp_min, max=self.clamp_max
|
||||
)
|
||||
loss_g = 0
|
||||
# Minimize label loss
|
||||
loss_g += F.cross_entropy(self.F(input_p), label)
|
||||
# Maximize domain loss
|
||||
loss_g -= F.cross_entropy(self.D(input_p), domain)
|
||||
self.model_backward_and_update(loss_g, "G")
|
||||
|
||||
# Perturb data with new G
|
||||
with torch.no_grad():
|
||||
input_p = self.G(input, lmda=self.lmda)
|
||||
if self.clamp:
|
||||
input_p = torch.clamp(
|
||||
input_p, min=self.clamp_min, max=self.clamp_max
|
||||
)
|
||||
|
||||
#############
|
||||
# Update F
|
||||
#############
|
||||
loss_f = F.cross_entropy(self.F(input), label)
|
||||
if (self.epoch + 1) > self.warmup:
|
||||
loss_fp = F.cross_entropy(self.F(input_p), label)
|
||||
loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
|
||||
self.model_backward_and_update(loss_f, "F")
|
||||
|
||||
#############
|
||||
# Update D
|
||||
#############
|
||||
loss_d = F.cross_entropy(self.D(input), domain)
|
||||
self.model_backward_and_update(loss_d, "D")
|
||||
|
||||
loss_summary = {
|
||||
"loss_g": loss_g.item(),
|
||||
"loss_f": loss_f.item(),
|
||||
"loss_d": loss_d.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.F(input)
|
||||
32
Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py
Normal file
32
Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class Vanilla(TrainerX):
|
||||
"""Vanilla baseline."""
|
||||
|
||||
def forward_backward(self, batch):
|
||||
input, label = self.parse_batch_train(batch)
|
||||
output = self.model(input)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss": loss.item(),
|
||||
"acc": compute_accuracy(output, label)[0].item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
5
Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .entmin import EntMin
|
||||
from .fixmatch import FixMatch
|
||||
from .mixmatch import MixMatch
|
||||
from .mean_teacher import MeanTeacher
|
||||
from .sup_baseline import SupBaseline
|
||||
41
Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
Normal file
41
Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class EntMin(TrainerXU):
|
||||
"""Entropy Minimization.
|
||||
|
||||
http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.lmda = cfg.TRAINER.ENTMIN.LMDA
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
|
||||
output_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(output_x, label_x)
|
||||
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()
|
||||
|
||||
loss = loss_x + loss_u * self.lmda
|
||||
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
112
Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
Normal file
112
Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.data.transforms import build_transform
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class FixMatch(TrainerXU):
|
||||
"""FixMatch: Simplifying Semi-Supervised Learning with
|
||||
Consistency and Confidence.
|
||||
|
||||
https://arxiv.org/abs/2001.07685.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U
|
||||
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0
|
||||
|
||||
def build_data_loader(self):
|
||||
cfg = self.cfg
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
custom_tfm_train = [tfm_train]
|
||||
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS
|
||||
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||
custom_tfm_train += [tfm_train_strong]
|
||||
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||
self.train_loader_x = self.dm.train_loader_x
|
||||
self.train_loader_u = self.dm.train_loader_u
|
||||
self.val_loader = self.dm.val_loader
|
||||
self.test_loader = self.dm.test_loader
|
||||
self.num_classes = self.dm.num_classes
|
||||
|
||||
def assess_y_pred_quality(self, y_pred, y_true, mask):
|
||||
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
|
||||
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
|
||||
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
|
||||
keep_rate = mask.sum() / mask.numel()
|
||||
output = {
|
||||
"acc_thre": acc_thre,
|
||||
"acc_raw": acc_raw,
|
||||
"keep_rate": keep_rate
|
||||
}
|
||||
return output
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed_data = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
|
||||
input_u = torch.cat([input_x, input_u], 0)
|
||||
input_u2 = torch.cat([input_x2, input_u2], 0)
|
||||
n_x = input_x.size(0)
|
||||
|
||||
# Generate pseudo labels
|
||||
with torch.no_grad():
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
max_prob, label_u_pred = output_u.max(1)
|
||||
mask_u = (max_prob >= self.conf_thre).float()
|
||||
|
||||
# Evaluate pseudo labels' accuracy
|
||||
y_u_pred_stats = self.assess_y_pred_quality(
|
||||
label_u_pred[n_x:], label_u, mask_u[n_x:]
|
||||
)
|
||||
|
||||
# Supervised loss
|
||||
output_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(output_x, label_x)
|
||||
|
||||
# Unsupervised loss
|
||||
output_u = self.model(input_u2)
|
||||
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none")
|
||||
loss_u = (loss_u * mask_u).mean()
|
||||
|
||||
loss = loss_x + loss_u * self.weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"],
|
||||
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"],
|
||||
"y_u_pred_keep": y_u_pred_stats["keep_rate"],
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
input_x2 = batch_x["img2"]
|
||||
label_x = batch_x["label"]
|
||||
input_u = batch_u["img"]
|
||||
input_u2 = batch_u["img2"]
|
||||
# label_u is used only for evaluating pseudo labels' accuracy
|
||||
label_u = batch_u["label"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
input_x2 = input_x2.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
input_u2 = input_u2.to(self.device)
|
||||
label_u = label_u.to(self.device)
|
||||
|
||||
return input_x, input_x2, label_x, input_u, input_u2, label_u
|
||||
54
Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py
Normal file
54
Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import copy
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MeanTeacher(TrainerXU):
|
||||
"""Mean teacher.
|
||||
|
||||
https://arxiv.org/abs/1703.01780.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U
|
||||
self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA
|
||||
self.rampup = cfg.TRAINER.MEANTEA.RAMPUP
|
||||
|
||||
self.teacher = copy.deepcopy(self.model)
|
||||
self.teacher.train()
|
||||
for param in self.teacher.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
|
||||
logit_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(logit_x, label_x)
|
||||
|
||||
target_u = F.softmax(self.teacher(input_u), 1)
|
||||
prob_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = ((prob_u - target_u)**2).sum(1).mean()
|
||||
|
||||
weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
|
||||
loss = loss_x + loss_u*weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
|
||||
ema_model_update(self.model, self.teacher, ema_alpha)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
98
Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py
Normal file
98
Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.modeling.ops import mixup
|
||||
from dassl.modeling.ops.utils import (
|
||||
sharpen_prob, create_onehot, linear_rampup, shuffle_index
|
||||
)
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MixMatch(TrainerXU):
|
||||
"""MixMatch: A Holistic Approach to Semi-Supervised Learning.
|
||||
|
||||
https://arxiv.org/abs/1905.02249.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U
|
||||
self.temp = cfg.TRAINER.MIXMATCH.TEMP
|
||||
self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA
|
||||
self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.K_TRANSFORMS > 1
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
num_x = input_x.shape[0]
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
weight_u = self.weight_u * linear_rampup(global_step, self.rampup)
|
||||
|
||||
# Generate pseudo-label for unlabeled data
|
||||
with torch.no_grad():
|
||||
output_u = 0
|
||||
for input_ui in input_u:
|
||||
output_ui = F.softmax(self.model(input_ui), 1)
|
||||
output_u += output_ui
|
||||
output_u /= len(input_u)
|
||||
label_u = sharpen_prob(output_u, self.temp)
|
||||
label_u = [label_u] * len(input_u)
|
||||
label_u = torch.cat(label_u, 0)
|
||||
input_u = torch.cat(input_u, 0)
|
||||
|
||||
# Combine and shuffle labeled and unlabeled data
|
||||
input_xu = torch.cat([input_x, input_u], 0)
|
||||
label_xu = torch.cat([label_x, label_u], 0)
|
||||
input_xu, label_xu = shuffle_index(input_xu, label_xu)
|
||||
|
||||
# Mixup
|
||||
input_x, label_x = mixup(
|
||||
input_x,
|
||||
input_xu[:num_x],
|
||||
label_x,
|
||||
label_xu[:num_x],
|
||||
self.beta,
|
||||
preserve_order=True,
|
||||
)
|
||||
|
||||
input_u, label_u = mixup(
|
||||
input_u,
|
||||
input_xu[num_x:],
|
||||
label_u,
|
||||
label_xu[num_x:],
|
||||
self.beta,
|
||||
preserve_order=True,
|
||||
)
|
||||
|
||||
# Compute losses
|
||||
output_x = F.softmax(self.model(input_x), 1)
|
||||
loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()
|
||||
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = ((label_u - output_u)**2).mean()
|
||||
|
||||
loss = loss_x + loss_u*weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"][0]
|
||||
label_x = batch_x["label"]
|
||||
label_x = create_onehot(label_x, self.num_classes)
|
||||
input_u = batch_u["img"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = [input_ui.to(self.device) for input_ui in input_u]
|
||||
|
||||
return input_x, label_x, input_u
|
||||
32
Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py
Normal file
32
Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class SupBaseline(TrainerXU):
|
||||
"""Supervised Baseline."""
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input, label = self.parse_batch_train(batch_x, batch_u)
|
||||
output = self.model(input)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss": loss.item(),
|
||||
"acc": compute_accuracy(output, label)[0].item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input = batch_x["img"]
|
||||
label = batch_x["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
735
Dassl.ProGrad.pytorch/dassl/engine/trainer.py
Normal file
735
Dassl.ProGrad.pytorch/dassl/engine/trainer.py
Normal file
@@ -0,0 +1,735 @@
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import (
|
||||
MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
|
||||
save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
|
||||
load_pretrained_weights
|
||||
)
|
||||
from dassl.modeling import build_head, build_backbone
|
||||
from dassl.evaluation import build_evaluator
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
"""A simple neural network composed of a CNN backbone
|
||||
and optionally a head such as mlp for classification.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, model_cfg, num_classes, **kwargs):
|
||||
super().__init__()
|
||||
self.backbone = build_backbone(
|
||||
model_cfg.BACKBONE.NAME,
|
||||
verbose=cfg.VERBOSE,
|
||||
pretrained=model_cfg.BACKBONE.PRETRAINED,
|
||||
**kwargs,
|
||||
)
|
||||
fdim = self.backbone.out_features
|
||||
|
||||
self.head = None
|
||||
if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:
|
||||
self.head = build_head(
|
||||
model_cfg.HEAD.NAME,
|
||||
verbose=cfg.VERBOSE,
|
||||
in_features=fdim,
|
||||
hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,
|
||||
activation=model_cfg.HEAD.ACTIVATION,
|
||||
bn=model_cfg.HEAD.BN,
|
||||
dropout=model_cfg.HEAD.DROPOUT,
|
||||
**kwargs,
|
||||
)
|
||||
fdim = self.head.out_features
|
||||
|
||||
self.classifier = None
|
||||
if num_classes > 0:
|
||||
self.classifier = nn.Linear(fdim, num_classes)
|
||||
|
||||
self._fdim = fdim
|
||||
|
||||
@property
|
||||
def fdim(self):
|
||||
return self._fdim
|
||||
|
||||
def forward(self, x, return_feature=False):
|
||||
f = self.backbone(x)
|
||||
if self.head is not None:
|
||||
f = self.head(f)
|
||||
|
||||
if self.classifier is None:
|
||||
return f
|
||||
|
||||
y = self.classifier(f)
|
||||
|
||||
if return_feature:
|
||||
return y, f
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class TrainerBase:
|
||||
"""Base class for iterative trainer."""
|
||||
|
||||
def __init__(self):
|
||||
self._models = OrderedDict()
|
||||
self._optims = OrderedDict()
|
||||
self._scheds = OrderedDict()
|
||||
self._writer = None
|
||||
|
||||
def register_model(self, name="model", model=None, optim=None, sched=None):
|
||||
if self.__dict__.get("_models") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign model before super().__init__() call"
|
||||
)
|
||||
|
||||
if self.__dict__.get("_optims") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign optim before super().__init__() call"
|
||||
)
|
||||
|
||||
if self.__dict__.get("_scheds") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign sched before super().__init__() call"
|
||||
)
|
||||
|
||||
assert name not in self._models, "Found duplicate model names"
|
||||
|
||||
self._models[name] = model
|
||||
self._optims[name] = optim
|
||||
self._scheds[name] = sched
|
||||
|
||||
def get_model_names(self, names=None):
|
||||
names_real = list(self._models.keys())
|
||||
if names is not None:
|
||||
names = tolist_if_not(names)
|
||||
for name in names:
|
||||
assert name in names_real
|
||||
return names
|
||||
else:
|
||||
return names_real
|
||||
|
||||
def save_model(self, epoch, directory, is_best=False, model_name=""):
|
||||
names = self.get_model_names()
|
||||
|
||||
for name in names:
|
||||
model_dict = self._models[name].state_dict()
|
||||
|
||||
optim_dict = None
|
||||
if self._optims[name] is not None:
|
||||
optim_dict = self._optims[name].state_dict()
|
||||
|
||||
sched_dict = None
|
||||
if self._scheds[name] is not None:
|
||||
sched_dict = self._scheds[name].state_dict()
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"state_dict": model_dict,
|
||||
"epoch": epoch + 1,
|
||||
"optimizer": optim_dict,
|
||||
"scheduler": sched_dict,
|
||||
},
|
||||
osp.join(directory, name),
|
||||
is_best=is_best,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
def resume_model_if_exist(self, directory):
|
||||
names = self.get_model_names()
|
||||
file_missing = False
|
||||
|
||||
for name in names:
|
||||
path = osp.join(directory, name)
|
||||
if not osp.exists(path):
|
||||
file_missing = True
|
||||
break
|
||||
|
||||
if file_missing:
|
||||
print("No checkpoint found, train from scratch")
|
||||
return 0
|
||||
|
||||
print(
|
||||
'Found checkpoint in "{}". Will resume training'.format(directory)
|
||||
)
|
||||
|
||||
for name in names:
|
||||
path = osp.join(directory, name)
|
||||
start_epoch = resume_from_checkpoint(
|
||||
path, self._models[name], self._optims[name],
|
||||
self._scheds[name]
|
||||
)
|
||||
|
||||
return start_epoch
|
||||
|
||||
def load_model(self, directory, epoch=None):
|
||||
if not directory:
|
||||
print(
|
||||
"Note that load_model() is skipped as no pretrained "
|
||||
"model is given (ignore this if it's done on purpose)"
|
||||
)
|
||||
return
|
||||
|
||||
names = self.get_model_names()
|
||||
|
||||
# By default, the best model is loaded
|
||||
model_file = "model-best.pth.tar"
|
||||
|
||||
if epoch is not None:
|
||||
model_file = "model.pth.tar-" + str(epoch)
|
||||
|
||||
for name in names:
|
||||
model_path = osp.join(directory, name, model_file)
|
||||
|
||||
if not osp.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
'Model not found at "{}"'.format(model_path)
|
||||
)
|
||||
|
||||
checkpoint = load_checkpoint(model_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint["epoch"]
|
||||
|
||||
print(
|
||||
"Loading weights to {} "
|
||||
'from "{}" (epoch = {})'.format(name, model_path, epoch)
|
||||
)
|
||||
self._models[name].load_state_dict(state_dict)
|
||||
|
||||
def set_model_mode(self, mode="train", names=None):
|
||||
names = self.get_model_names(names)
|
||||
|
||||
for name in names:
|
||||
if mode == "train":
|
||||
self._models[name].train()
|
||||
elif mode in ["test", "eval"]:
|
||||
self._models[name].eval()
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
def update_lr(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
|
||||
for name in names:
|
||||
if self._scheds[name] is not None:
|
||||
self._scheds[name].step()
|
||||
|
||||
def detect_anomaly(self, loss):
|
||||
if not torch.isfinite(loss).all():
|
||||
raise FloatingPointError("Loss is infinite or NaN!")
|
||||
|
||||
def init_writer(self, log_dir):
|
||||
if self.__dict__.get("_writer") is None or self._writer is None:
|
||||
print(
|
||||
"Initializing summary writer for tensorboard "
|
||||
"with log_dir={}".format(log_dir)
|
||||
)
|
||||
self._writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def close_writer(self):
|
||||
if self._writer is not None:
|
||||
self._writer.close()
|
||||
|
||||
def write_scalar(self, tag, scalar_value, global_step=None):
|
||||
if self._writer is None:
|
||||
# Do nothing if writer is not initialized
|
||||
# Note that writer is only used when training is needed
|
||||
pass
|
||||
else:
|
||||
self._writer.add_scalar(tag, scalar_value, global_step)
|
||||
|
||||
def train(self, start_epoch, max_epoch):
|
||||
"""Generic training loops."""
|
||||
self.start_epoch = start_epoch
|
||||
self.max_epoch = max_epoch
|
||||
|
||||
self.before_train()
|
||||
for self.epoch in range(self.start_epoch, self.max_epoch):
|
||||
self.before_epoch()
|
||||
self.run_epoch()
|
||||
self.after_epoch()
|
||||
self.after_train()
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def after_train(self):
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def run_epoch(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_batch_test(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_backward(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def model_inference(self, input):
|
||||
raise NotImplementedError
|
||||
|
||||
def model_zero_grad(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
for name in names:
|
||||
if self._optims[name] is not None:
|
||||
self._optims[name].zero_grad()
|
||||
|
||||
def model_backward(self, loss):
|
||||
self.detect_anomaly(loss)
|
||||
loss.backward()
|
||||
|
||||
def model_update(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
for name in names:
|
||||
if self._optims[name] is not None:
|
||||
self._optims[name].step()
|
||||
|
||||
def model_backward_and_update(self, loss, names=None):
|
||||
self.model_zero_grad(names)
|
||||
self.model_backward(loss)
|
||||
self.model_update(names)
|
||||
|
||||
def prograd_backward_and_update(
|
||||
self, loss_a, loss_b, lambda_=1, names=None
|
||||
):
|
||||
# loss_b not increase is okay
|
||||
# loss_a has to decline
|
||||
self.model_zero_grad(names)
|
||||
# get name of the model parameters
|
||||
names = self.get_model_names(names)
|
||||
# backward loss_a
|
||||
self.detect_anomaly(loss_b)
|
||||
loss_b.backward(retain_graph=True)
|
||||
# normalize gradient
|
||||
b_grads = []
|
||||
for name in names:
|
||||
for p in self._models[name].parameters():
|
||||
b_grads.append(p.grad.clone())
|
||||
|
||||
# optimizer don't step
|
||||
for name in names:
|
||||
self._optims[name].zero_grad()
|
||||
|
||||
# backward loss_a
|
||||
self.detect_anomaly(loss_a)
|
||||
loss_a.backward()
|
||||
for name in names:
|
||||
for p, b_grad in zip(self._models[name].parameters(), b_grads):
|
||||
# calculate cosine distance
|
||||
b_grad_norm = b_grad / torch.linalg.norm(b_grad)
|
||||
a_grad = p.grad.clone()
|
||||
a_grad_norm = a_grad / torch.linalg.norm(a_grad)
|
||||
|
||||
if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:
|
||||
p.grad = a_grad - lambda_ * torch.dot(
|
||||
a_grad.flatten(), b_grad_norm.flatten()
|
||||
) * b_grad_norm
|
||||
|
||||
# optimizer
|
||||
for name in names:
|
||||
self._optims[name].step()
|
||||
|
||||
|
||||
class SimpleTrainer(TrainerBase):
|
||||
"""A simple trainer class implementing generic functions."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.check_cfg(cfg)
|
||||
|
||||
if torch.cuda.is_available() and cfg.USE_CUDA:
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# Save as attributes some frequently used variables
|
||||
self.start_epoch = self.epoch = 0
|
||||
self.max_epoch = cfg.OPTIM.MAX_EPOCH
|
||||
self.output_dir = cfg.OUTPUT_DIR
|
||||
|
||||
self.cfg = cfg
|
||||
self.build_data_loader()
|
||||
self.build_model()
|
||||
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
|
||||
self.best_result = -np.inf
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
"""Check whether some variables are set correctly for
|
||||
the trainer (optional).
|
||||
|
||||
For example, a trainer might require a particular sampler
|
||||
for training such as 'RandomDomainSampler', so it is good
|
||||
to do the checking:
|
||||
|
||||
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
|
||||
"""
|
||||
pass
|
||||
|
||||
def build_data_loader(self):
|
||||
"""Create essential data-related attributes.
|
||||
|
||||
A re-implementation of this method must create the
|
||||
same attributes (except self.dm).
|
||||
"""
|
||||
dm = DataManager(self.cfg)
|
||||
|
||||
self.train_loader_x = dm.train_loader_x
|
||||
self.train_loader_u = dm.train_loader_u # optional, can be None
|
||||
self.val_loader = dm.val_loader # optional, can be None
|
||||
self.test_loader = dm.test_loader
|
||||
self.num_classes = dm.num_classes
|
||||
self.num_source_domains = dm.num_source_domains
|
||||
self.lab2cname = dm.lab2cname # dict {label: classname}
|
||||
|
||||
self.dm = dm
|
||||
|
||||
def build_model(self):
|
||||
"""Build and register model.
|
||||
|
||||
The default builds a classification model along with its
|
||||
optimizer and scheduler.
|
||||
|
||||
Custom trainers can re-implement this method if necessary.
|
||||
"""
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building model")
|
||||
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||
if cfg.MODEL.INIT_WEIGHTS:
|
||||
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
|
||||
self.model.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.model)))
|
||||
self.optim = build_optimizer(self.model, cfg.OPTIM)
|
||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||
self.register_model("model", self.model, self.optim, self.sched)
|
||||
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count > 1:
|
||||
print(
|
||||
f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel"
|
||||
)
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
def train(self):
|
||||
super().train(self.start_epoch, self.max_epoch)
|
||||
|
||||
def before_train(self):
|
||||
directory = self.cfg.OUTPUT_DIR
|
||||
if self.cfg.RESUME:
|
||||
directory = self.cfg.RESUME
|
||||
self.start_epoch = self.resume_model_if_exist(directory)
|
||||
|
||||
# Initialize summary writer
|
||||
writer_dir = osp.join(self.output_dir, "tensorboard")
|
||||
mkdir_if_missing(writer_dir)
|
||||
self.init_writer(writer_dir)
|
||||
|
||||
# Remember the starting time (for computing the elapsed time)
|
||||
self.time_start = time.time()
|
||||
|
||||
def after_train(self):
|
||||
print("Finished training")
|
||||
|
||||
do_test = not self.cfg.TEST.NO_TEST
|
||||
if do_test:
|
||||
if self.cfg.TEST.FINAL_MODEL == "best_val":
|
||||
print("Deploy the model with the best val performance")
|
||||
self.load_model(self.output_dir)
|
||||
self.test()
|
||||
|
||||
# Show elapsed time
|
||||
elapsed = round(time.time() - self.time_start)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed: {}".format(elapsed))
|
||||
|
||||
# Close writer
|
||||
self.close_writer()
|
||||
|
||||
def after_epoch(self):
|
||||
last_epoch = (self.epoch + 1) == self.max_epoch
|
||||
do_test = not self.cfg.TEST.NO_TEST
|
||||
meet_checkpoint_freq = (
|
||||
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
|
||||
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
|
||||
)
|
||||
|
||||
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
|
||||
curr_result = self.test(split="val")
|
||||
is_best = curr_result > self.best_result
|
||||
if is_best:
|
||||
self.best_result = curr_result
|
||||
self.save_model(
|
||||
self.epoch,
|
||||
self.output_dir,
|
||||
model_name="model-best.pth.tar"
|
||||
)
|
||||
|
||||
if meet_checkpoint_freq or last_epoch:
|
||||
self.save_model(self.epoch, self.output_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def output_test(self, split=None):
|
||||
"""testing pipline, which could also output the results."""
|
||||
self.set_model_mode("eval")
|
||||
self.evaluator.reset()
|
||||
|
||||
output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')
|
||||
res_json = {}
|
||||
|
||||
if split is None:
|
||||
split = self.cfg.TEST.SPLIT
|
||||
|
||||
if split == "val" and self.val_loader is not None:
|
||||
data_loader = self.val_loader
|
||||
print("Do evaluation on {} set".format(split))
|
||||
else:
|
||||
data_loader = self.test_loader
|
||||
print("Do evaluation on test set")
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
img_path = batch['impath']
|
||||
input, label = self.parse_batch_test(batch)
|
||||
output = self.model_inference(input)
|
||||
self.evaluator.process(output, label)
|
||||
for i in range(len(img_path)):
|
||||
res_json[img_path[i]] = {
|
||||
'predict': output[i].cpu().numpy().tolist(),
|
||||
'gt': label[i].cpu().numpy().tolist()
|
||||
}
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(res_json, f)
|
||||
results = self.evaluator.evaluate()
|
||||
|
||||
for k, v in results.items():
|
||||
tag = "{}/{}".format(split, k)
|
||||
self.write_scalar(tag, v, self.epoch)
|
||||
|
||||
return list(results.values())[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def test(self, split=None):
|
||||
"""A generic testing pipeline."""
|
||||
self.set_model_mode("eval")
|
||||
self.evaluator.reset()
|
||||
|
||||
if split is None:
|
||||
split = self.cfg.TEST.SPLIT
|
||||
|
||||
if split == "val" and self.val_loader is not None:
|
||||
data_loader = self.val_loader
|
||||
print("Do evaluation on {} set".format(split))
|
||||
else:
|
||||
data_loader = self.test_loader
|
||||
print("Do evaluation on test set")
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
input, label = self.parse_batch_test(batch)
|
||||
output = self.model_inference(input)
|
||||
self.evaluator.process(output, label)
|
||||
|
||||
results = self.evaluator.evaluate()
|
||||
|
||||
for k, v in results.items():
|
||||
tag = "{}/{}".format(split, k)
|
||||
self.write_scalar(tag, v, self.epoch)
|
||||
|
||||
return list(results.values())[0]
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.model(input)
|
||||
|
||||
def parse_batch_test(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
|
||||
return input, label
|
||||
|
||||
def get_current_lr(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
name = names[0]
|
||||
return self._optims[name].param_groups[0]["lr"]
|
||||
|
||||
|
||||
class TrainerXU(SimpleTrainer):
|
||||
"""A base trainer using both labeled and unlabeled data.
|
||||
|
||||
In the context of domain adaptation, labeled and unlabeled data
|
||||
come from source and target domains respectively.
|
||||
|
||||
When it comes to semi-supervised learning, all data comes from the
|
||||
same domain.
|
||||
"""
|
||||
|
||||
def run_epoch(self):
|
||||
self.set_model_mode("train")
|
||||
losses = MetricMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
||||
# Decide to iterate over labeled or unlabeled dataset
|
||||
len_train_loader_x = len(self.train_loader_x)
|
||||
len_train_loader_u = len(self.train_loader_u)
|
||||
if self.cfg.TRAIN.COUNT_ITER == "train_x":
|
||||
self.num_batches = len_train_loader_x
|
||||
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
|
||||
self.num_batches = len_train_loader_u
|
||||
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
|
||||
self.num_batches = min(len_train_loader_x, len_train_loader_u)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
train_loader_x_iter = iter(self.train_loader_x)
|
||||
train_loader_u_iter = iter(self.train_loader_u)
|
||||
|
||||
end = time.time()
|
||||
for self.batch_idx in range(self.num_batches):
|
||||
try:
|
||||
batch_x = next(train_loader_x_iter)
|
||||
except StopIteration:
|
||||
train_loader_x_iter = iter(self.train_loader_x)
|
||||
batch_x = next(train_loader_x_iter)
|
||||
|
||||
try:
|
||||
batch_u = next(train_loader_u_iter)
|
||||
except StopIteration:
|
||||
train_loader_u_iter = iter(self.train_loader_u)
|
||||
batch_u = next(train_loader_u_iter)
|
||||
|
||||
data_time.update(time.time() - end)
|
||||
loss_summary = self.forward_backward(batch_x, batch_u)
|
||||
batch_time.update(time.time() - end)
|
||||
losses.update(loss_summary)
|
||||
|
||||
if (
|
||||
self.batch_idx + 1
|
||||
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
|
||||
nb_remain = 0
|
||||
nb_remain += self.num_batches - self.batch_idx - 1
|
||||
nb_remain += (
|
||||
self.max_epoch - self.epoch - 1
|
||||
) * self.num_batches
|
||||
eta_seconds = batch_time.avg * nb_remain
|
||||
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
print(
|
||||
"epoch [{0}/{1}][{2}/{3}]\t"
|
||||
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
||||
"eta {eta}\t"
|
||||
"{losses}\t"
|
||||
"lr {lr:.6e}".format(
|
||||
self.epoch + 1,
|
||||
self.max_epoch,
|
||||
self.batch_idx + 1,
|
||||
self.num_batches,
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
eta=eta,
|
||||
losses=losses,
|
||||
lr=self.get_current_lr(),
|
||||
)
|
||||
)
|
||||
|
||||
n_iter = self.epoch * self.num_batches + self.batch_idx
|
||||
for name, meter in losses.meters.items():
|
||||
self.write_scalar("train/" + name, meter.avg, n_iter)
|
||||
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
||||
|
||||
end = time.time()
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
label_x = batch_x["label"]
|
||||
input_u = batch_u["img"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
|
||||
return input_x, label_x, input_u
|
||||
|
||||
|
||||
class TrainerX(SimpleTrainer):
|
||||
"""A base trainer using labeled data only."""
|
||||
|
||||
def run_epoch(self):
|
||||
self.set_model_mode("train")
|
||||
losses = MetricMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
self.num_batches = len(self.train_loader_x)
|
||||
|
||||
end = time.time()
|
||||
for self.batch_idx, batch in enumerate(self.train_loader_x):
|
||||
data_time.update(time.time() - end)
|
||||
loss_summary = self.forward_backward(batch)
|
||||
batch_time.update(time.time() - end)
|
||||
losses.update(loss_summary)
|
||||
|
||||
if (
|
||||
self.batch_idx + 1
|
||||
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
|
||||
nb_remain = 0
|
||||
nb_remain += self.num_batches - self.batch_idx - 1
|
||||
nb_remain += (
|
||||
self.max_epoch - self.epoch - 1
|
||||
) * self.num_batches
|
||||
eta_seconds = batch_time.avg * nb_remain
|
||||
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
print(
|
||||
"epoch [{0}/{1}][{2}/{3}]\t"
|
||||
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
||||
"eta {eta}\t"
|
||||
"{losses}\t"
|
||||
"lr {lr:.6e}".format(
|
||||
self.epoch + 1,
|
||||
self.max_epoch,
|
||||
self.batch_idx + 1,
|
||||
self.num_batches,
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
eta=eta,
|
||||
losses=losses,
|
||||
lr=self.get_current_lr(),
|
||||
)
|
||||
)
|
||||
|
||||
n_iter = self.epoch * self.num_batches + self.batch_idx
|
||||
for name, meter in losses.meters.items():
|
||||
self.write_scalar("train/" + name, meter.avg, n_iter)
|
||||
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
||||
|
||||
end = time.time()
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
domain = batch["domain"]
|
||||
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
domain = domain.to(self.device)
|
||||
|
||||
return input, label, domain
|
||||
Reference in New Issue
Block a user