release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
from .mmd import MaximumMeanDiscrepancy
from .dsbn import DSBN1d, DSBN2d
from .mixup import mixup
from .efdmix import (
EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix,
crossdomain_efdmix, run_without_efdmix
)
from .mixstyle import (
MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle,
deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle
)
from .transnorm import TransNorm1d, TransNorm2d
from .sequential2 import Sequential2
from .reverse_grad import ReverseGrad
from .cross_entropy import cross_entropy
from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance

View File

@@ -0,0 +1,30 @@
import torch
from torch.nn import functional as F
def cross_entropy(input, target, label_smooth=0, reduction="mean"):
"""Cross entropy loss.
Args:
input (torch.Tensor): logit matrix with shape of (batch, num_classes).
target (torch.LongTensor): int label matrix.
label_smooth (float, optional): label smoothing hyper-parameter.
Default is 0.
reduction (str, optional): how the losses for a mini-batch
will be aggregated. Default is 'mean'.
"""
num_classes = input.shape[1]
log_prob = F.log_softmax(input, dim=1)
zeros = torch.zeros(log_prob.size())
target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1)
target = target.type_as(input)
target = (1-label_smooth) * target + label_smooth/num_classes
loss = (-target * log_prob).sum(1)
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
elif reduction == "none":
return loss
else:
raise ValueError

View File

@@ -0,0 +1,45 @@
import torch.nn as nn
class _DSBN(nn.Module):
"""Domain Specific Batch Normalization.
Args:
num_features (int): number of features.
n_domain (int): number of domains.
bn_type (str): type of bn. Choices are ['1d', '2d'].
"""
def __init__(self, num_features, n_domain, bn_type):
super().__init__()
if bn_type == "1d":
BN = nn.BatchNorm1d
elif bn_type == "2d":
BN = nn.BatchNorm2d
else:
raise ValueError
self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))
self.valid_domain_idxs = list(range(n_domain))
self.n_domain = n_domain
self.domain_idx = 0
def select_bn(self, domain_idx=0):
assert domain_idx in self.valid_domain_idxs
self.domain_idx = domain_idx
def forward(self, x):
return self.bn[self.domain_idx](x)
class DSBN1d(_DSBN):
def __init__(self, num_features, n_domain):
super().__init__(num_features, n_domain, "1d")
class DSBN2d(_DSBN):
def __init__(self, num_features, n_domain):
super().__init__(num_features, n_domain, "2d")

View File

@@ -0,0 +1,118 @@
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(False)
def activate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(True)
def random_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("random")
def crossdomain_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_efdmix(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_efdmix)
yield
finally:
model.apply(activate_efdmix)
@contextmanager
def run_with_efdmix(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_efdmix)
elif mix == "crossdomain":
model.apply(crossdomain_efdmix)
try:
model.apply(activate_efdmix)
yield
finally:
model.apply(deactivate_efdmix)
class EFDMix(nn.Module):
"""EFDMix.
Reference:
Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)
x_view = x.view(B, C, -1)
value_x, index_x = torch.sort(x_view) # sort inputs
lmda = self.beta.sample((B, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
inverse_index = index_x.argsort(-1)
x_view_copy = value_x[perm].gather(-1, inverse_index)
new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)
return new_x.view(B, C, W, H)

View File

@@ -0,0 +1,124 @@
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(False)
def activate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(True)
def random_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("random")
def crossdomain_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_mixstyle(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_mixstyle)
yield
finally:
model.apply(activate_mixstyle)
@contextmanager
def run_with_mixstyle(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_mixstyle)
elif mix == "crossdomain":
model.apply(crossdomain_mixstyle)
try:
model.apply(activate_mixstyle)
yield
finally:
model.apply(deactivate_mixstyle)
class MixStyle(nn.Module):
"""MixStyle.
Reference:
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B = x.size(0)
mu = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], keepdim=True)
sig = (var + self.eps).sqrt()
mu, sig = mu.detach(), sig.detach()
x_normed = (x-mu) / sig
lmda = self.beta.sample((B, 1, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
mu2, sig2 = mu[perm], sig[perm]
mu_mix = mu*lmda + mu2 * (1-lmda)
sig_mix = sig*lmda + sig2 * (1-lmda)
return x_normed*sig_mix + mu_mix

View File

@@ -0,0 +1,23 @@
import torch
def mixup(x1, x2, y1, y2, beta, preserve_order=False):
"""Mixup.
Args:
x1 (torch.Tensor): data with shape of (b, c, h, w).
x2 (torch.Tensor): data with shape of (b, c, h, w).
y1 (torch.Tensor): label with shape of (b, n).
y2 (torch.Tensor): label with shape of (b, n).
beta (float): hyper-parameter for Beta sampling.
preserve_order (bool): apply lmda=max(lmda, 1-lmda).
Default is False.
"""
lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])
if preserve_order:
lmda = torch.max(lmda, 1 - lmda)
lmda = lmda.to(x1.device)
xmix = x1*lmda + x2 * (1-lmda)
lmda = lmda[:, :, 0, 0]
ymix = y1*lmda + y2 * (1-lmda)
return xmix, ymix

