release code
This commit is contained in:
78
Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
Normal file
78
Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling import build_head
|
||||
from dassl.modeling.ops import ReverseGrad
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DANN(TrainerXU):
|
||||
"""Domain-Adversarial Neural Networks.
|
||||
|
||||
https://arxiv.org/abs/1505.07818.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.build_critic()
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def build_critic(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building critic network")
|
||||
fdim = self.model.fdim
|
||||
critic_body = build_head(
|
||||
"mlp",
|
||||
verbose=cfg.VERBOSE,
|
||||
in_features=fdim,
|
||||
hidden_layers=[fdim, fdim],
|
||||
activation="leaky_relu",
|
||||
)
|
||||
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
|
||||
print("# params: {:,}".format(count_num_param(self.critic)))
|
||||
self.critic.to(self.device)
|
||||
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
|
||||
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
|
||||
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
|
||||
self.revgrad = ReverseGrad()
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
|
||||
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
progress = global_step / (self.max_epoch * self.num_batches)
|
||||
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
|
||||
|
||||
logit_x, feat_x = self.model(input_x, return_feature=True)
|
||||
_, feat_u = self.model(input_u, return_feature=True)
|
||||
|
||||
loss_x = self.ce(logit_x, label_x)
|
||||
|
||||
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
|
||||
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
|
||||
output_xd = self.critic(feat_x)
|
||||
output_ud = self.critic(feat_u)
|
||||
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
|
||||
|
||||
loss = loss_x + loss_d
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_d": loss_d.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
Reference in New Issue
Block a user