release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
from .mcd import MCD
from .mme import MME
from .adda import ADDA
from .dael import DAEL
from .dann import DANN
from .adabn import AdaBN
from .m3sda import M3SDA
from .source_only import SourceOnly
from .self_ensembling import SelfEnsembling

View File

@@ -0,0 +1,38 @@
import torch
from dassl.utils import check_isfile
from dassl.engine import TRAINER_REGISTRY, TrainerXU
@TRAINER_REGISTRY.register()
class AdaBN(TrainerXU):
"""Adaptive Batch Normalization.
https://arxiv.org/abs/1603.04779.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.done_reset_bn_stats = False
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def before_epoch(self):
if not self.done_reset_bn_stats:
for m in self.model.modules():
classname = m.__class__.__name__
if classname.find("BatchNorm") != -1:
m.reset_running_stats()
self.done_reset_bn_stats = True
def forward_backward(self, batch_x, batch_u):
input_u = batch_u["img"].to(self.device)
with torch.no_grad():
self.model(input_u)
return None

View File

@@ -0,0 +1,85 @@
import copy
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import check_isfile, count_num_param, open_specified_layers
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling import build_head
@TRAINER_REGISTRY.register()
class ADDA(TrainerXU):
"""Adversarial Discriminative Domain Adaptation.
https://arxiv.org/abs/1702.05464.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.open_layers = ["backbone"]
if isinstance(self.model.head, nn.Module):
self.open_layers.append("head")
self.source_model = copy.deepcopy(self.model)
self.source_model.eval()
for param in self.source_model.parameters():
param.requires_grad_(False)
self.build_critic()
self.bce = nn.BCEWithLogitsLoss()
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim // 2],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
def forward_backward(self, batch_x, batch_u):
open_specified_layers(self.model, self.open_layers)
input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
_, feat_x = self.source_model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
logit_xd = self.critic(feat_x)
logit_ud = self.critic(feat_u.detach())
loss_critic = self.bce(logit_xd, domain_x)
loss_critic += self.bce(logit_ud, domain_u)
self.model_backward_and_update(loss_critic, "critic")
logit_ud = self.critic(feat_u)
loss_model = self.bce(logit_ud, 1 - domain_u)
self.model_backward_and_update(loss_model, "model")
loss_summary = {
"loss_critic": loss_critic.item(),
"loss_model": loss_model.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,210 @@
import torch
import torch.nn as nn
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot
class Experts(nn.Module):
def __init__(self, n_source, fdim, num_classes):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
)
self.softmax = nn.Softmax(dim=1)
def forward(self, i, x):
x = self.linears[i](x)
x = self.softmax(x)
return x
@TRAINER_REGISTRY.register()
class DAEL(TrainerXU):
"""Domain Adaptive Ensemble Learning.
https://arxiv.org/abs/2003.07325.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u
self.val_loader = dm.val_loader
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building E")
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
self.E.to(self.device)
print("# params: {:,}".format(count_num_param(self.E)))
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
self.register_model("E", self.E, self.optim_E, self.sched_E)
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
input_x = torch.split(input_x, self.split_batch, 0)
input_x2 = torch.split(input_x2, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Generate pseudo label
with torch.no_grad():
feat_u = self.F(input_u)
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1) # (B, K, C)
# Get the highest probability and index (label) for each expert
experts_max_p, experts_max_idx = pred_u.max(2) # (B, K)
# Get the most confident expert
max_expert_p, max_expert_idx = experts_max_p.max(1) # (B)
pseudo_label_u = []
for i, experts_label in zip(max_expert_idx, experts_max_idx):
pseudo_label_u.append(experts_label[i])
pseudo_label_u = torch.stack(pseudo_label_u, 0)
pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
pseudo_label_u = pseudo_label_u.to(self.device)
label_u_mask = (max_expert_p >= self.conf_thre).float()
loss_x = 0
loss_cr = 0
acc_x = 0
feat_x = [self.F(x) for x in input_x]
feat_x2 = [self.F(x) for x in input_x2]
feat_u2 = self.F(input_u2)
for feat_xi, feat_x2i, label_xi, i in zip(
feat_x, feat_x2, label_x, domain_x
):
cr_s = [j for j in domain_x if j != i]
# Learning expert
pred_xi = self.E(i, feat_xi)
loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
expert_label_xi = pred_xi.detach()
acc_x += compute_accuracy(pred_xi.detach(),
label_xi.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat_x2i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc_x /= self.n_domain
# Unsupervised loss
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u2)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1)
pred_u = pred_u.mean(1)
l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
loss_u = (l_u * label_u_mask).mean()
loss = 0
loss += loss_x
loss += loss_cr
loss += loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": acc_x,
"loss_cr": loss_cr.item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
input_x2 = batch_x["img2"]
label_x = batch_x["label"]
domain_x = batch_x["domain"]
input_u = batch_u["img"]
input_u2 = batch_u["img2"]
label_x = create_onehot(label_x, self.num_classes)
input_x = input_x.to(self.device)
input_x2 = input_x2.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_u2 = input_u2.to(self.device)
return input_x, input_x2, label_x, domain_x, input_u, input_u2
def model_inference(self, input):
f = self.F(input)
p = []
for k in range(self.num_source_domains):
p_k = self.E(k, f)
p_k = p_k.unsqueeze(1)
p.append(p_k)
p = torch.cat(p, 1)
p = p.mean(1)
return p

View File

@@ -0,0 +1,78 @@
import numpy as np
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling import build_head
from dassl.modeling.ops import ReverseGrad
@TRAINER_REGISTRY.register()
class DANN(TrainerXU):
"""Domain-Adversarial Neural Networks.
https://arxiv.org/abs/1505.07818.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.build_critic()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
global_step = self.batch_idx + self.epoch * self.num_batches
progress = global_step / (self.max_epoch * self.num_batches)
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
logit_x, feat_x = self.model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
loss_x = self.ce(logit_x, label_x)
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
output_xd = self.critic(feat_x)
output_ud = self.critic(feat_u)
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
loss = loss_x + loss_d
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,208 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
class PairClassifiers(nn.Module):
def __init__(self, fdim, num_classes):
super().__init__()
self.c1 = nn.Linear(fdim, num_classes)
self.c2 = nn.Linear(fdim, num_classes)
def forward(self, x):
z1 = self.c1(x)
if not self.training:
return z1
z2 = self.c2(x)
return z1, z2
@TRAINER_REGISTRY.register()
class M3SDA(TrainerXU):
"""Moment Matching for Multi-Source Domain Adaptation.
https://arxiv.org/abs/1812.01754.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
self.lmda = cfg.TRAINER.M3SDA.LMDA
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C")
self.C = nn.ModuleList(
[
PairClassifiers(fdim, self.num_classes)
for _ in range(self.num_source_domains)
]
)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, domain_x, input_u = parsed
input_x = torch.split(input_x, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Step A
loss_x = 0
feat_x = []
for x, y, d in zip(input_x, label_x, domain_x):
f = self.F(x)
z1, z2 = self.C[d](f)
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
feat_x.append(f)
loss_x /= self.n_domain
feat_u = self.F(input_u)
loss_msda = self.moment_distance(feat_x, feat_u)
loss_step_A = loss_x + loss_msda * self.lmda
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_u = self.F(input_u)
loss_x, loss_dis = 0, 0
for x, y, d in zip(input_x, label_x, domain_x):
with torch.no_grad():
f = self.F(x)
z1, z2 = self.C[d](f)
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
z1, z2 = self.C[d](feat_u)
p1 = F.softmax(z1, 1)
p2 = F.softmax(z2, 1)
loss_dis += self.discrepancy(p1, p2)
loss_x /= self.n_domain
loss_dis /= self.n_domain
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, "C")
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
loss_dis = 0
for d in domain_x:
z1, z2 = self.C[d](feat_u)
p1 = F.softmax(z1, 1)
p2 = F.softmax(z2, 1)
loss_dis += self.discrepancy(p1, p2)
loss_dis /= self.n_domain
loss_step_C = loss_dis
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def moment_distance(self, x, u):
# x (list): a list of feature matrix.
# u (torch.Tensor): feature matrix.
x_mean = [xi.mean(0) for xi in x]
u_mean = u.mean(0)
dist1 = self.pairwise_distance(x_mean, u_mean)
x_var = [xi.var(0) for xi in x]
u_var = u.var(0)
dist2 = self.pairwise_distance(x_var, u_var)
return (dist1+dist2) / 2
def pairwise_distance(self, x, u):
# x (list): a list of feature vector.
# u (torch.Tensor): feature vector.
dist = 0
count = 0
for xi in x:
dist += self.euclidean(xi, u)
count += 1
for i in range(len(x) - 1):
for j in range(i + 1, len(x)):
dist += self.euclidean(x[i], x[j])
count += 1
return dist / count
def euclidean(self, input1, input2):
return ((input1 - input2)**2).sum().sqrt()
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
label_x = batch_x["label"]
domain_x = batch_x["domain"]
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
return input_x, label_x, domain_x, input_u
def model_inference(self, input):
f = self.F(input)
p = 0
for C_i in self.C:
z = C_i(f)
p += F.softmax(z, 1)
p = p / len(self.C)
return p

View File

@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class MCD(TrainerXU):
"""Maximum Classifier Discrepancy.
https://arxiv.org/abs/1712.02560.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.n_step_F = cfg.TRAINER.MCD.N_STEP_F
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C1")
self.C1 = nn.Linear(fdim, self.num_classes)
self.C1.to(self.device)
print("# params: {:,}".format(count_num_param(self.C1)))
self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)
print("Building C2")
self.C2 = nn.Linear(fdim, self.num_classes)
self.C2.to(self.device)
print("# params: {:,}".format(count_num_param(self.C2)))
self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u = parsed
# Step A
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_step_A = loss_x1 + loss_x2
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_x = loss_x1 + loss_x2
with torch.no_grad():
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_dis = self.discrepancy(pred_u1, pred_u2)
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, ["C1", "C2"])
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_step_C = self.discrepancy(pred_u1, pred_u2)
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def model_inference(self, input):
feat = self.F(input)
return self.C1(feat)

