46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
import torch.nn as nn
|
|
|
|
|
|
class _DSBN(nn.Module):
|
|
"""Domain Specific Batch Normalization.
|
|
|
|
Args:
|
|
num_features (int): number of features.
|
|
n_domain (int): number of domains.
|
|
bn_type (str): type of bn. Choices are ['1d', '2d'].
|
|
"""
|
|
|
|
def __init__(self, num_features, n_domain, bn_type):
|
|
super().__init__()
|
|
if bn_type == "1d":
|
|
BN = nn.BatchNorm1d
|
|
elif bn_type == "2d":
|
|
BN = nn.BatchNorm2d
|
|
else:
|
|
raise ValueError
|
|
|
|
self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))
|
|
|
|
self.valid_domain_idxs = list(range(n_domain))
|
|
self.n_domain = n_domain
|
|
self.domain_idx = 0
|
|
|
|
def select_bn(self, domain_idx=0):
|
|
assert domain_idx in self.valid_domain_idxs
|
|
self.domain_idx = domain_idx
|
|
|
|
def forward(self, x):
|
|
return self.bn[self.domain_idx](x)
|
|
|
|
|
|
class DSBN1d(_DSBN):
|
|
|
|
def __init__(self, num_features, n_domain):
|
|
super().__init__(num_features, n_domain, "1d")
|
|
|
|
|
|
class DSBN2d(_DSBN):
|
|
|
|
def __init__(self, num_features, n_domain):
|
|
super().__init__(num_features, n_domain, "2d")
|