release code
This commit is contained in:
124
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def deactivate_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.set_activation_status(False)
|
||||
|
||||
|
||||
def activate_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.set_activation_status(True)
|
||||
|
||||
|
||||
def random_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.update_mix_method("random")
|
||||
|
||||
|
||||
def crossdomain_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.update_mix_method("crossdomain")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_without_mixstyle(model):
|
||||
# Assume MixStyle was initially activated
|
||||
try:
|
||||
model.apply(deactivate_mixstyle)
|
||||
yield
|
||||
finally:
|
||||
model.apply(activate_mixstyle)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_with_mixstyle(model, mix=None):
|
||||
# Assume MixStyle was initially deactivated
|
||||
if mix == "random":
|
||||
model.apply(random_mixstyle)
|
||||
|
||||
elif mix == "crossdomain":
|
||||
model.apply(crossdomain_mixstyle)
|
||||
|
||||
try:
|
||||
model.apply(activate_mixstyle)
|
||||
yield
|
||||
finally:
|
||||
model.apply(deactivate_mixstyle)
|
||||
|
||||
|
||||
class MixStyle(nn.Module):
|
||||
"""MixStyle.
|
||||
|
||||
Reference:
|
||||
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
|
||||
"""
|
||||
Args:
|
||||
p (float): probability of using MixStyle.
|
||||
alpha (float): parameter of the Beta distribution.
|
||||
eps (float): scaling parameter to avoid numerical issues.
|
||||
mix (str): how to mix.
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.beta = torch.distributions.Beta(alpha, alpha)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
self.mix = mix
|
||||
self._activated = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
|
||||
)
|
||||
|
||||
def set_activation_status(self, status=True):
|
||||
self._activated = status
|
||||
|
||||
def update_mix_method(self, mix="random"):
|
||||
self.mix = mix
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self._activated:
|
||||
return x
|
||||
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
|
||||
B = x.size(0)
|
||||
|
||||
mu = x.mean(dim=[2, 3], keepdim=True)
|
||||
var = x.var(dim=[2, 3], keepdim=True)
|
||||
sig = (var + self.eps).sqrt()
|
||||
mu, sig = mu.detach(), sig.detach()
|
||||
x_normed = (x-mu) / sig
|
||||
|
||||
lmda = self.beta.sample((B, 1, 1, 1))
|
||||
lmda = lmda.to(x.device)
|
||||
|
||||
if self.mix == "random":
|
||||
# random shuffle
|
||||
perm = torch.randperm(B)
|
||||
|
||||
elif self.mix == "crossdomain":
|
||||
# split into two halves and swap the order
|
||||
perm = torch.arange(B - 1, -1, -1) # inverse index
|
||||
perm_b, perm_a = perm.chunk(2)
|
||||
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
|
||||
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
|
||||
perm = torch.cat([perm_b, perm_a], 0)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
mu2, sig2 = mu[perm], sig[perm]
|
||||
mu_mix = mu*lmda + mu2 * (1-lmda)
|
||||
sig_mix = sig*lmda + sig2 * (1-lmda)
|
||||
|
||||
return x_normed*sig_mix + mu_mix
|
||||
Reference in New Issue
Block a user