View File

@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import ReverseGrad
from dassl.engine.trainer import SimpleNet
class Prototypes(nn.Module):
def __init__(self, fdim, num_classes, temp=0.05):
super().__init__()
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
self.temp = temp
def forward(self, x):
x = F.normalize(x, p=2, dim=1)
out = self.prototypes(x)
out = out / self.temp
return out
@TRAINER_REGISTRY.register()
class MME(TrainerXU):
"""Minimax Entropy.
https://arxiv.org/abs/1904.06487.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.MME.LMDA
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building C")
self.C = Prototypes(self.F.fdim, self.num_classes)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
feat_x = self.F(input_x)
logit_x = self.C(feat_x)
loss_x = F.cross_entropy(logit_x, label_x)
self.model_backward_and_update(loss_x)
feat_u = self.F(input_u)
feat_u = self.revgrad(feat_u)
logit_u = self.C(feat_u)
prob_u = F.softmax(logit_u, 1)
loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
self.model_backward_and_update(loss_u * self.lmda)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.C(self.F(input))

View File

@@ -0,0 +1,78 @@
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class SelfEnsembling(TrainerXU):
"""Self-ensembling for visual domain adaptation.
https://arxiv.org/abs/1706.05208.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA
self.conf_thre = cfg.TRAINER.SE.CONF_THRE
self.rampup = cfg.TRAINER.SE.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS == 2
def forward_backward(self, batch_x, batch_u):
global_step = self.batch_idx + self.epoch * self.num_batches
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u1, input_u2 = parsed
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
prob_u = F.softmax(self.model(input_u1), 1)
t_prob_u = F.softmax(self.teacher(input_u2), 1)
loss_u = ((prob_u - t_prob_u)**2).sum(1)
if self.conf_thre:
max_prob = t_prob_u.max(1)[0]
mask = (max_prob > self.conf_thre).float()
loss_u = (loss_u * mask).mean()
else:
weight_u = sigmoid_rampup(global_step, self.rampup)
loss_u = loss_u.mean() * weight_u
loss = loss_x + loss_u
self.model_backward_and_update(loss)
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_u1, input_u2 = input_u
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u1 = input_u1.to(self.device)
input_u2 = input_u2.to(self.device)
return input_x, label_x, input_u1, input_u2

View File

@@ -0,0 +1,34 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class SourceOnly(TrainerXU):
"""Baseline model for domain adaptation, which is
trained using source data only.
"""
def forward_backward(self, batch_x, batch_u):
input, label = self.parse_batch_train(batch_x, batch_u)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input = batch_x["img"]
label = batch_x["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label