View File

@@ -0,0 +1,91 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
class MaximumMeanDiscrepancy(nn.Module):
def __init__(self, kernel_type="rbf", normalize=False):
super().__init__()
self.kernel_type = kernel_type
self.normalize = normalize
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
if self.normalize:
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
if self.kernel_type == "linear":
return self.linear_mmd(x, y)
elif self.kernel_type == "poly":
return self.poly_mmd(x, y)
elif self.kernel_type == "rbf":
return self.rbf_mmd(x, y)
else:
raise NotImplementedError
def linear_mmd(self, x, y):
# k(x, y) = x^T y
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_xy = torch.mm(x, y.t())
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
# k(x, y) = (alpha * x^T y + c)^d
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_xx = (alpha*k_xx + c).pow(d)
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_yy = (alpha*k_yy + c).pow(d)
k_xy = torch.mm(x, y.t())
k_xy = (alpha*k_xy + c).pow(d)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def rbf_mmd(self, x, y):
# k_xx
d_xx = self.euclidean_squared_distance(x, x)
d_xx = self.remove_self_distance(d_xx)
k_xx = self.rbf_kernel_mixture(d_xx)
# k_yy
d_yy = self.euclidean_squared_distance(y, y)
d_yy = self.remove_self_distance(d_yy)
k_yy = self.rbf_kernel_mixture(d_yy)
# k_xy
d_xy = self.euclidean_squared_distance(x, y)
k_xy = self.rbf_kernel_mixture(d_xy)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
@staticmethod
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
K = 0
for sigma in sigmas:
gamma = 1.0 / (2.0 * sigma**2)
K += torch.exp(-gamma * exponent)
return K
@staticmethod
def remove_self_distance(distmat):
tmp_list = []
for i, row in enumerate(distmat):
row1 = torch.cat([row[:i], row[i + 1:]])
tmp_list.append(row1)
return torch.stack(tmp_list)
@staticmethod
def euclidean_squared_distance(x, y):
m, n = x.size(0), y.size(0)
distmat = (
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
# distmat.addmm_(1, -2, x, y.t())
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
return distmat
if __name__ == "__main__":
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
d = mmd(input1, input2)
print(d.item())

View File

@@ -0,0 +1,147 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
class OptimalTransport(nn.Module):
@staticmethod
def distance(batch1, batch2, dist_metric="cosine"):
if dist_metric == "cosine":
batch1 = F.normalize(batch1, p=2, dim=1)
batch2 = F.normalize(batch2, p=2, dim=1)
dist_mat = 1 - torch.mm(batch1, batch2.t())
elif dist_metric == "euclidean":
m, n = batch1.size(0), batch2.size(0)
dist_mat = (
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
dist_mat.addmm_(
1, -2, batch1, batch2.t()
) # squared euclidean distance
elif dist_metric == "fast_euclidean":
batch1 = batch1.unsqueeze(-2)
batch2 = batch2.unsqueeze(-3)
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
else:
raise ValueError(
"Unknown cost function: {}. Expected to "
"be one of [cosine | euclidean]".format(dist_metric)
)
return dist_mat
class SinkhornDivergence(OptimalTransport):
thre = 1e-3
def __init__(
self,
dist_metric="cosine",
eps=0.01,
max_iter=5,
bp_to_sinkhorn=False
):
super().__init__()
self.dist_metric = dist_metric
self.eps = eps
self.max_iter = max_iter
self.bp_to_sinkhorn = bp_to_sinkhorn
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
W_xy = self.transport_cost(x, y)
W_xx = self.transport_cost(x, x)
W_yy = self.transport_cost(y, y)
return 2*W_xy - W_xx - W_yy
def transport_cost(self, x, y, return_pi=False):
C = self.distance(x, y, dist_metric=self.dist_metric)
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
if not self.bp_to_sinkhorn:
pi = pi.detach()
cost = torch.sum(pi * C)
if return_pi:
return cost, pi
return cost
@staticmethod
def sinkhorn_iterate(C, eps, max_iter, thre):
nx, ny = C.shape
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
u = torch.zeros_like(mu)
v = torch.zeros_like(nu)
def M(_C, _u, _v):
"""Modified cost for logarithmic updates.
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
"""
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
real_iter = 0 # check if algorithm terminates before max_iter
# Sinkhorn iterations
for i in range(max_iter):
u0 = u
u = eps * (
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
) + u
v = (
eps * (
torch.log(nu + 1e-8) -
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
) + v
)
err = (u - u0).abs().sum()
real_iter += 1
if err.item() < thre:
break
# Transport plan pi = diag(a)*K*diag(b)
return torch.exp(M(C, u, v))
class MinibatchEnergyDistance(SinkhornDivergence):
def __init__(
self,
dist_metric="cosine",
eps=0.01,
max_iter=5,
bp_to_sinkhorn=False
):
super().__init__(
dist_metric=dist_metric,
eps=eps,
max_iter=max_iter,
bp_to_sinkhorn=bp_to_sinkhorn,
)
def forward(self, x, y):
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
cost = 0
cost += self.transport_cost(x1, y1)
cost += self.transport_cost(x1, y2)
cost += self.transport_cost(x2, y1)
cost += self.transport_cost(x2, y2)
cost -= 2 * self.transport_cost(x1, x2)
cost -= 2 * self.transport_cost(y1, y2)
return cost
if __name__ == "__main__":
# example: https://dfdazac.github.io/sinkhorn.html
import numpy as np
n_points = 5
a = np.array([[i, 0] for i in range(n_points)])
b = np.array([[i, 1] for i in range(n_points)])
x = torch.tensor(a, dtype=torch.float)
y = torch.tensor(b, dtype=torch.float)
sinkhorn = SinkhornDivergence(
dist_metric="euclidean", eps=0.01, max_iter=5
)
dist, pi = sinkhorn.transport_cost(x, y, True)
import pdb
pdb.set_trace()

View File

@@ -0,0 +1,34 @@
import torch.nn as nn
from torch.autograd import Function
class _ReverseGrad(Function):
@staticmethod
def forward(ctx, input, grad_scaling):
ctx.grad_scaling = grad_scaling
return input.view_as(input)
@staticmethod
def backward(ctx, grad_output):
grad_scaling = ctx.grad_scaling
return -grad_scaling * grad_output, None
reverse_grad = _ReverseGrad.apply
class ReverseGrad(nn.Module):
"""Gradient reversal layer.
It acts as an identity layer in the forward,
but reverses the sign of the gradient in
the backward.
"""
def forward(self, x, grad_scaling=1.0):
assert (grad_scaling >=
0), "grad_scaling must be non-negative, " "but got {}".format(
grad_scaling
)
return reverse_grad(x, grad_scaling)

View File

@@ -0,0 +1,15 @@
import torch.nn as nn
class Sequential2(nn.Sequential):
"""An alternative sequential container to nn.Sequential,
which accepts an arbitrary number of input arguments.
"""
def forward(self, *inputs):
for module in self._modules.values():
if isinstance(inputs, tuple):
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs

View File

@@ -0,0 +1,138 @@
import torch
import torch.nn as nn
class _TransNorm(nn.Module):
"""Transferable normalization.
Reference:
- Wang et al. Transferable Normalization: Towards Improving
Transferability of Deep Neural Networks. NeurIPS 2019.
Args:
num_features (int): number of features.
eps (float): epsilon.
momentum (float): value for updating running_mean and running_var.
adaptive_alpha (bool): apply domain adaptive alpha.
"""
def __init__(
self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.adaptive_alpha = adaptive_alpha
self.register_buffer("running_mean_s", torch.zeros(num_features))
self.register_buffer("running_var_s", torch.ones(num_features))
self.register_buffer("running_mean_t", torch.zeros(num_features))
self.register_buffer("running_var_t", torch.ones(num_features))
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def resnet_running_stats(self):
self.running_mean_s.zero_()
self.running_var_s.fill_(1)
self.running_mean_t.zero_()
self.running_var_t.fill_(1)
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def _check_input(self, x):
raise NotImplementedError
def _compute_alpha(self, mean_s, var_s, mean_t, var_t):
C = self.num_features
ratio_s = mean_s / (var_s + self.eps).sqrt()
ratio_t = mean_t / (var_t + self.eps).sqrt()
dist = (ratio_s - ratio_t).abs()
dist_inv = 1 / (1+dist)
return C * dist_inv / dist_inv.sum()
def forward(self, input):
self._check_input(input)
C = self.num_features
if input.dim() == 2:
new_shape = (1, C)
elif input.dim() == 4:
new_shape = (1, C, 1, 1)
else:
raise ValueError
weight = self.weight.view(*new_shape)
bias = self.bias.view(*new_shape)
if not self.training:
mean_t = self.running_mean_t.view(*new_shape)
var_t = self.running_var_t.view(*new_shape)
output = (input-mean_t) / (var_t + self.eps).sqrt()
output = output*weight + bias
if self.adaptive_alpha:
mean_s = self.running_mean_s.view(*new_shape)
var_s = self.running_var_s.view(*new_shape)
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)
x_s = input_s.transpose(0, 1).reshape(C, -1)
mean_s = x_s.mean(1)
var_s = x_s.var(1)
self.running_mean_s.mul_(self.momentum)
self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
self.running_var_s.mul_(self.momentum)
self.running_var_s.add_((1 - self.momentum) * var_s.data)
mean_s = mean_s.reshape(*new_shape)
var_s = var_s.reshape(*new_shape)
output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
output_s = output_s*weight + bias
x_t = input_t.transpose(0, 1).reshape(C, -1)
mean_t = x_t.mean(1)
var_t = x_t.var(1)
self.running_mean_t.mul_(self.momentum)
self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
self.running_var_t.mul_(self.momentum)
self.running_var_t.add_((1 - self.momentum) * var_t.data)
mean_t = mean_t.reshape(*new_shape)
var_t = var_t.reshape(*new_shape)
output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
output_t = output_t*weight + bias
output = torch.cat([output_s, output_t], 0)
if self.adaptive_alpha:
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
class TransNorm1d(_TransNorm):
def _check_input(self, x):
if x.dim() != 2:
raise ValueError(
"Expected the input to be 2-D, "
"but got {}-D".format(x.dim())
)
class TransNorm2d(_TransNorm):
def _check_input(self, x):
if x.dim() != 4:
raise ValueError(
"Expected the input to be 4-D, "
"but got {}-D".format(x.dim())
)

View File

@@ -0,0 +1,75 @@
import numpy as np
import torch
def sharpen_prob(p, temperature=2):
"""Sharpening probability with a temperature.
Args:
p (torch.Tensor): probability matrix (batch_size, n_classes)
temperature (float): temperature.
"""
p = p.pow(temperature)
return p / p.sum(1, keepdim=True)
def reverse_index(data, label):
"""Reverse order."""
inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()
return data[inv_idx], label[inv_idx]
def shuffle_index(data, label):
"""Shuffle order."""
rnd_idx = torch.randperm(data.shape[0])
return data[rnd_idx], label[rnd_idx]
def create_onehot(label, num_classes):
"""Create one-hot tensor.
We suggest using nn.functional.one_hot.
Args:
label (torch.Tensor): 1-D tensor.
num_classes (int): number of classes.
"""
onehot = torch.zeros(label.shape[0], num_classes)
return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)
def sigmoid_rampup(current, rampup_length):
"""Exponential rampup.
Args:
current (int): current step.
rampup_length (int): maximum step.
"""
assert rampup_length > 0
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current/rampup_length
return float(np.exp(-5.0 * phase * phase))
def linear_rampup(current, rampup_length):
"""Linear rampup.
Args:
current (int): current step.
rampup_length (int): maximum step.
"""
assert rampup_length > 0
ratio = np.clip(current / rampup_length, 0.0, 1.0)
return float(ratio)
def ema_model_update(model, ema_model, alpha):
"""Exponential moving average of model parameters.
Args:
model (nn.Module): model being trained.
ema_model (nn.Module): ema of the model.
alpha (float): ema decay rate.
"""
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)