release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .ddaig import DDAIG
from .daeldg import DAELDG
from .vanilla import Vanilla
from .crossgrad import CrossGrad

View File

@@ -0,0 +1,83 @@
import torch
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class CrossGrad(TrainerX):
"""Cross-gradient training.
https://arxiv.org/abs/1804.10745.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.eps_f = cfg.TRAINER.CG.EPS_F
self.eps_d = cfg.TRAINER.CG.EPS_D
self.alpha_f = cfg.TRAINER.CG.ALPHA_F
self.alpha_d = cfg.TRAINER.CG.ALPHA_D
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building D")
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
self.D.to(self.device)
print("# params: {:,}".format(count_num_param(self.D)))
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
self.register_model("D", self.D, self.optim_D, self.sched_D)
def forward_backward(self, batch):
input, label, domain = self.parse_batch_train(batch)
input.requires_grad = True
# Compute domain perturbation
loss_d = F.cross_entropy(self.D(input), domain)
loss_d.backward()
grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
input_d = input.data + self.eps_f * grad_d
# Compute label perturbation
input.grad.data.zero_()
loss_f = F.cross_entropy(self.F(input), label)
loss_f.backward()
grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
input_f = input.data + self.eps_d * grad_f
input = input.detach()
# Update label net
loss_f1 = F.cross_entropy(self.F(input), label)
loss_f2 = F.cross_entropy(self.F(input_d), label)
loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
self.model_backward_and_update(loss_f, "F")
# Update domain net
loss_d1 = F.cross_entropy(self.D(input), domain)
loss_d2 = F.cross_entropy(self.D(input_f), domain)
loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
self.model_backward_and_update(loss_d, "D")
loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.F(input)

View File

@@ -0,0 +1,169 @@
import torch
import torch.nn as nn
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot
class Experts(nn.Module):
def __init__(self, n_source, fdim, num_classes):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
)
self.softmax = nn.Softmax(dim=1)
def forward(self, i, x):
x = self.linears[i](x)
x = self.softmax(x)
return x
@TRAINER_REGISTRY.register()
class DAELDG(TrainerX):
"""Domain Adaptive Ensemble Learning.
DG version: only use labeled source data.
https://arxiv.org/abs/2003.07325.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u
self.val_loader = dm.val_loader
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building E")
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
self.E.to(self.device)
print("# params: {:,}".format(count_num_param(self.E)))
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
self.register_model("E", self.E, self.optim_E, self.sched_E)
def forward_backward(self, batch):
parsed_data = self.parse_batch_train(batch)
input, input2, label, domain = parsed_data
input = torch.split(input, self.split_batch, 0)
input2 = torch.split(input2, self.split_batch, 0)
label = torch.split(label, self.split_batch, 0)
domain = torch.split(domain, self.split_batch, 0)
domain = [d[0].item() for d in domain]
loss_x = 0
loss_cr = 0
acc = 0
feat = [self.F(x) for x in input]
feat2 = [self.F(x) for x in input2]
for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
cr_s = [j for j in domain if j != i]
# Learning expert
pred_i = self.E(i, feat_i)
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
expert_label_i = pred_i.detach()
acc += compute_accuracy(pred_i.detach(),
label_i.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat2_i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc /= self.n_domain
loss = 0
loss += loss_x
loss += loss_cr
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc": acc,
"loss_cr": loss_cr.item()
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
input2 = batch["img2"]
label = batch["label"]
domain = batch["domain"]
label = create_onehot(label, self.num_classes)
input = input.to(self.device)
input2 = input2.to(self.device)
label = label.to(self.device)
return input, input2, label, domain
def model_inference(self, input):
f = self.F(input)
p = []
for k in range(self.num_source_domains):
p_k = self.E(k, f)
p_k = p_k.unsqueeze(1)
p.append(p_k)
p = torch.cat(p, 1)
p = p.mean(1)
return p

View File

@@ -0,0 +1,107 @@
import torch
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.modeling import build_network
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class DDAIG(TrainerX):
"""Deep Domain-Adversarial Image Generation.
https://arxiv.org/abs/2003.06054.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.DDAIG.LMDA
self.clamp = cfg.TRAINER.DDAIG.CLAMP
self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
self.warmup = cfg.TRAINER.DDAIG.WARMUP
self.alpha = cfg.TRAINER.DDAIG.ALPHA
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building D")
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
self.D.to(self.device)
print("# params: {:,}".format(count_num_param(self.D)))
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
self.register_model("D", self.D, self.optim_D, self.sched_D)
print("Building G")
self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
self.G.to(self.device)
print("# params: {:,}".format(count_num_param(self.G)))
self.optim_G = build_optimizer(self.G, cfg.OPTIM)
self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
self.register_model("G", self.G, self.optim_G, self.sched_G)
def forward_backward(self, batch):
input, label, domain = self.parse_batch_train(batch)
#############
# Update G
#############
input_p = self.G(input, lmda=self.lmda)
if self.clamp:
input_p = torch.clamp(
input_p, min=self.clamp_min, max=self.clamp_max
)
loss_g = 0
# Minimize label loss
loss_g += F.cross_entropy(self.F(input_p), label)
# Maximize domain loss
loss_g -= F.cross_entropy(self.D(input_p), domain)
self.model_backward_and_update(loss_g, "G")
# Perturb data with new G
with torch.no_grad():
input_p = self.G(input, lmda=self.lmda)
if self.clamp:
input_p = torch.clamp(
input_p, min=self.clamp_min, max=self.clamp_max
)
#############
# Update F
#############
loss_f = F.cross_entropy(self.F(input), label)
if (self.epoch + 1) > self.warmup:
loss_fp = F.cross_entropy(self.F(input_p), label)
loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
self.model_backward_and_update(loss_f, "F")
#############
# Update D
#############
loss_d = F.cross_entropy(self.D(input), domain)
self.model_backward_and_update(loss_d, "D")
loss_summary = {
"loss_g": loss_g.item(),
"loss_f": loss_f.item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.F(input)

View File

@@ -0,0 +1,32 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class Vanilla(TrainerX):
"""Vanilla baseline."""
def forward_backward(self, batch):
input, label = self.parse_batch_train(batch)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label