26 lines
910 B
Python
26 lines
910 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from daln.grl import WarmStartGradientReverseLayer
|
|
|
|
|
|
class NuclearWassersteinDiscrepancy(nn.Module):
|
|
def __init__(self, classifier: nn.Module):
|
|
super(NuclearWassersteinDiscrepancy, self).__init__()
|
|
self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True)
|
|
self.classifier = classifier
|
|
|
|
@staticmethod
|
|
def n_discrepancy(y_s: torch.Tensor, y_t: torch.Tensor) -> torch.Tensor:
|
|
pre_s, pre_t = F.softmax(y_s, dim=1), F.softmax(y_t, dim=1)
|
|
loss = (-torch.norm(pre_t, 'nuc') + torch.norm(pre_s, 'nuc')) / y_t.shape[0]
|
|
return loss
|
|
|
|
def forward(self, f: torch.Tensor) -> torch.Tensor:
|
|
f_grl = self.grl(f)
|
|
y = self.classifier(f_grl)
|
|
y_s, y_t = y.chunk(2, dim=0)
|
|
|
|
loss = self.n_discrepancy(y_s, y_t)
|
|
return loss
|