release code
This commit is contained in:
16
Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py
Normal file
16
Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .mmd import MaximumMeanDiscrepancy
|
||||
from .dsbn import DSBN1d, DSBN2d
|
||||
from .mixup import mixup
|
||||
from .efdmix import (
|
||||
EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix,
|
||||
crossdomain_efdmix, run_without_efdmix
|
||||
)
|
||||
from .mixstyle import (
|
||||
MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle,
|
||||
deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle
|
||||
)
|
||||
from .transnorm import TransNorm1d, TransNorm2d
|
||||
from .sequential2 import Sequential2
|
||||
from .reverse_grad import ReverseGrad
|
||||
from .cross_entropy import cross_entropy
|
||||
from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance
|
||||
30
Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py
Normal file
30
Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def cross_entropy(input, target, label_smooth=0, reduction="mean"):
|
||||
"""Cross entropy loss.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): logit matrix with shape of (batch, num_classes).
|
||||
target (torch.LongTensor): int label matrix.
|
||||
label_smooth (float, optional): label smoothing hyper-parameter.
|
||||
Default is 0.
|
||||
reduction (str, optional): how the losses for a mini-batch
|
||||
will be aggregated. Default is 'mean'.
|
||||
"""
|
||||
num_classes = input.shape[1]
|
||||
log_prob = F.log_softmax(input, dim=1)
|
||||
zeros = torch.zeros(log_prob.size())
|
||||
target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1)
|
||||
target = target.type_as(input)
|
||||
target = (1-label_smooth) * target + label_smooth/num_classes
|
||||
loss = (-target * log_prob).sum(1)
|
||||
if reduction == "mean":
|
||||
return loss.mean()
|
||||
elif reduction == "sum":
|
||||
return loss.sum()
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError
|
||||
45
Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py
Normal file
45
Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _DSBN(nn.Module):
|
||||
"""Domain Specific Batch Normalization.
|
||||
|
||||
Args:
|
||||
num_features (int): number of features.
|
||||
n_domain (int): number of domains.
|
||||
bn_type (str): type of bn. Choices are ['1d', '2d'].
|
||||
"""
|
||||
|
||||
def __init__(self, num_features, n_domain, bn_type):
|
||||
super().__init__()
|
||||
if bn_type == "1d":
|
||||
BN = nn.BatchNorm1d
|
||||
elif bn_type == "2d":
|
||||
BN = nn.BatchNorm2d
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))
|
||||
|
||||
self.valid_domain_idxs = list(range(n_domain))
|
||||
self.n_domain = n_domain
|
||||
self.domain_idx = 0
|
||||
|
||||
def select_bn(self, domain_idx=0):
|
||||
assert domain_idx in self.valid_domain_idxs
|
||||
self.domain_idx = domain_idx
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn[self.domain_idx](x)
|
||||
|
||||
|
||||
class DSBN1d(_DSBN):
|
||||
|
||||
def __init__(self, num_features, n_domain):
|
||||
super().__init__(num_features, n_domain, "1d")
|
||||
|
||||
|
||||
class DSBN2d(_DSBN):
|
||||
|
||||
def __init__(self, num_features, n_domain):
|
||||
super().__init__(num_features, n_domain, "2d")
|
||||
118
Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py
Normal file
118
Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def deactivate_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.set_activation_status(False)
|
||||
|
||||
|
||||
def activate_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.set_activation_status(True)
|
||||
|
||||
|
||||
def random_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.update_mix_method("random")
|
||||
|
||||
|
||||
def crossdomain_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.update_mix_method("crossdomain")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_without_efdmix(model):
|
||||
# Assume MixStyle was initially activated
|
||||
try:
|
||||
model.apply(deactivate_efdmix)
|
||||
yield
|
||||
finally:
|
||||
model.apply(activate_efdmix)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_with_efdmix(model, mix=None):
|
||||
# Assume MixStyle was initially deactivated
|
||||
if mix == "random":
|
||||
model.apply(random_efdmix)
|
||||
|
||||
elif mix == "crossdomain":
|
||||
model.apply(crossdomain_efdmix)
|
||||
|
||||
try:
|
||||
model.apply(activate_efdmix)
|
||||
yield
|
||||
finally:
|
||||
model.apply(deactivate_efdmix)
|
||||
|
||||
|
||||
class EFDMix(nn.Module):
|
||||
"""EFDMix.
|
||||
|
||||
Reference:
|
||||
Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
|
||||
"""
|
||||
Args:
|
||||
p (float): probability of using MixStyle.
|
||||
alpha (float): parameter of the Beta distribution.
|
||||
eps (float): scaling parameter to avoid numerical issues.
|
||||
mix (str): how to mix.
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.beta = torch.distributions.Beta(alpha, alpha)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
self.mix = mix
|
||||
self._activated = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
|
||||
)
|
||||
|
||||
def set_activation_status(self, status=True):
|
||||
self._activated = status
|
||||
|
||||
def update_mix_method(self, mix="random"):
|
||||
self.mix = mix
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self._activated:
|
||||
return x
|
||||
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
|
||||
B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||
x_view = x.view(B, C, -1)
|
||||
value_x, index_x = torch.sort(x_view) # sort inputs
|
||||
lmda = self.beta.sample((B, 1, 1))
|
||||
lmda = lmda.to(x.device)
|
||||
|
||||
if self.mix == "random":
|
||||
# random shuffle
|
||||
perm = torch.randperm(B)
|
||||
|
||||
elif self.mix == "crossdomain":
|
||||
# split into two halves and swap the order
|
||||
perm = torch.arange(B - 1, -1, -1) # inverse index
|
||||
perm_b, perm_a = perm.chunk(2)
|
||||
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
|
||||
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
|
||||
perm = torch.cat([perm_b, perm_a], 0)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
inverse_index = index_x.argsort(-1)
|
||||
x_view_copy = value_x[perm].gather(-1, inverse_index)
|
||||
new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)
|
||||
return new_x.view(B, C, W, H)
|
||||
124
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def deactivate_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.set_activation_status(False)
|
||||
|
||||
|
||||
def activate_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.set_activation_status(True)
|
||||
|
||||
|
||||
def random_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.update_mix_method("random")
|
||||
|
||||
|
||||
def crossdomain_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.update_mix_method("crossdomain")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_without_mixstyle(model):
|
||||
# Assume MixStyle was initially activated
|
||||
try:
|
||||
model.apply(deactivate_mixstyle)
|
||||
yield
|
||||
finally:
|
||||
model.apply(activate_mixstyle)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_with_mixstyle(model, mix=None):
|
||||
# Assume MixStyle was initially deactivated
|
||||
if mix == "random":
|
||||
model.apply(random_mixstyle)
|
||||
|
||||
elif mix == "crossdomain":
|
||||
model.apply(crossdomain_mixstyle)
|
||||
|
||||
try:
|
||||
model.apply(activate_mixstyle)
|
||||
yield
|
||||
finally:
|
||||
model.apply(deactivate_mixstyle)
|
||||
|
||||
|
||||
class MixStyle(nn.Module):
|
||||
"""MixStyle.
|
||||
|
||||
Reference:
|
||||
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
|
||||
"""
|
||||
Args:
|
||||
p (float): probability of using MixStyle.
|
||||
alpha (float): parameter of the Beta distribution.
|
||||
eps (float): scaling parameter to avoid numerical issues.
|
||||
mix (str): how to mix.
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.beta = torch.distributions.Beta(alpha, alpha)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
self.mix = mix
|
||||
self._activated = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
|
||||
)
|
||||
|
||||
def set_activation_status(self, status=True):
|
||||
self._activated = status
|
||||
|
||||
def update_mix_method(self, mix="random"):
|
||||
self.mix = mix
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self._activated:
|
||||
return x
|
||||
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
|
||||
B = x.size(0)
|
||||
|
||||
mu = x.mean(dim=[2, 3], keepdim=True)
|
||||
var = x.var(dim=[2, 3], keepdim=True)
|
||||
sig = (var + self.eps).sqrt()
|
||||
mu, sig = mu.detach(), sig.detach()
|
||||
x_normed = (x-mu) / sig
|
||||
|
||||
lmda = self.beta.sample((B, 1, 1, 1))
|
||||
lmda = lmda.to(x.device)
|
||||
|
||||
if self.mix == "random":
|
||||
# random shuffle
|
||||
perm = torch.randperm(B)
|
||||
|
||||
elif self.mix == "crossdomain":
|
||||
# split into two halves and swap the order
|
||||
perm = torch.arange(B - 1, -1, -1) # inverse index
|
||||
perm_b, perm_a = perm.chunk(2)
|
||||
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
|
||||
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
|
||||
perm = torch.cat([perm_b, perm_a], 0)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
mu2, sig2 = mu[perm], sig[perm]
|
||||
mu_mix = mu*lmda + mu2 * (1-lmda)
|
||||
sig_mix = sig*lmda + sig2 * (1-lmda)
|
||||
|
||||
return x_normed*sig_mix + mu_mix
|
||||
23
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
Normal file
23
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
def mixup(x1, x2, y1, y2, beta, preserve_order=False):
|
||||
"""Mixup.
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): data with shape of (b, c, h, w).
|
||||
x2 (torch.Tensor): data with shape of (b, c, h, w).
|
||||
y1 (torch.Tensor): label with shape of (b, n).
|
||||
y2 (torch.Tensor): label with shape of (b, n).
|
||||
beta (float): hyper-parameter for Beta sampling.
|
||||
preserve_order (bool): apply lmda=max(lmda, 1-lmda).
|
||||
Default is False.
|
||||
"""
|
||||
lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])
|
||||
if preserve_order:
|
||||
lmda = torch.max(lmda, 1 - lmda)
|
||||
lmda = lmda.to(x1.device)
|
||||
xmix = x1*lmda + x2 * (1-lmda)
|
||||
lmda = lmda[:, :, 0, 0]
|
||||
ymix = y1*lmda + y2 * (1-lmda)
|
||||
return xmix, ymix
|
||||
91
Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py
Normal file
91
Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class MaximumMeanDiscrepancy(nn.Module):
|
||||
|
||||
def __init__(self, kernel_type="rbf", normalize=False):
|
||||
super().__init__()
|
||||
self.kernel_type = kernel_type
|
||||
self.normalize = normalize
|
||||
|
||||
def forward(self, x, y):
|
||||
# x, y: two batches of data with shape (batch, dim)
|
||||
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
|
||||
if self.normalize:
|
||||
x = F.normalize(x, dim=1)
|
||||
y = F.normalize(y, dim=1)
|
||||
if self.kernel_type == "linear":
|
||||
return self.linear_mmd(x, y)
|
||||
elif self.kernel_type == "poly":
|
||||
return self.poly_mmd(x, y)
|
||||
elif self.kernel_type == "rbf":
|
||||
return self.rbf_mmd(x, y)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def linear_mmd(self, x, y):
|
||||
# k(x, y) = x^T y
|
||||
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
|
||||
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
|
||||
k_xy = torch.mm(x, y.t())
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
|
||||
# k(x, y) = (alpha * x^T y + c)^d
|
||||
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
|
||||
k_xx = (alpha*k_xx + c).pow(d)
|
||||
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
|
||||
k_yy = (alpha*k_yy + c).pow(d)
|
||||
k_xy = torch.mm(x, y.t())
|
||||
k_xy = (alpha*k_xy + c).pow(d)
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
def rbf_mmd(self, x, y):
|
||||
# k_xx
|
||||
d_xx = self.euclidean_squared_distance(x, x)
|
||||
d_xx = self.remove_self_distance(d_xx)
|
||||
k_xx = self.rbf_kernel_mixture(d_xx)
|
||||
# k_yy
|
||||
d_yy = self.euclidean_squared_distance(y, y)
|
||||
d_yy = self.remove_self_distance(d_yy)
|
||||
k_yy = self.rbf_kernel_mixture(d_yy)
|
||||
# k_xy
|
||||
d_xy = self.euclidean_squared_distance(x, y)
|
||||
k_xy = self.rbf_kernel_mixture(d_xy)
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
@staticmethod
|
||||
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
|
||||
K = 0
|
||||
for sigma in sigmas:
|
||||
gamma = 1.0 / (2.0 * sigma**2)
|
||||
K += torch.exp(-gamma * exponent)
|
||||
return K
|
||||
|
||||
@staticmethod
|
||||
def remove_self_distance(distmat):
|
||||
tmp_list = []
|
||||
for i, row in enumerate(distmat):
|
||||
row1 = torch.cat([row[:i], row[i + 1:]])
|
||||
tmp_list.append(row1)
|
||||
return torch.stack(tmp_list)
|
||||
|
||||
@staticmethod
|
||||
def euclidean_squared_distance(x, y):
|
||||
m, n = x.size(0), y.size(0)
|
||||
distmat = (
|
||||
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
||||
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
)
|
||||
# distmat.addmm_(1, -2, x, y.t())
|
||||
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
|
||||
return distmat
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
|
||||
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
|
||||
d = mmd(input1, input2)
|
||||
print(d.item())
|
||||
147
Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
Normal file
147
Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class OptimalTransport(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def distance(batch1, batch2, dist_metric="cosine"):
|
||||
if dist_metric == "cosine":
|
||||
batch1 = F.normalize(batch1, p=2, dim=1)
|
||||
batch2 = F.normalize(batch2, p=2, dim=1)
|
||||
dist_mat = 1 - torch.mm(batch1, batch2.t())
|
||||
elif dist_metric == "euclidean":
|
||||
m, n = batch1.size(0), batch2.size(0)
|
||||
dist_mat = (
|
||||
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
||||
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
)
|
||||
dist_mat.addmm_(
|
||||
1, -2, batch1, batch2.t()
|
||||
) # squared euclidean distance
|
||||
elif dist_metric == "fast_euclidean":
|
||||
batch1 = batch1.unsqueeze(-2)
|
||||
batch2 = batch2.unsqueeze(-3)
|
||||
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown cost function: {}. Expected to "
|
||||
"be one of [cosine | euclidean]".format(dist_metric)
|
||||
)
|
||||
return dist_mat
|
||||
|
||||
|
||||
class SinkhornDivergence(OptimalTransport):
|
||||
thre = 1e-3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist_metric="cosine",
|
||||
eps=0.01,
|
||||
max_iter=5,
|
||||
bp_to_sinkhorn=False
|
||||
):
|
||||
super().__init__()
|
||||
self.dist_metric = dist_metric
|
||||
self.eps = eps
|
||||
self.max_iter = max_iter
|
||||
self.bp_to_sinkhorn = bp_to_sinkhorn
|
||||
|
||||
def forward(self, x, y):
|
||||
# x, y: two batches of data with shape (batch, dim)
|
||||
W_xy = self.transport_cost(x, y)
|
||||
W_xx = self.transport_cost(x, x)
|
||||
W_yy = self.transport_cost(y, y)
|
||||
return 2*W_xy - W_xx - W_yy
|
||||
|
||||
def transport_cost(self, x, y, return_pi=False):
|
||||
C = self.distance(x, y, dist_metric=self.dist_metric)
|
||||
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
|
||||
if not self.bp_to_sinkhorn:
|
||||
pi = pi.detach()
|
||||
cost = torch.sum(pi * C)
|
||||
if return_pi:
|
||||
return cost, pi
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def sinkhorn_iterate(C, eps, max_iter, thre):
|
||||
nx, ny = C.shape
|
||||
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
|
||||
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
|
||||
u = torch.zeros_like(mu)
|
||||
v = torch.zeros_like(nu)
|
||||
|
||||
def M(_C, _u, _v):
|
||||
"""Modified cost for logarithmic updates.
|
||||
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
|
||||
"""
|
||||
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
|
||||
|
||||
real_iter = 0 # check if algorithm terminates before max_iter
|
||||
# Sinkhorn iterations
|
||||
for i in range(max_iter):
|
||||
u0 = u
|
||||
u = eps * (
|
||||
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
|
||||
) + u
|
||||
v = (
|
||||
eps * (
|
||||
torch.log(nu + 1e-8) -
|
||||
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
|
||||
) + v
|
||||
)
|
||||
err = (u - u0).abs().sum()
|
||||
real_iter += 1
|
||||
if err.item() < thre:
|
||||
break
|
||||
# Transport plan pi = diag(a)*K*diag(b)
|
||||
return torch.exp(M(C, u, v))
|
||||
|
||||
|
||||
class MinibatchEnergyDistance(SinkhornDivergence):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist_metric="cosine",
|
||||
eps=0.01,
|
||||
max_iter=5,
|
||||
bp_to_sinkhorn=False
|
||||
):
|
||||
super().__init__(
|
||||
dist_metric=dist_metric,
|
||||
eps=eps,
|
||||
max_iter=max_iter,
|
||||
bp_to_sinkhorn=bp_to_sinkhorn,
|
||||
)
|
||||
|
||||
def forward(self, x, y):
|
||||
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
|
||||
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
|
||||
cost = 0
|
||||
cost += self.transport_cost(x1, y1)
|
||||
cost += self.transport_cost(x1, y2)
|
||||
cost += self.transport_cost(x2, y1)
|
||||
cost += self.transport_cost(x2, y2)
|
||||
cost -= 2 * self.transport_cost(x1, x2)
|
||||
cost -= 2 * self.transport_cost(y1, y2)
|
||||
return cost
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# example: https://dfdazac.github.io/sinkhorn.html
|
||||
import numpy as np
|
||||
|
||||
n_points = 5
|
||||
a = np.array([[i, 0] for i in range(n_points)])
|
||||
b = np.array([[i, 1] for i in range(n_points)])
|
||||
x = torch.tensor(a, dtype=torch.float)
|
||||
y = torch.tensor(b, dtype=torch.float)
|
||||
sinkhorn = SinkhornDivergence(
|
||||
dist_metric="euclidean", eps=0.01, max_iter=5
|
||||
)
|
||||
dist, pi = sinkhorn.transport_cost(x, y, True)
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
34
Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
Normal file
34
Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class _ReverseGrad(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, grad_scaling):
|
||||
ctx.grad_scaling = grad_scaling
|
||||
return input.view_as(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_scaling = ctx.grad_scaling
|
||||
return -grad_scaling * grad_output, None
|
||||
|
||||
|
||||
reverse_grad = _ReverseGrad.apply
|
||||
|
||||
|
||||
class ReverseGrad(nn.Module):
|
||||
"""Gradient reversal layer.
|
||||
|
||||
It acts as an identity layer in the forward,
|
||||
but reverses the sign of the gradient in
|
||||
the backward.
|
||||
"""
|
||||
|
||||
def forward(self, x, grad_scaling=1.0):
|
||||
assert (grad_scaling >=
|
||||
0), "grad_scaling must be non-negative, " "but got {}".format(
|
||||
grad_scaling
|
||||
)
|
||||
return reverse_grad(x, grad_scaling)
|
||||
15
Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py
Normal file
15
Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Sequential2(nn.Sequential):
|
||||
"""An alternative sequential container to nn.Sequential,
|
||||
which accepts an arbitrary number of input arguments.
|
||||
"""
|
||||
|
||||
def forward(self, *inputs):
|
||||
for module in self._modules.values():
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
return inputs
|
||||
138
Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py
Normal file
138
Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _TransNorm(nn.Module):
|
||||
"""Transferable normalization.
|
||||
|
||||
Reference:
|
||||
- Wang et al. Transferable Normalization: Towards Improving
|
||||
Transferability of Deep Neural Networks. NeurIPS 2019.
|
||||
|
||||
Args:
|
||||
num_features (int): number of features.
|
||||
eps (float): epsilon.
|
||||
momentum (float): value for updating running_mean and running_var.
|
||||
adaptive_alpha (bool): apply domain adaptive alpha.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True
|
||||
):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.adaptive_alpha = adaptive_alpha
|
||||
|
||||
self.register_buffer("running_mean_s", torch.zeros(num_features))
|
||||
self.register_buffer("running_var_s", torch.ones(num_features))
|
||||
self.register_buffer("running_mean_t", torch.zeros(num_features))
|
||||
self.register_buffer("running_var_t", torch.ones(num_features))
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
|
||||
def resnet_running_stats(self):
|
||||
self.running_mean_s.zero_()
|
||||
self.running_var_s.fill_(1)
|
||||
self.running_mean_t.zero_()
|
||||
self.running_var_t.fill_(1)
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def _check_input(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def _compute_alpha(self, mean_s, var_s, mean_t, var_t):
|
||||
C = self.num_features
|
||||
ratio_s = mean_s / (var_s + self.eps).sqrt()
|
||||
ratio_t = mean_t / (var_t + self.eps).sqrt()
|
||||
dist = (ratio_s - ratio_t).abs()
|
||||
dist_inv = 1 / (1+dist)
|
||||
return C * dist_inv / dist_inv.sum()
|
||||
|
||||
def forward(self, input):
|
||||
self._check_input(input)
|
||||
C = self.num_features
|
||||
if input.dim() == 2:
|
||||
new_shape = (1, C)
|
||||
elif input.dim() == 4:
|
||||
new_shape = (1, C, 1, 1)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
weight = self.weight.view(*new_shape)
|
||||
bias = self.bias.view(*new_shape)
|
||||
|
||||
if not self.training:
|
||||
mean_t = self.running_mean_t.view(*new_shape)
|
||||
var_t = self.running_var_t.view(*new_shape)
|
||||
output = (input-mean_t) / (var_t + self.eps).sqrt()
|
||||
output = output*weight + bias
|
||||
|
||||
if self.adaptive_alpha:
|
||||
mean_s = self.running_mean_s.view(*new_shape)
|
||||
var_s = self.running_var_s.view(*new_shape)
|
||||
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
|
||||
alpha = alpha.reshape(*new_shape)
|
||||
output = (1 + alpha.detach()) * output
|
||||
|
||||
return output
|
||||
|
||||
input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)
|
||||
|
||||
x_s = input_s.transpose(0, 1).reshape(C, -1)
|
||||
mean_s = x_s.mean(1)
|
||||
var_s = x_s.var(1)
|
||||
self.running_mean_s.mul_(self.momentum)
|
||||
self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
|
||||
self.running_var_s.mul_(self.momentum)
|
||||
self.running_var_s.add_((1 - self.momentum) * var_s.data)
|
||||
mean_s = mean_s.reshape(*new_shape)
|
||||
var_s = var_s.reshape(*new_shape)
|
||||
output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
|
||||
output_s = output_s*weight + bias
|
||||
|
||||
x_t = input_t.transpose(0, 1).reshape(C, -1)
|
||||
mean_t = x_t.mean(1)
|
||||
var_t = x_t.var(1)
|
||||
self.running_mean_t.mul_(self.momentum)
|
||||
self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
|
||||
self.running_var_t.mul_(self.momentum)
|
||||
self.running_var_t.add_((1 - self.momentum) * var_t.data)
|
||||
mean_t = mean_t.reshape(*new_shape)
|
||||
var_t = var_t.reshape(*new_shape)
|
||||
output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
|
||||
output_t = output_t*weight + bias
|
||||
|
||||
output = torch.cat([output_s, output_t], 0)
|
||||
|
||||
if self.adaptive_alpha:
|
||||
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
|
||||
alpha = alpha.reshape(*new_shape)
|
||||
output = (1 + alpha.detach()) * output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransNorm1d(_TransNorm):
|
||||
|
||||
def _check_input(self, x):
|
||||
if x.dim() != 2:
|
||||
raise ValueError(
|
||||
"Expected the input to be 2-D, "
|
||||
"but got {}-D".format(x.dim())
|
||||
)
|
||||
|
||||
|
||||
class TransNorm2d(_TransNorm):
|
||||
|
||||
def _check_input(self, x):
|
||||
if x.dim() != 4:
|
||||
raise ValueError(
|
||||
"Expected the input to be 4-D, "
|
||||
"but got {}-D".format(x.dim())
|
||||
)
|
||||
75
Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py
Normal file
75
Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def sharpen_prob(p, temperature=2):
|
||||
"""Sharpening probability with a temperature.
|
||||
|
||||
Args:
|
||||
p (torch.Tensor): probability matrix (batch_size, n_classes)
|
||||
temperature (float): temperature.
|
||||
"""
|
||||
p = p.pow(temperature)
|
||||
return p / p.sum(1, keepdim=True)
|
||||
|
||||
|
||||
def reverse_index(data, label):
|
||||
"""Reverse order."""
|
||||
inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()
|
||||
return data[inv_idx], label[inv_idx]
|
||||
|
||||
|
||||
def shuffle_index(data, label):
|
||||
"""Shuffle order."""
|
||||
rnd_idx = torch.randperm(data.shape[0])
|
||||
return data[rnd_idx], label[rnd_idx]
|
||||
|
||||
|
||||
def create_onehot(label, num_classes):
|
||||
"""Create one-hot tensor.
|
||||
|
||||
We suggest using nn.functional.one_hot.
|
||||
|
||||
Args:
|
||||
label (torch.Tensor): 1-D tensor.
|
||||
num_classes (int): number of classes.
|
||||
"""
|
||||
onehot = torch.zeros(label.shape[0], num_classes)
|
||||
return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)
|
||||
|
||||
|
||||
def sigmoid_rampup(current, rampup_length):
|
||||
"""Exponential rampup.
|
||||
|
||||
Args:
|
||||
current (int): current step.
|
||||
rampup_length (int): maximum step.
|
||||
"""
|
||||
assert rampup_length > 0
|
||||
current = np.clip(current, 0.0, rampup_length)
|
||||
phase = 1.0 - current/rampup_length
|
||||
return float(np.exp(-5.0 * phase * phase))
|
||||
|
||||
|
||||
def linear_rampup(current, rampup_length):
|
||||
"""Linear rampup.
|
||||
|
||||
Args:
|
||||
current (int): current step.
|
||||
rampup_length (int): maximum step.
|
||||
"""
|
||||
assert rampup_length > 0
|
||||
ratio = np.clip(current / rampup_length, 0.0, 1.0)
|
||||
return float(ratio)
|
||||
|
||||
|
||||
def ema_model_update(model, ema_model, alpha):
|
||||
"""Exponential moving average of model parameters.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model being trained.
|
||||
ema_model (nn.Module): ema of the model.
|
||||
alpha (float): ema decay rate.
|
||||
"""
|
||||
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
|
||||
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
|
||||
Reference in New Issue
Block a user