release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
"""
Dassl
------
PyTorch toolbox for domain adaptation and semi-supervised learning.
URL: https://github.com/KaiyangZhou/Dassl.pytorch
@article{zhou2020domain,
title={Domain Adaptive Ensemble Learning},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
journal={arXiv preprint arXiv:2003.07325},
year={2020}
}
"""
__version__ = "0.5.0"
__author__ = "Kaiyang Zhou"
__homepage__ = "https://kaiyangzhou.github.io/"

View File

@@ -0,0 +1,5 @@
from .defaults import _C as cfg_default
def get_cfg_default():
return cfg_default.clone()

View File

@@ -0,0 +1,275 @@
from yacs.config import CfgNode as CN
###########################
# Config definition
###########################
_C = CN()
_C.VERSION = 1
# Directory to save the output files (like log.txt and model weights)
_C.OUTPUT_DIR = "./output"
# Path to a directory where the files were saved previously
_C.RESUME = ""
# Set seed to negative value to randomize everything
# Set seed to positive value to use a fixed seed
_C.SEED = -1
_C.USE_CUDA = True
# Print detailed information
# E.g. trainer, dataset, and backbone
_C.VERBOSE = True
###########################
# Input
###########################
_C.INPUT = CN()
_C.INPUT.SIZE = (224, 224)
# Mode of interpolation in resize functions
_C.INPUT.INTERPOLATION = "bilinear"
# For available choices please refer to transforms.py
_C.INPUT.TRANSFORMS = ()
# If True, tfm_train and tfm_test will be None
_C.INPUT.NO_TRANSFORM = False
# Default mean and std come from ImageNet
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
# Padding for random crop
_C.INPUT.CROP_PADDING = 4
# Cutout
_C.INPUT.CUTOUT_N = 1
_C.INPUT.CUTOUT_LEN = 16
# Gaussian noise
_C.INPUT.GN_MEAN = 0.0
_C.INPUT.GN_STD = 0.15
# RandomAugment
_C.INPUT.RANDAUGMENT_N = 2
_C.INPUT.RANDAUGMENT_M = 10
# ColorJitter (brightness, contrast, saturation, hue)
_C.INPUT.COLORJITTER_B = 0.4
_C.INPUT.COLORJITTER_C = 0.4
_C.INPUT.COLORJITTER_S = 0.4
_C.INPUT.COLORJITTER_H = 0.1
# Random gray scale's probability
_C.INPUT.RGS_P = 0.2
# Gaussian blur
_C.INPUT.GB_P = 0.5 # propability of applying this operation
_C.INPUT.GB_K = 21 # kernel size (should be an odd number)
###########################
# Dataset
###########################
_C.DATASET = CN()
# Directory where datasets are stored
_C.DATASET.ROOT = ""
_C.DATASET.NAME = ""
# List of names of source domains
_C.DATASET.SOURCE_DOMAINS = ()
# List of names of target domains
_C.DATASET.TARGET_DOMAINS = ()
# Number of labeled instances in total
# Useful for the semi-supervised learning
_C.DATASET.NUM_LABELED = -1
# Number of images per class
_C.DATASET.NUM_SHOTS = -1
# Percentage of validation data (only used for SSL datasets)
# Set to 0 if do not want to use val data
# Using val data for hyperparameter tuning was done in Oliver et al. 2018
_C.DATASET.VAL_PERCENT = 0.1
# Fold index for STL-10 dataset (normal range is 0 - 9)
# Negative number means None
_C.DATASET.STL10_FOLD = -1
# CIFAR-10/100-C's corruption type and intensity level
_C.DATASET.CIFAR_C_TYPE = ""
_C.DATASET.CIFAR_C_LEVEL = 1
# Use all data in the unlabeled data set (e.g. FixMatch)
_C.DATASET.ALL_AS_UNLABELED = False
###########################
# Dataloader
###########################
_C.DATALOADER = CN()
_C.DATALOADER.NUM_WORKERS = 4
# Apply transformations to an image K times (during training)
_C.DATALOADER.K_TRANSFORMS = 1
# img0 denotes image tensor without augmentation
# Useful for consistency learning
_C.DATALOADER.RETURN_IMG0 = False
# Setting for the train_x data-loader
_C.DATALOADER.TRAIN_X = CN()
_C.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler"
_C.DATALOADER.TRAIN_X.BATCH_SIZE = 32
# Parameter for RandomDomainSampler
# 0 or -1 means sampling from all domains
_C.DATALOADER.TRAIN_X.N_DOMAIN = 0
# Parameter of RandomClassSampler
# Number of instances per class
_C.DATALOADER.TRAIN_X.N_INS = 16
# Setting for the train_u data-loader
_C.DATALOADER.TRAIN_U = CN()
# Set to false if you want to have unique
# data loader params for train_u
_C.DATALOADER.TRAIN_U.SAME_AS_X = True
_C.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler"
_C.DATALOADER.TRAIN_U.BATCH_SIZE = 32
_C.DATALOADER.TRAIN_U.N_DOMAIN = 0
_C.DATALOADER.TRAIN_U.N_INS = 16
# Setting for the test data-loader
_C.DATALOADER.TEST = CN()
_C.DATALOADER.TEST.SAMPLER = "SequentialSampler"
_C.DATALOADER.TEST.BATCH_SIZE = 32
###########################
# Model
###########################
_C.MODEL = CN()
# Path to model weights (for initialization)
_C.MODEL.INIT_WEIGHTS = ""
_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = ""
_C.MODEL.BACKBONE.PRETRAINED = True
# Definition of embedding layers
_C.MODEL.HEAD = CN()
# If none, do not construct embedding layers, the
# backbone's output will be passed to the classifier
_C.MODEL.HEAD.NAME = ""
# Structure of hidden layers (a list), e.g. [512, 512]
# If undefined, no embedding layer will be constructed
_C.MODEL.HEAD.HIDDEN_LAYERS = ()
_C.MODEL.HEAD.ACTIVATION = "relu"
_C.MODEL.HEAD.BN = True
_C.MODEL.HEAD.DROPOUT = 0.0
###########################
# Optimization
###########################
_C.OPTIM = CN()
_C.OPTIM.NAME = "adam"
_C.OPTIM.LR = 0.0003
_C.OPTIM.WEIGHT_DECAY = 5e-4
_C.OPTIM.MOMENTUM = 0.9
_C.OPTIM.SGD_DAMPNING = 0
_C.OPTIM.SGD_NESTEROV = False
_C.OPTIM.RMSPROP_ALPHA = 0.99
_C.OPTIM.ADAM_BETA1 = 0.9
_C.OPTIM.ADAM_BETA2 = 0.999
# STAGED_LR allows different layers to have
# different lr, e.g. pre-trained base layers
# can be assigned a smaller lr than the new
# classification layer
_C.OPTIM.STAGED_LR = False
_C.OPTIM.NEW_LAYERS = ()
_C.OPTIM.BASE_LR_MULT = 0.1
# Learning rate scheduler
_C.OPTIM.LR_SCHEDULER = "single_step"
# -1 or 0 means the stepsize is equal to max_epoch
_C.OPTIM.STEPSIZE = (-1, )
_C.OPTIM.GAMMA = 0.1
_C.OPTIM.MAX_EPOCH = 10
# Set WARMUP_EPOCH larger than 0 to activate warmup training
_C.OPTIM.WARMUP_EPOCH = -1
# Either linear or constant
_C.OPTIM.WARMUP_TYPE = "linear"
# Constant learning rate when type=constant
_C.OPTIM.WARMUP_CONS_LR = 1e-5
# Minimum learning rate when type=linear
_C.OPTIM.WARMUP_MIN_LR = 1e-5
# Recount epoch for the next scheduler (last_epoch=-1)
# Otherwise last_epoch=warmup_epoch
_C.OPTIM.WARMUP_RECOUNT = True
###########################
# Train
###########################
_C.TRAIN = CN()
# How often (epoch) to save model during training
# Set to 0 or negative value to only save the last one
_C.TRAIN.CHECKPOINT_FREQ = 0
# How often (batch) to print training information
_C.TRAIN.PRINT_FREQ = 10
# Use 'train_x', 'train_u' or 'smaller_one' to count
# the number of iterations in an epoch (for DA and SSL)
_C.TRAIN.COUNT_ITER = "train_x"
###########################
# Test
###########################
_C.TEST = CN()
_C.TEST.EVALUATOR = "Classification"
_C.TEST.PER_CLASS_RESULT = False
# Compute confusion matrix, which will be saved
# to $OUTPUT_DIR/cmat.pt
_C.TEST.COMPUTE_CMAT = False
# If NO_TEST=True, no testing will be conducted
_C.TEST.NO_TEST = False
# Use test or val set for FINAL evaluation
_C.TEST.SPLIT = "test"
# Which model to test after training
# Either last_step or best_val
_C.TEST.FINAL_MODEL = "last_step"
###########################
# Trainer specifics
###########################
_C.TRAINER = CN()
_C.TRAINER.NAME = ""
# MCD
_C.TRAINER.MCD = CN()
_C.TRAINER.MCD.N_STEP_F = 4 # number of steps to train F
# MME
_C.TRAINER.MME = CN()
_C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss
# SelfEnsembling
_C.TRAINER.SE = CN()
_C.TRAINER.SE.EMA_ALPHA = 0.999
_C.TRAINER.SE.CONF_THRE = 0.95
_C.TRAINER.SE.RAMPUP = 300
# M3SDA
_C.TRAINER.M3SDA = CN()
_C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss
_C.TRAINER.M3SDA.N_STEP_F = 4 # follow MCD
# DAEL
_C.TRAINER.DAEL = CN()
_C.TRAINER.DAEL.WEIGHT_U = 0.5 # weight on the unlabeled loss
_C.TRAINER.DAEL.CONF_THRE = 0.95 # confidence threshold
_C.TRAINER.DAEL.STRONG_TRANSFORMS = ()
# CrossGrad
_C.TRAINER.CG = CN()
_C.TRAINER.CG.EPS_F = 1.0 # scaling parameter for D's gradients
_C.TRAINER.CG.EPS_D = 1.0 # scaling parameter for F's gradients
_C.TRAINER.CG.ALPHA_F = 0.5 # balancing weight for the label net's loss
_C.TRAINER.CG.ALPHA_D = 0.5 # balancing weight for the domain net's loss
# DDAIG
_C.TRAINER.DDAIG = CN()
_C.TRAINER.DDAIG.G_ARCH = "" # generator's architecture
_C.TRAINER.DDAIG.LMDA = 0.3 # perturbation weight
_C.TRAINER.DDAIG.CLAMP = False # clamp perturbation values
_C.TRAINER.DDAIG.CLAMP_MIN = -1.0
_C.TRAINER.DDAIG.CLAMP_MAX = 1.0
_C.TRAINER.DDAIG.WARMUP = 0
_C.TRAINER.DDAIG.ALPHA = 0.5 # balancing weight for the losses
# EntMin
_C.TRAINER.ENTMIN = CN()
_C.TRAINER.ENTMIN.LMDA = 1e-3 # weight on the entropy loss
# Mean Teacher
_C.TRAINER.MEANTEA = CN()
_C.TRAINER.MEANTEA.WEIGHT_U = 1.0 # weight on the unlabeled loss
_C.TRAINER.MEANTEA.EMA_ALPHA = 0.999
_C.TRAINER.MEANTEA.RAMPUP = 5 # epochs used to ramp up the loss_u weight
# MixMatch
_C.TRAINER.MIXMATCH = CN()
_C.TRAINER.MIXMATCH.WEIGHT_U = 100.0 # weight on the unlabeled loss
_C.TRAINER.MIXMATCH.TEMP = 2.0 # temperature for sharpening the probability
_C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75
_C.TRAINER.MIXMATCH.RAMPUP = 20000 # steps used to ramp up the loss_u weight
# FixMatch
_C.TRAINER.FIXMATCH = CN()
_C.TRAINER.FIXMATCH.WEIGHT_U = 1.0 # weight on the unlabeled loss
_C.TRAINER.FIXMATCH.CONF_THRE = 0.95 # confidence threshold
_C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = ()

View File

@@ -0,0 +1 @@
from .data_manager import DataManager, DatasetWrapper

View File

@@ -0,0 +1,264 @@
import torch
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import Dataset as TorchDataset
from dassl.utils import read_image
from .datasets import build_dataset
from .samplers import build_sampler
from .transforms import build_transform
INTERPOLATION_MODES = {
"bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC,
"nearest": Image.NEAREST,
}
def build_data_loader(
cfg,
sampler_type="SequentialSampler",
data_source=None,
batch_size=64,
n_domain=0,
n_ins=2,
tfm=None,
is_train=True,
dataset_wrapper=None,
):
# Build sampler
sampler = build_sampler(
sampler_type,
cfg=cfg,
data_source=data_source,
batch_size=batch_size,
n_domain=n_domain,
n_ins=n_ins,
)
if dataset_wrapper is None:
dataset_wrapper = DatasetWrapper
# Build data loader
data_loader = torch.utils.data.DataLoader(
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
batch_size=batch_size,
sampler=sampler,
num_workers=cfg.DATALOADER.NUM_WORKERS,
drop_last=is_train and len(data_source) >= batch_size,
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
)
assert len(data_loader) > 0
return data_loader
class DataManager:
def __init__(
self,
cfg,
custom_tfm_train=None,
custom_tfm_test=None,
dataset_wrapper=None
):
# Load dataset
dataset = build_dataset(cfg)
# Build transform
if custom_tfm_train is None:
tfm_train = build_transform(cfg, is_train=True)
else:
print("* Using custom transform for training")
tfm_train = custom_tfm_train
if custom_tfm_test is None:
tfm_test = build_transform(cfg, is_train=False)
else:
print("* Using custom transform for testing")
tfm_test = custom_tfm_test
# Build train_loader_x
train_loader_x = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
data_source=dataset.train_x,
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
tfm=tfm_train,
is_train=True,
dataset_wrapper=dataset_wrapper,
)
# Build train_loader_u
train_loader_u = None
if dataset.train_u:
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
train_loader_u = build_data_loader(
cfg,
sampler_type=sampler_type_,
data_source=dataset.train_u,
batch_size=batch_size_,
n_domain=n_domain_,
n_ins=n_ins_,
tfm=tfm_train,
is_train=True,
dataset_wrapper=dataset_wrapper,
)
# Build val_loader
val_loader = None
if dataset.val:
val_loader = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
data_source=dataset.val,
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
tfm=tfm_test,
is_train=False,
dataset_wrapper=dataset_wrapper,
)
# Build test_loader
test_loader = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
data_source=dataset.test,
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
tfm=tfm_test,
is_train=False,
dataset_wrapper=dataset_wrapper,
)
# Attributes
self._num_classes = dataset.num_classes
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
self._lab2cname = dataset.lab2cname
# Dataset and data-loaders
self.dataset = dataset
self.train_loader_x = train_loader_x
self.train_loader_u = train_loader_u
self.val_loader = val_loader
self.test_loader = test_loader
if cfg.VERBOSE:
self.show_dataset_summary(cfg)
@property
def num_classes(self):
return self._num_classes
@property
def num_source_domains(self):
return self._num_source_domains
@property
def lab2cname(self):
return self._lab2cname
def show_dataset_summary(self, cfg):
print("***** Dataset statistics *****")
print(" Dataset: {}".format(cfg.DATASET.NAME))
if cfg.DATASET.SOURCE_DOMAINS:
print(" Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS))
if cfg.DATASET.TARGET_DOMAINS:
print(" Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS))
print(" # classes: {:,}".format(self.num_classes))
print(" # train_x: {:,}".format(len(self.dataset.train_x)))
if self.dataset.train_u:
print(" # train_u: {:,}".format(len(self.dataset.train_u)))
if self.dataset.val:
print(" # val: {:,}".format(len(self.dataset.val)))
print(" # test: {:,}".format(len(self.dataset.test)))
class DatasetWrapper(TorchDataset):
def __init__(self, cfg, data_source, transform=None, is_train=False):
self.cfg = cfg
self.data_source = data_source
self.transform = transform # accept list (tuple) as input
self.is_train = is_train
# Augmenting an image K>1 times is only allowed during training
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
if self.k_tfm > 1 and transform is None:
raise ValueError(
"Cannot augment the image {} times "
"because transform is None".format(self.k_tfm)
)
# Build transform that doesn't apply any data augmentation
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
to_tensor = []
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
to_tensor += [T.ToTensor()]
if "normalize" in cfg.INPUT.TRANSFORMS:
normalize = T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
)
to_tensor += [normalize]
self.to_tensor = T.Compose(to_tensor)
def __len__(self):
return len(self.data_source)
def __getitem__(self, idx):
item = self.data_source[idx]
output = {
"label": item.label,
"domain": item.domain,
"impath": item.impath
}
img0 = read_image(item.impath)
if self.transform is not None:
if isinstance(self.transform, (list, tuple)):
for i, tfm in enumerate(self.transform):
img = self._transform_image(tfm, img0)
keyname = "img"
if (i + 1) > 1:
keyname += str(i + 1)
output[keyname] = img
else:
img = self._transform_image(self.transform, img0)
output["img"] = img
if self.return_img0:
output["img0"] = self.to_tensor(img0)
return output
def _transform_image(self, tfm, img0):
img_list = []
for k in range(self.k_tfm):
img_list.append(tfm(img0))
img = img_list
if len(img) == 1:
img = img[0]
return img

View File

@@ -0,0 +1,6 @@
from .build import DATASET_REGISTRY, build_dataset # isort:skip
from .base_dataset import Datum, DatasetBase # isort:skip
from .da import *
from .dg import *
from .ssl import *

View File

@@ -0,0 +1,225 @@
import os
import random
import os.path as osp
import tarfile
import zipfile
from collections import defaultdict
import gdown
from dassl.utils import check_isfile
class Datum:
"""Data instance which defines the basic attributes.
Args:
impath (str): image path.
label (int): class label.
domain (int): domain label.
classname (str): class name.
"""
def __init__(self, impath="", label=0, domain=0, classname=""):
assert isinstance(impath, str)
assert check_isfile(impath)
self._impath = impath
self._label = label
self._domain = domain
self._classname = classname
@property
def impath(self):
return self._impath
@property
def label(self):
return self._label
@property
def domain(self):
return self._domain
@property
def classname(self):
return self._classname
class DatasetBase:
"""A unified dataset class for
1) domain adaptation
2) domain generalization
3) semi-supervised learning
"""
dataset_dir = "" # the directory where the dataset is stored
domains = [] # string names of all domains
def __init__(self, train_x=None, train_u=None, val=None, test=None):
self._train_x = train_x # labeled training data
self._train_u = train_u # unlabeled training data (optional)
self._val = val # validation data (optional)
self._test = test # test data
self._num_classes = self.get_num_classes(train_x)
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
@property
def train_x(self):
return self._train_x
@property
def train_u(self):
return self._train_u
@property
def val(self):
return self._val
@property
def test(self):
return self._test
@property
def lab2cname(self):
return self._lab2cname
@property
def classnames(self):
return self._classnames
@property
def num_classes(self):
return self._num_classes
def get_num_classes(self, data_source):
"""Count number of classes.
Args:
data_source (list): a list of Datum objects.
"""
label_set = set()
for item in data_source:
label_set.add(item.label)
return max(label_set) + 1
def get_lab2cname(self, data_source):
"""Get a label-to-classname mapping (dict).
Args:
data_source (list): a list of Datum objects.
"""
container = set()
for item in data_source:
container.add((item.label, item.classname))
mapping = {label: classname for label, classname in container}
labels = list(mapping.keys())
labels.sort()
classnames = [mapping[label] for label in labels]
return mapping, classnames
def check_input_domains(self, source_domains, target_domains):
self.is_input_domain_valid(source_domains)
self.is_input_domain_valid(target_domains)
def is_input_domain_valid(self, input_domains):
for domain in input_domains:
if domain not in self.domains:
raise ValueError(
"Input domain must belong to {}, "
"but got [{}]".format(self.domains, domain)
)
def download_data(self, url, dst, from_gdrive=True):
if not osp.exists(osp.dirname(dst)):
os.makedirs(osp.dirname(dst))
if from_gdrive:
gdown.download(url, dst, quiet=False)
else:
raise NotImplementedError
print("Extracting file ...")
try:
tar = tarfile.open(dst)
tar.extractall(path=osp.dirname(dst))
tar.close()
except:
zip_ref = zipfile.ZipFile(dst, "r")
zip_ref.extractall(osp.dirname(dst))
zip_ref.close()
print("File extracted to {}".format(osp.dirname(dst)))
def generate_fewshot_dataset(
self, *data_sources, num_shots=-1, repeat=False
):
"""Generate a few-shot dataset (typically for the training set).
This function is useful when one wants to evaluate a model
in a few-shot learning setting where each class only contains
a few number of images.
Args:
data_sources: each individual is a list containing Datum objects.
num_shots (int): number of instances per class to sample.
repeat (bool): repeat images if needed (default: False).
"""
if num_shots < 1:
if len(data_sources) == 1:
return data_sources[0]
return data_sources
print(f"Creating a {num_shots}-shot dataset")
output = []
for data_source in data_sources:
tracker = self.split_dataset_by_label(data_source)
dataset = []
for label, items in tracker.items():
if len(items) >= num_shots:
sampled_items = random.sample(items, num_shots)
else:
if repeat:
sampled_items = random.choices(items, k=num_shots)
else:
sampled_items = items
dataset.extend(sampled_items)
output.append(dataset)
if len(output) == 1:
return output[0]
return output
def split_dataset_by_label(self, data_source):
"""Split a dataset, i.e. a list of Datum objects,
into class-specific groups stored in a dictionary.
Args:
data_source (list): a list of Datum objects.
"""
output = defaultdict(list)
for item in data_source:
output[item.label].append(item)
return output
def split_dataset_by_domain(self, data_source):
"""Split a dataset, i.e. a list of Datum objects,
into domain-specific groups stored in a dictionary.
Args:
data_source (list): a list of Datum objects.
"""
output = defaultdict(list)
for item in data_source:
output[item.domain].append(item)
return output

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
DATASET_REGISTRY = Registry("DATASET")
def build_dataset(cfg):
avai_datasets = DATASET_REGISTRY.registered_names()
check_availability(cfg.DATASET.NAME, avai_datasets)
if cfg.VERBOSE:
print("Loading dataset: {}".format(cfg.DATASET.NAME))
return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg)

View File

@@ -0,0 +1,7 @@
from .digit5 import Digit5
from .visda17 import VisDA17
from .cifarstl import CIFARSTL
from .office31 import Office31
from .domainnet import DomainNet
from .office_home import OfficeHome
from .mini_domainnet import miniDomainNet

View File

@@ -0,0 +1,68 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class CIFARSTL(DatasetBase):
"""CIFAR-10 and STL-10.
CIFAR-10:
- 60,000 32x32 colour images.
- 10 classes, with 6,000 images per class.
- 50,000 training images and 10,000 test images.
- URL: https://www.cs.toronto.edu/~kriz/cifar.html.
STL-10:
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
monkey, ship, truck.
- Images are 96x96 pixels, color.
- 500 training images (10 pre-defined folds), 800 test images
per class.
- URL: https://cs.stanford.edu/~acoates/stl10/.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
- Coates et al. An Analysis of Single Layer Networks in
Unsupervised Feature Learning. AISTATS 2011.
"""
dataset_dir = "cifar_stl"
domains = ["cifar", "stl"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
data_dir = osp.join(self.dataset_dir, dname, split)
class_names = listdir_nohidden(data_dir)
for class_name in class_names:
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
label = int(class_name.split("_")[0])
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items

View File

@@ -0,0 +1,124 @@
import random
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}
def read_image_list(im_dir, n_max=None, n_repeat=None):
items = []
for imname in listdir_nohidden(im_dir):
imname_noext = osp.splitext(imname)[0]
label = int(imname_noext.split("_")[1])
impath = osp.join(im_dir, imname)
items.append((impath, label))
if n_max is not None:
items = random.sample(items, n_max)
if n_repeat is not None:
items *= n_repeat
return items
def load_mnist(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_mnist_m(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST_M[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_svhn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SVHN[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_syn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SYN[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_usps(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, USPS[split])
n_repeat = 3 if split == "train" else None
return read_image_list(data_dir, n_repeat=n_repeat)
@DATASET_REGISTRY.register()
class Digit5(DatasetBase):
"""Five digit datasets.
It contains:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
- USPS: hand-written digits, slightly different from MNIST.
For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from
the training set and 9,000 images from the test set. For USPS which has only
9,298 images in total, we use the entire dataset but replicate its training
set for 3 times so as to match the training set size of other domains.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
"""
dataset_dir = "digit5"
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
func = "load_" + dname
domain_dir = osp.join(self.dataset_dir, dname)
items_d = eval(func)(domain_dir, split=split)
for impath, label in items_d:
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=str(label)
)
items.append(item)
return items

View File

@@ -0,0 +1,69 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class DomainNet(DatasetBase):
"""DomainNet.
Statistics:
- 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,
Real, Sketch.
- Around 0.6M images.
- 345 categories.
- URL: http://ai.bu.edu/M3SDA/.
Special note: the t-shirt class (327) is missing in painting_train.txt.
Reference:
- Peng et al. Moment Matching for Multi-Source Domain
Adaptation. ICCV 2019.
"""
dataset_dir = "domainnet"
domains = [
"clipart", "infograph", "painting", "quickdraw", "real", "sketch"
]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.split_dir = osp.join(self.dataset_dir, "splits")
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
filename = dname + "_" + split + ".txt"
split_file = osp.join(self.split_dir, filename)
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[1]
impath = osp.join(self.dataset_dir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items

View File

@@ -0,0 +1,58 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class miniDomainNet(DatasetBase):
"""A subset of DomainNet.
Reference:
- Peng et al. Moment Matching for Multi-Source Domain
Adaptation. ICCV 2019.
- Zhou et al. Domain Adaptive Ensemble Learning.
"""
dataset_dir = "domainnet"
domains = ["clipart", "painting", "real", "sketch"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.split_dir = osp.join(self.dataset_dir, "splits_mini")
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
filename = dname + "_" + split + ".txt"
split_file = osp.join(self.split_dir, filename)
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[1]
impath = osp.join(self.dataset_dir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items

View File

@@ -0,0 +1,63 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class Office31(DatasetBase):
"""Office-31.
Statistics:
- 4,110 images.
- 31 classes related to office objects.
- 3 domains: Amazon, Webcam, Dslr.
- URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.
Reference:
- Saenko et al. Adapting visual category models to
new domains. ECCV 2010.
"""
dataset_dir = "office31"
domains = ["amazon", "webcam", "dslr"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains):
items = []
for domain, dname in enumerate(input_domains):
domain_dir = osp.join(self.dataset_dir, dname)
class_names = listdir_nohidden(domain_dir)
class_names.sort()
for label, class_name in enumerate(class_names):
class_path = osp.join(domain_dir, class_name)
imnames = listdir_nohidden(class_path)
for imname in imnames:
impath = osp.join(class_path, imname)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name
)
items.append(item)
return items

View File

@@ -0,0 +1,63 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class OfficeHome(DatasetBase):
"""Office-Home.
Statistics:
- Around 15,500 images.
- 65 classes related to office and home objects.
- 4 domains: Art, Clipart, Product, Real World.
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
Reference:
- Venkateswara et al. Deep Hashing Network for Unsupervised
Domain Adaptation. CVPR 2017.
"""
dataset_dir = "office_home"
domains = ["art", "clipart", "product", "real_world"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains):
items = []
for domain, dname in enumerate(input_domains):
domain_dir = osp.join(self.dataset_dir, dname)
class_names = listdir_nohidden(domain_dir)
class_names.sort()
for label, class_name in enumerate(class_names):
class_path = osp.join(domain_dir, class_name)
imnames = listdir_nohidden(class_path)
for imname in imnames:
impath = osp.join(class_path, imname)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name.lower(),
)
items.append(item)
return items

View File

@@ -0,0 +1,61 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class VisDA17(DatasetBase):
"""VisDA17.
Focusing on simulation-to-reality domain shift.
URL: http://ai.bu.edu/visda-2017/.
Reference:
- Peng et al. VisDA: The Visual Domain Adaptation
Challenge. ArXiv 2017.
"""
dataset_dir = "visda17"
domains = ["synthetic", "real"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data("synthetic")
train_u = self._read_data("real")
test = self._read_data("real")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, dname):
filedir = "train" if dname == "synthetic" else "validation"
image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
items = []
# There is only one source domain
domain = 0
with open(image_list, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[0]
impath = osp.join(self.dataset_dir, filedir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items

View File

@@ -0,0 +1,6 @@
from .pacs import PACS
from .vlcs import VLCS
from .cifar_c import CIFAR10C, CIFAR100C
from .digits_dg import DigitsDG
from .digit_single import DigitSingle
from .office_home_dg import OfficeHomeDG

View File

@@ -0,0 +1,123 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
AVAI_C_TYPES = [
"brightness",
"contrast",
"defocus_blur",
"elastic_transform",
"fog",
"frost",
"gaussian_blur",
"gaussian_noise",
"glass_blur",
"impulse_noise",
"jpeg_compression",
"motion_blur",
"pixelate",
"saturate",
"shot_noise",
"snow",
"spatter",
"speckle_noise",
"zoom_blur",
]
@DATASET_REGISTRY.register()
class CIFAR10C(DatasetBase):
"""CIFAR-10 -> CIFAR-10-C.
Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o
Statistics:
- 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10
- 10 categories
Reference:
- Hendrycks et al. Benchmarking neural network robustness
to common corruptions and perturbations. ICLR 2019.
"""
dataset_dir = ""
domains = ["cifar10", "cifar10_c"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = root
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
source_domain = cfg.DATASET.SOURCE_DOMAINS[0]
target_domain = cfg.DATASET.TARGET_DOMAINS[0]
assert source_domain == self.domains[0]
assert target_domain == self.domains[1]
c_type = cfg.DATASET.CIFAR_C_TYPE
c_level = cfg.DATASET.CIFAR_C_LEVEL
if not c_type:
raise ValueError(
"Please specify DATASET.CIFAR_C_TYPE in the config file"
)
assert (
c_type in AVAI_C_TYPES
), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"'
assert 1 <= c_level <= 5
train_dir = osp.join(self.dataset_dir, source_domain, "train")
test_dir = osp.join(
self.dataset_dir, target_domain, c_type, str(c_level)
)
if not osp.exists(test_dir):
raise ValueError
train = self._read_data(train_dir)
test = self._read_data(test_dir)
super().__init__(train_x=train, test=test)
def _read_data(self, data_dir):
class_names = listdir_nohidden(data_dir)
class_names.sort()
items = []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label, domain=0)
items.append(item)
return items
@DATASET_REGISTRY.register()
class CIFAR100C(CIFAR10C):
"""CIFAR-100 -> CIFAR-100-C.
Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o
Statistics:
- 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100
- 10 categories
Reference:
- Hendrycks et al. Benchmarking neural network robustness
to common corruptions and perturbations. ICLR 2019.
"""
dataset_dir = ""
domains = ["cifar100", "cifar100_c"]
def __init__(self, cfg):
super().__init__(cfg)

View File

@@ -0,0 +1,124 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}
def read_image_list(im_dir, n_max=None, n_repeat=None):
items = []
for imname in listdir_nohidden(im_dir):
imname_noext = osp.splitext(imname)[0]
label = int(imname_noext.split("_")[1])
impath = osp.join(im_dir, imname)
items.append((impath, label))
if n_max is not None:
# Note that the sampling process is NOT random,
# which follows that in Volpi et al. NIPS'18.
items = items[:n_max]
if n_repeat is not None:
items *= n_repeat
return items
def load_mnist(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_mnist_m(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST_M[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_svhn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SVHN[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_syn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SYN[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_usps(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, USPS[split])
return read_image_list(data_dir)
@DATASET_REGISTRY.register()
class DigitSingle(DatasetBase):
"""Digit recognition datasets for single-source domain generalization.
There are five digit datasets:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
- USPS: hand-written digits, slightly different from MNIST.
Protocol:
Volpi et al. train a model using 10,000 images from MNIST and
evaluate the model on the test split of the other four datasets. However,
the code does not restrict you to only use MNIST as the source dataset.
Instead, you can use any dataset as the source. But note that only 10,000
images will be sampled from the source dataset for training.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
- Volpi et al. Generalizing to Unseen Domains via Adversarial Data
Augmentation. NIPS 2018.
"""
# Reuse the digit-5 folder instead of creating a new folder
dataset_dir = "digit5"
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
func = "load_" + dname
domain_dir = osp.join(self.dataset_dir, dname)
items_d = eval(func)(domain_dir, split=split)
for impath, label in items_d:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items

View File

@@ -0,0 +1,97 @@
import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class DigitsDG(DatasetBase):
"""Digits-DG.
It contains 4 digit datasets:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
- Zhou et al. Deep Domain-Adversarial Image Generation for Domain
Generalisation. AAAI 2020.
"""
dataset_dir = "digits_dg"
domains = ["mnist", "mnist_m", "svhn", "syn"]
data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "digits_dg.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
)
val = self.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
)
test = self.read_data(
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
)
super().__init__(train_x=train, val=val, test=test)
@staticmethod
def read_data(dataset_dir, input_domains, split):
def _load_data_from_directory(directory):
folders = listdir_nohidden(directory)
folders.sort()
items_ = []
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(directory, folder, "*.jpg"))
for impath in impaths:
items_.append((impath, label))
return items_
items = []
for domain, dname in enumerate(input_domains):
if split == "all":
train_dir = osp.join(dataset_dir, dname, "train")
impath_label_list = _load_data_from_directory(train_dir)
val_dir = osp.join(dataset_dir, dname, "val")
impath_label_list += _load_data_from_directory(val_dir)
else:
split_dir = osp.join(dataset_dir, dname, split)
impath_label_list = _load_data_from_directory(split_dir)
for impath, label in impath_label_list:
class_name = impath.split("/")[-2].lower()
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name
)
items.append(item)
return items

View File

@@ -0,0 +1,49 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from .digits_dg import DigitsDG
from ..base_dataset import DatasetBase
@DATASET_REGISTRY.register()
class OfficeHomeDG(DatasetBase):
"""Office-Home.
Statistics:
- Around 15,500 images.
- 65 classes related to office and home objects.
- 4 domains: Art, Clipart, Product, Real World.
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
Reference:
- Venkateswara et al. Deep Hashing Network for Unsupervised
Domain Adaptation. CVPR 2017.
"""
dataset_dir = "office_home_dg"
domains = ["art", "clipart", "product", "real_world"]
data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "office_home_dg.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
)
val = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
)
test = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
)
super().__init__(train_x=train, val=val, test=test)

View File

@@ -0,0 +1,94 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class PACS(DatasetBase):
"""PACS.
Statistics:
- 4 domains: Photo (1,670), Art (2,048), Cartoon
(2,344), Sketch (3,929).
- 7 categories: dog, elephant, giraffe, guitar, horse,
house and person.
Reference:
- Li et al. Deeper, broader and artier domain generalization.
ICCV 2017.
"""
dataset_dir = "pacs"
domains = ["art_painting", "cartoon", "photo", "sketch"]
data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE"
# the following images contain errors and should be ignored
_error_paths = ["sketch/dog/n02103406_4068-1.png"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.image_dir = osp.join(self.dataset_dir, "images")
self.split_dir = osp.join(self.dataset_dir, "splits")
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "pacs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
if split == "all":
file_train = osp.join(
self.split_dir, dname + "_train_kfold.txt"
)
impath_label_list = self._read_split_pacs(file_train)
file_val = osp.join(
self.split_dir, dname + "_crossval_kfold.txt"
)
impath_label_list += self._read_split_pacs(file_val)
else:
file = osp.join(
self.split_dir, dname + "_" + split + "_kfold.txt"
)
impath_label_list = self._read_split_pacs(file)
for impath, label in impath_label_list:
classname = impath.split("/")[-2]
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items
def _read_split_pacs(self, split_file):
items = []
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
if impath in self._error_paths:
continue
impath = osp.join(self.image_dir, impath)
label = int(label) - 1
items.append((impath, label))
return items

View File

@@ -0,0 +1,60 @@
import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class VLCS(DatasetBase):
"""VLCS.
Statistics:
- 4 domains: CALTECH, LABELME, PASCAL, SUN
- 5 categories: bird, car, chair, dog, and person.
Reference:
- Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
"""
dataset_dir = "VLCS"
domains = ["caltech", "labelme", "pascal", "sun"]
data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "vlcs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
dname = dname.upper()
path = osp.join(self.dataset_dir, dname, split)
folders = listdir_nohidden(path)
folders.sort()
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(path, folder, "*.jpg"))
for impath in impaths:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items

View File

@@ -0,0 +1,3 @@
from .svhn import SVHN
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10

View File

@@ -0,0 +1,108 @@
import math
import random
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class CIFAR10(DatasetBase):
"""CIFAR10 for SSL.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
"""
dataset_dir = "cifar10"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
train_dir = osp.join(self.dataset_dir, "train")
test_dir = osp.join(self.dataset_dir, "test")
assert cfg.DATASET.NUM_LABELED > 0
train_x, train_u, val = self._read_data_train(
train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT
)
test = self._read_data_test(test_dir)
if cfg.DATASET.ALL_AS_UNLABELED:
train_u = train_u + train_x
if len(val) == 0:
val = None
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
def _read_data_train(self, data_dir, num_labeled, val_percent):
class_names = listdir_nohidden(data_dir)
class_names.sort()
num_labeled_per_class = num_labeled / len(class_names)
items_x, items_u, items_v = [], [], []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
# Split into train and val following Oliver et al. 2018
# Set cfg.DATASET.VAL_PERCENT to 0 to not use val data
num_val = math.floor(len(imnames) * val_percent)
imnames_train = imnames[num_val:]
imnames_val = imnames[:num_val]
# Note we do shuffle after split
random.shuffle(imnames_train)
for i, imname in enumerate(imnames_train):
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
if (i + 1) <= num_labeled_per_class:
items_x.append(item)
else:
items_u.append(item)
for imname in imnames_val:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
items_v.append(item)
return items_x, items_u, items_v
def _read_data_test(self, data_dir):
class_names = listdir_nohidden(data_dir)
class_names.sort()
items = []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
items.append(item)
return items
@DATASET_REGISTRY.register()
class CIFAR100(CIFAR10):
"""CIFAR100 for SSL.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
"""
dataset_dir = "cifar100"
def __init__(self, cfg):
super().__init__(cfg)

View File

@@ -0,0 +1,87 @@
import numpy as np
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class STL10(DatasetBase):
"""STL-10 dataset.
Description:
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
monkey, ship, truck.
- Images are 96x96 pixels, color.
- 500 training images per class, 800 test images per class.
- 100,000 unlabeled images for unsupervised learning.
Reference:
- Coates et al. An Analysis of Single Layer Networks in
Unsupervised Feature Learning. AISTATS 2011.
"""
dataset_dir = "stl10"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
train_dir = osp.join(self.dataset_dir, "train")
test_dir = osp.join(self.dataset_dir, "test")
unlabeled_dir = osp.join(self.dataset_dir, "unlabeled")
fold_file = osp.join(
self.dataset_dir, "stl10_binary", "fold_indices.txt"
)
# Only use the first five splits
assert 0 <= cfg.DATASET.STL10_FOLD <= 4
train_x = self._read_data_train(
train_dir, cfg.DATASET.STL10_FOLD, fold_file
)
train_u = self._read_data_all(unlabeled_dir)
test = self._read_data_all(test_dir)
if cfg.DATASET.ALL_AS_UNLABELED:
train_u = train_u + train_x
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data_train(self, data_dir, fold, fold_file):
imnames = listdir_nohidden(data_dir)
imnames.sort()
items = []
list_idx = list(range(len(imnames)))
if fold >= 0:
with open(fold_file, "r") as f:
str_idx = f.read().splitlines()[fold]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ")
for i in list_idx:
imname = imnames[i]
impath = osp.join(data_dir, imname)
label = osp.splitext(imname)[0].split("_")[1]
label = int(label)
item = Datum(impath=impath, label=label)
items.append(item)
return items
def _read_data_all(self, data_dir):
imnames = listdir_nohidden(data_dir)
items = []
for imname in imnames:
impath = osp.join(data_dir, imname)
label = osp.splitext(imname)[0].split("_")[1]
if label == "none":
label = -1
else:
label = int(label)
item = Datum(impath=impath, label=label)
items.append(item)
return items

View File

@@ -0,0 +1,17 @@
from .cifar import CIFAR10
from ..build import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class SVHN(CIFAR10):
"""SVHN for SSL.
Reference:
- Netzer et al. Reading Digits in Natural Images with
Unsupervised Feature Learning. NIPS-W 2011.
"""
dataset_dir = "svhn"
def __init__(self, cfg):
super().__init__(cfg)

View File

@@ -0,0 +1,205 @@
import copy
import numpy as np
import random
from collections import defaultdict
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
class RandomDomainSampler(Sampler):
"""Randomly samples N domains each with K images
to form a minibatch of size N*K.
Args:
data_source (list): list of Datums.
batch_size (int): batch size.
n_domain (int): number of domains to sample in a minibatch.
"""
def __init__(self, data_source, batch_size, n_domain):
self.data_source = data_source
# Keep track of image indices for each domain
self.domain_dict = defaultdict(list)
for i, item in enumerate(data_source):
self.domain_dict[item.domain].append(i)
self.domains = list(self.domain_dict.keys())
# Make sure each domain has equal number of images
if n_domain is None or n_domain <= 0:
n_domain = len(self.domains)
assert batch_size % n_domain == 0
self.n_img_per_domain = batch_size // n_domain
self.batch_size = batch_size
# n_domain denotes number of domains sampled in a minibatch
self.n_domain = n_domain
self.length = len(list(self.__iter__()))
def __iter__(self):
domain_dict = copy.deepcopy(self.domain_dict)
final_idxs = []
stop_sampling = False
while not stop_sampling:
selected_domains = random.sample(self.domains, self.n_domain)
for domain in selected_domains:
idxs = domain_dict[domain]
selected_idxs = random.sample(idxs, self.n_img_per_domain)
final_idxs.extend(selected_idxs)
for idx in selected_idxs:
domain_dict[domain].remove(idx)
remaining = len(domain_dict[domain])
if remaining < self.n_img_per_domain:
stop_sampling = True
return iter(final_idxs)
def __len__(self):
return self.length
class SeqDomainSampler(Sampler):
"""Sequential domain sampler, which randomly samples K
images from each domain to form a minibatch.
Args:
data_source (list): list of Datums.
batch_size (int): batch size.
"""
def __init__(self, data_source, batch_size):
self.data_source = data_source
# Keep track of image indices for each domain
self.domain_dict = defaultdict(list)
for i, item in enumerate(data_source):
self.domain_dict[item.domain].append(i)
self.domains = list(self.domain_dict.keys())
self.domains.sort()
# Make sure each domain has equal number of images
n_domain = len(self.domains)
assert batch_size % n_domain == 0
self.n_img_per_domain = batch_size // n_domain
self.batch_size = batch_size
# n_domain denotes number of domains sampled in a minibatch
self.n_domain = n_domain
self.length = len(list(self.__iter__()))
def __iter__(self):
domain_dict = copy.deepcopy(self.domain_dict)
final_idxs = []
stop_sampling = False
while not stop_sampling:
for domain in self.domains:
idxs = domain_dict[domain]
selected_idxs = random.sample(idxs, self.n_img_per_domain)
final_idxs.extend(selected_idxs)
for idx in selected_idxs:
domain_dict[domain].remove(idx)
remaining = len(domain_dict[domain])
if remaining < self.n_img_per_domain:
stop_sampling = True
return iter(final_idxs)
def __len__(self):
return self.length
class RandomClassSampler(Sampler):
"""Randomly samples N classes each with K instances to
form a minibatch of size N*K.
Modified from https://github.com/KaiyangZhou/deep-person-reid.
Args:
data_source (list): list of Datums.
batch_size (int): batch size.
n_ins (int): number of instances per class to sample in a minibatch.
"""
def __init__(self, data_source, batch_size, n_ins):
if batch_size < n_ins:
raise ValueError(
"batch_size={} must be no less "
"than n_ins={}".format(batch_size, n_ins)
)
self.data_source = data_source
self.batch_size = batch_size
self.n_ins = n_ins
self.ncls_per_batch = self.batch_size // self.n_ins
self.index_dic = defaultdict(list)
for index, item in enumerate(data_source):
self.index_dic[item.label].append(index)
self.labels = list(self.index_dic.keys())
assert len(self.labels) >= self.ncls_per_batch
# estimate number of images in an epoch
self.length = len(list(self.__iter__()))
def __iter__(self):
batch_idxs_dict = defaultdict(list)
for label in self.labels:
idxs = copy.deepcopy(self.index_dic[label])
if len(idxs) < self.n_ins:
idxs = np.random.choice(idxs, size=self.n_ins, replace=True)
random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.n_ins:
batch_idxs_dict[label].append(batch_idxs)
batch_idxs = []
avai_labels = copy.deepcopy(self.labels)
final_idxs = []
while len(avai_labels) >= self.ncls_per_batch:
selected_labels = random.sample(avai_labels, self.ncls_per_batch)
for label in selected_labels:
batch_idxs = batch_idxs_dict[label].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[label]) == 0:
avai_labels.remove(label)
return iter(final_idxs)
def __len__(self):
return self.length
def build_sampler(
sampler_type,
cfg=None,
data_source=None,
batch_size=32,
n_domain=0,
n_ins=16
):
if sampler_type == "RandomSampler":
return RandomSampler(data_source)
elif sampler_type == "SequentialSampler":
return SequentialSampler(data_source)
elif sampler_type == "RandomDomainSampler":
return RandomDomainSampler(data_source, batch_size, n_domain)
elif sampler_type == "SeqDomainSampler":
return SeqDomainSampler(data_source, batch_size)
elif sampler_type == "RandomClassSampler":
return RandomClassSampler(data_source, batch_size, n_ins)
else:
raise ValueError("Unknown sampler type: {}".format(sampler_type))

View File

@@ -0,0 +1 @@
from .transforms import build_transform

View File

@@ -0,0 +1,273 @@
"""
Source: https://github.com/DeepVoltaire/AutoAugment
"""
import numpy as np
import random
from PIL import Image, ImageOps, ImageEnhance
class ImageNetPolicy:
"""Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy:
"""Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy:
"""Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor),
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(
self,
p1,
operation1,
magnitude_idx1,
p2,
operation2,
magnitude_idx2,
fillcolor=(128, 128, 128),
):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10,
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(
rot, Image.new("RGBA", rot.size, (128, ) * 4), rot
).convert(img.mode)
func = {
"shearX":
lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"shearY":
lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"translateX":
lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0,
1, 0
),
fillcolor=fillcolor,
),
"translateY":
lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1, 0, 0, 0, 1, magnitude * img.size[1] * random.
choice([-1, 1])
),
fillcolor=fillcolor,
),
"rotate":
lambda img, magnitude: rotate_with_fill(img, magnitude),
"color":
lambda img, magnitude: ImageEnhance.Color(img).
enhance(1 + magnitude * random.choice([-1, 1])),
"posterize":
lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize":
lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast":
lambda img, magnitude: ImageEnhance.Contrast(img).
enhance(1 + magnitude * random.choice([-1, 1])),
"sharpness":
lambda img, magnitude: ImageEnhance.Sharpness(img).
enhance(1 + magnitude * random.choice([-1, 1])),
"brightness":
lambda img, magnitude: ImageEnhance.Brightness(img).
enhance(1 + magnitude * random.choice([-1, 1])),
"autocontrast":
lambda img, magnitude: ImageOps.autocontrast(img),
"equalize":
lambda img, magnitude: ImageOps.equalize(img),
"invert":
lambda img, magnitude: ImageOps.invert(img),
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1:
img = self.operation1(img, self.magnitude1)
if random.random() < self.p2:
img = self.operation2(img, self.magnitude2)
return img

View File

@@ -0,0 +1,363 @@
"""
Credit to
1) https://github.com/ildoonet/pytorch-randaugment
2) https://github.com/kakaobrain/fast-autoaugment
"""
import numpy as np
import random
import PIL
import torch
import PIL.ImageOps
import PIL.ImageDraw
import PIL.ImageEnhance
from PIL import Image
def ShearX(img, v):
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
def ShearY(img, v):
assert -0.3 <= v <= 0.3
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
def TranslateX(img, v):
# [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v = -v
v = v * img.size[0]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateXabs(img, v):
# [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateY(img, v):
# [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if random.random() > 0.5:
v = -v
v = v * img.size[1]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
def TranslateYabs(img, v):
# [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
def Rotate(img, v):
assert -30 <= v <= 30
if random.random() > 0.5:
v = -v
return img.rotate(v)
def AutoContrast(img, _):
return PIL.ImageOps.autocontrast(img)
def Invert(img, _):
return PIL.ImageOps.invert(img)
def Equalize(img, _):
return PIL.ImageOps.equalize(img)
def Flip(img, _):
return PIL.ImageOps.mirror(img)
def Solarize(img, v):
assert 0 <= v <= 256
return PIL.ImageOps.solarize(img, v)
def SolarizeAdd(img, addition=0, threshold=128):
img_np = np.array(img).astype(np.int)
img_np = img_np + addition
img_np = np.clip(img_np, 0, 255)
img_np = img_np.astype(np.uint8)
img = Image.fromarray(img_np)
return PIL.ImageOps.solarize(img, threshold)
def Posterize(img, v):
assert 4 <= v <= 8
v = int(v)
return PIL.ImageOps.posterize(img, v)
def Contrast(img, v):
assert 0.0 <= v <= 2.0
return PIL.ImageEnhance.Contrast(img).enhance(v)
def Color(img, v):
assert 0.0 <= v <= 2.0
return PIL.ImageEnhance.Color(img).enhance(v)
def Brightness(img, v):
assert 0.0 <= v <= 2.0
return PIL.ImageEnhance.Brightness(img).enhance(v)
def Sharpness(img, v):
assert 0.0 <= v <= 2.0
return PIL.ImageEnhance.Sharpness(img).enhance(v)
def Cutout(img, v):
# [0, 60] => percentage: [0, 0.2]
assert 0.0 <= v <= 0.2
if v <= 0.0:
return img
v = v * img.size[0]
return CutoutAbs(img, v)
def CutoutAbs(img, v):
# [0, 60] => percentage: [0, 0.2]
# assert 0 <= v <= 20
if v < 0:
return img
w, h = img.size
x0 = np.random.uniform(w)
y0 = np.random.uniform(h)
x0 = int(max(0, x0 - v/2.0))
y0 = int(max(0, y0 - v/2.0))
x1 = min(w, x0 + v)
y1 = min(h, y0 + v)
xy = (x0, y0, x1, y1)
color = (125, 123, 114)
# color = (0, 0, 0)
img = img.copy()
PIL.ImageDraw.Draw(img).rectangle(xy, color)
return img
def SamplePairing(imgs):
# [0, 0.4]
def f(img1, v):
i = np.random.choice(len(imgs))
img2 = PIL.Image.fromarray(imgs[i])
return PIL.Image.blend(img1, img2, v)
return f
def Identity(img, v):
return img
class Lighting:
"""Lighting noise (AlexNet - style PCA - based noise)."""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = torch.Tensor(eigval)
self.eigvec = torch.Tensor(eigvec)
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = (
self.eigvec.type_as(img).clone().mul(
alpha.view(1, 3).expand(3, 3)
).mul(self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze()
)
return img.add(rgb.view(3, 1, 1).expand_as(img))
class CutoutDefault:
"""
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
"""
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def randaugment_list():
# 16 oeprations and their ranges
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
# augs = [
# (Identity, 0., 1.0),
# (ShearX, 0., 0.3), # 0
# (ShearY, 0., 0.3), # 1
# (TranslateX, 0., 0.33), # 2
# (TranslateY, 0., 0.33), # 3
# (Rotate, 0, 30), # 4
# (AutoContrast, 0, 1), # 5
# (Invert, 0, 1), # 6
# (Equalize, 0, 1), # 7
# (Solarize, 0, 110), # 8
# (Posterize, 4, 8), # 9
# # (Contrast, 0.1, 1.9), # 10
# (Color, 0.1, 1.9), # 11
# (Brightness, 0.1, 1.9), # 12
# (Sharpness, 0.1, 1.9), # 13
# # (Cutout, 0, 0.2), # 14
# # (SamplePairing(imgs), 0, 0.4) # 15
# ]
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
augs = [
(AutoContrast, 0, 1),
(Equalize, 0, 1),
(Invert, 0, 1),
(Rotate, 0, 30),
(Posterize, 4, 8),
(Solarize, 0, 256),
(SolarizeAdd, 0, 110),
(Color, 0.1, 1.9),
(Contrast, 0.1, 1.9),
(Brightness, 0.1, 1.9),
(Sharpness, 0.1, 1.9),
(ShearX, 0.0, 0.3),
(ShearY, 0.0, 0.3),
(CutoutAbs, 0, 40),
(TranslateXabs, 0.0, 100),
(TranslateYabs, 0.0, 100),
]
return augs
def randaugment_list2():
augs = [
(AutoContrast, 0, 1),
(Brightness, 0.1, 1.9),
(Color, 0.1, 1.9),
(Contrast, 0.1, 1.9),
(Equalize, 0, 1),
(Identity, 0, 1),
(Invert, 0, 1),
(Posterize, 4, 8),
(Rotate, -30, 30),
(Sharpness, 0.1, 1.9),
(ShearX, -0.3, 0.3),
(ShearY, -0.3, 0.3),
(Solarize, 0, 256),
(TranslateX, -0.3, 0.3),
(TranslateY, -0.3, 0.3),
]
return augs
def fixmatch_list():
# https://arxiv.org/abs/2001.07685
augs = [
(AutoContrast, 0, 1),
(Brightness, 0.05, 0.95),
(Color, 0.05, 0.95),
(Contrast, 0.05, 0.95),
(Equalize, 0, 1),
(Identity, 0, 1),
(Posterize, 4, 8),
(Rotate, -30, 30),
(Sharpness, 0.05, 0.95),
(ShearX, -0.3, 0.3),
(ShearY, -0.3, 0.3),
(Solarize, 0, 256),
(TranslateX, -0.3, 0.3),
(TranslateY, -0.3, 0.3),
]
return augs
class RandAugment:
def __init__(self, n=2, m=10):
assert 0 <= m <= 30
self.n = n
self.m = m
self.augment_list = randaugment_list()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
val = (self.m / 30) * (maxval-minval) + minval
img = op(img, val)
return img
class RandAugment2:
def __init__(self, n=2, p=0.6):
self.n = n
self.p = p
self.augment_list = randaugment_list2()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
if random.random() > self.p:
continue
m = random.random()
val = m * (maxval-minval) + minval
img = op(img, val)
return img
class RandAugmentFixMatch:
def __init__(self, n=2):
self.n = n
self.augment_list = fixmatch_list()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
m = random.random()
val = m * (maxval-minval) + minval
img = op(img, val)
return img

View File

@@ -0,0 +1,341 @@
import numpy as np
import random
import torch
from PIL import Image
from torchvision.transforms import (
Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,
RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,
RandomHorizontalFlip
)
from .autoaugment import SVHNPolicy, CIFAR10Policy, ImageNetPolicy
from .randaugment import RandAugment, RandAugment2, RandAugmentFixMatch
AVAI_CHOICES = [
"random_flip",
"random_resized_crop",
"normalize",
"instance_norm",
"random_crop",
"random_translation",
"center_crop", # This has become a default operation for test
"cutout",
"imagenet_policy",
"cifar10_policy",
"svhn_policy",
"randaugment",
"randaugment_fixmatch",
"randaugment2",
"gaussian_noise",
"colorjitter",
"randomgrayscale",
"gaussian_blur",
]
INTERPOLATION_MODES = {
"bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC,
"nearest": Image.NEAREST,
}
class Random2DTranslation:
"""Given an image of (height, width), we resize it to
(height*1.125, width*1.125), and then perform random cropping.
Args:
height (int): target image height.
width (int): target image width.
p (float, optional): probability that this operation takes place.
Default is 0.5.
interpolation (int, optional): desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
self.height = height
self.width = width
self.p = p
self.interpolation = interpolation
def __call__(self, img):
if random.uniform(0, 1) > self.p:
return img.resize((self.width, self.height), self.interpolation)
new_width = int(round(self.width * 1.125))
new_height = int(round(self.height * 1.125))
resized_img = img.resize((new_width, new_height), self.interpolation)
x_maxrange = new_width - self.width
y_maxrange = new_height - self.height
x1 = int(round(random.uniform(0, x_maxrange)))
y1 = int(round(random.uniform(0, y_maxrange)))
croped_img = resized_img.crop(
(x1, y1, x1 + self.width, y1 + self.height)
)
return croped_img
class InstanceNormalization:
"""Normalize data using per-channel mean and standard deviation.
Reference:
- Ulyanov et al. Instance normalization: The missing in- gredient
for fast stylization. ArXiv 2016.
- Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.
ICLR 2018.
"""
def __init__(self, eps=1e-8):
self.eps = eps
def __call__(self, img):
C, H, W = img.shape
img_re = img.reshape(C, H * W)
mean = img_re.mean(1).view(C, 1, 1)
std = img_re.std(1).view(C, 1, 1)
return (img-mean) / (std + self.eps)
class Cutout:
"""Randomly mask out one or more patches from an image.
https://github.com/uoguelph-mlrg/Cutout
Args:
n_holes (int, optional): number of patches to cut out
of each image. Default is 1.
length (int, optinal): length (in pixels) of each square
patch. Default is 16.
"""
def __init__(self, n_holes=1, length=16):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): tensor image of size (C, H, W).
Returns:
Tensor: image with n_holes of dimension
length x length cut out of it.
"""
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
return img * mask
class GaussianNoise:
"""Add gaussian noise."""
def __init__(self, mean=0, std=0.15, p=0.5):
self.mean = mean
self.std = std
self.p = p
def __call__(self, img):
if random.uniform(0, 1) > self.p:
return img
noise = torch.randn(img.size()) * self.std + self.mean
return img + noise
def build_transform(cfg, is_train=True, choices=None):
"""Build transformation function.
Args:
cfg (CfgNode): config.
is_train (bool, optional): for training (True) or test (False).
Default is True.
choices (list, optional): list of strings which will overwrite
cfg.INPUT.TRANSFORMS if given. Default is None.
"""
if cfg.INPUT.NO_TRANSFORM:
print("Note: no transform is applied!")
return None
if choices is None:
choices = cfg.INPUT.TRANSFORMS
for choice in choices:
assert choice in AVAI_CHOICES
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
if is_train:
return _build_transform_train(cfg, choices, target_size, normalize)
else:
return _build_transform_test(cfg, choices, target_size, normalize)
def _build_transform_train(cfg, choices, target_size, normalize):
print("Building transform_train")
tfm_train = []
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
# Make sure the image size matches the target size
conditions = []
conditions += ["random_crop" not in choices]
conditions += ["random_resized_crop" not in choices]
if all(conditions):
print(f"+ resize to {target_size}")
tfm_train += [Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
if "random_translation" in choices:
print("+ random translation")
tfm_train += [
Random2DTranslation(cfg.INPUT.SIZE[0], cfg.INPUT.SIZE[1])
]
if "random_crop" in choices:
crop_padding = cfg.INPUT.CROP_PADDING
print("+ random crop (padding = {})".format(crop_padding))
tfm_train += [RandomCrop(cfg.INPUT.SIZE, padding=crop_padding)]
if "random_resized_crop" in choices:
print(f"+ random resized crop (size={cfg.INPUT.SIZE})")
tfm_train += [
RandomResizedCrop(cfg.INPUT.SIZE, interpolation=interp_mode)
]
if "center_crop" in choices:
print(f"+ center crop (size={cfg.INPUT.SIZE})")
tfm_train += [CenterCrop(cfg.INPUT.SIZE)]
if "random_flip" in choices:
print("+ random flip")
tfm_train += [RandomHorizontalFlip()]
if "imagenet_policy" in choices:
print("+ imagenet policy")
tfm_train += [ImageNetPolicy()]
if "cifar10_policy" in choices:
print("+ cifar10 policy")
tfm_train += [CIFAR10Policy()]
if "svhn_policy" in choices:
print("+ svhn policy")
tfm_train += [SVHNPolicy()]
if "randaugment" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
m_ = cfg.INPUT.RANDAUGMENT_M
print("+ randaugment (n={}, m={})".format(n_, m_))
tfm_train += [RandAugment(n_, m_)]
if "randaugment_fixmatch" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
print("+ randaugment_fixmatch (n={})".format(n_))
tfm_train += [RandAugmentFixMatch(n_)]
if "randaugment2" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
print("+ randaugment2 (n={})".format(n_))
tfm_train += [RandAugment2(n_)]
if "colorjitter" in choices:
print("+ color jitter")
tfm_train += [
ColorJitter(
brightness=cfg.INPUT.COLORJITTER_B,
contrast=cfg.INPUT.COLORJITTER_C,
saturation=cfg.INPUT.COLORJITTER_S,
hue=cfg.INPUT.COLORJITTER_H,
)
]
if "randomgrayscale" in choices:
print("+ random gray scale")
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
if "gaussian_blur" in choices:
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
tfm_train += [
RandomApply([GaussianBlur(cfg.INPUT.GB_K)], p=cfg.INPUT.GB_P)
]
print("+ to torch tensor of range [0, 1]")
tfm_train += [ToTensor()]
if "cutout" in choices:
cutout_n = cfg.INPUT.CUTOUT_N
cutout_len = cfg.INPUT.CUTOUT_LEN
print("+ cutout (n_holes={}, length={})".format(cutout_n, cutout_len))
tfm_train += [Cutout(cutout_n, cutout_len)]
if "normalize" in choices:
print(
"+ normalization (mean={}, "
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
)
tfm_train += [normalize]
if "gaussian_noise" in choices:
print(
"+ gaussian noise (mean={}, std={})".format(
cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD
)
)
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
if "instance_norm" in choices:
print("+ instance normalization")
tfm_train += [InstanceNormalization()]
tfm_train = Compose(tfm_train)
return tfm_train
def _build_transform_test(cfg, choices, target_size, normalize):
print("Building transform_test")
tfm_test = []
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
print(f"+ resize the smaller edge to {max(cfg.INPUT.SIZE)}")
tfm_test += [Resize(max(cfg.INPUT.SIZE), interpolation=interp_mode)]
print(f"+ {target_size} center crop")
tfm_test += [CenterCrop(cfg.INPUT.SIZE)]
print("+ to torch tensor of range [0, 1]")
tfm_test += [ToTensor()]
if "normalize" in choices:
print(
"+ normalization (mean={}, "
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
)
tfm_test += [normalize]
if "instance_norm" in choices:
print("+ instance normalization")
tfm_test += [InstanceNormalization()]
tfm_test = Compose(tfm_test)
return tfm_test

View File

@@ -0,0 +1,6 @@
from .build import TRAINER_REGISTRY, build_trainer # isort:skip
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip
from .da import *
from .dg import *
from .ssl import *

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
TRAINER_REGISTRY = Registry("TRAINER")
def build_trainer(cfg):
avai_trainers = TRAINER_REGISTRY.registered_names()
check_availability(cfg.TRAINER.NAME, avai_trainers)
if cfg.VERBOSE:
print("Loading trainer: {}".format(cfg.TRAINER.NAME))
return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)

View File

@@ -0,0 +1,9 @@
from .mcd import MCD
from .mme import MME
from .adda import ADDA
from .dael import DAEL
from .dann import DANN
from .adabn import AdaBN
from .m3sda import M3SDA
from .source_only import SourceOnly
from .self_ensembling import SelfEnsembling

View File

@@ -0,0 +1,38 @@
import torch
from dassl.utils import check_isfile
from dassl.engine import TRAINER_REGISTRY, TrainerXU
@TRAINER_REGISTRY.register()
class AdaBN(TrainerXU):
"""Adaptive Batch Normalization.
https://arxiv.org/abs/1603.04779.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.done_reset_bn_stats = False
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def before_epoch(self):
if not self.done_reset_bn_stats:
for m in self.model.modules():
classname = m.__class__.__name__
if classname.find("BatchNorm") != -1:
m.reset_running_stats()
self.done_reset_bn_stats = True
def forward_backward(self, batch_x, batch_u):
input_u = batch_u["img"].to(self.device)
with torch.no_grad():
self.model(input_u)
return None

View File

@@ -0,0 +1,85 @@
import copy
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import check_isfile, count_num_param, open_specified_layers
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling import build_head
@TRAINER_REGISTRY.register()
class ADDA(TrainerXU):
"""Adversarial Discriminative Domain Adaptation.
https://arxiv.org/abs/1702.05464.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.open_layers = ["backbone"]
if isinstance(self.model.head, nn.Module):
self.open_layers.append("head")
self.source_model = copy.deepcopy(self.model)
self.source_model.eval()
for param in self.source_model.parameters():
param.requires_grad_(False)
self.build_critic()
self.bce = nn.BCEWithLogitsLoss()
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim // 2],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
def forward_backward(self, batch_x, batch_u):
open_specified_layers(self.model, self.open_layers)
input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
_, feat_x = self.source_model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
logit_xd = self.critic(feat_x)
logit_ud = self.critic(feat_u.detach())
loss_critic = self.bce(logit_xd, domain_x)
loss_critic += self.bce(logit_ud, domain_u)
self.model_backward_and_update(loss_critic, "critic")
logit_ud = self.critic(feat_u)
loss_model = self.bce(logit_ud, 1 - domain_u)
self.model_backward_and_update(loss_model, "model")
loss_summary = {
"loss_critic": loss_critic.item(),
"loss_model": loss_model.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,210 @@
import torch
import torch.nn as nn
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot
class Experts(nn.Module):
def __init__(self, n_source, fdim, num_classes):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
)
self.softmax = nn.Softmax(dim=1)
def forward(self, i, x):
x = self.linears[i](x)
x = self.softmax(x)
return x
@TRAINER_REGISTRY.register()
class DAEL(TrainerXU):
"""Domain Adaptive Ensemble Learning.
https://arxiv.org/abs/2003.07325.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u
self.val_loader = dm.val_loader
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building E")
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
self.E.to(self.device)
print("# params: {:,}".format(count_num_param(self.E)))
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
self.register_model("E", self.E, self.optim_E, self.sched_E)
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
input_x = torch.split(input_x, self.split_batch, 0)
input_x2 = torch.split(input_x2, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Generate pseudo label
with torch.no_grad():
feat_u = self.F(input_u)
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1) # (B, K, C)
# Get the highest probability and index (label) for each expert
experts_max_p, experts_max_idx = pred_u.max(2) # (B, K)
# Get the most confident expert
max_expert_p, max_expert_idx = experts_max_p.max(1) # (B)
pseudo_label_u = []
for i, experts_label in zip(max_expert_idx, experts_max_idx):
pseudo_label_u.append(experts_label[i])
pseudo_label_u = torch.stack(pseudo_label_u, 0)
pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
pseudo_label_u = pseudo_label_u.to(self.device)
label_u_mask = (max_expert_p >= self.conf_thre).float()
loss_x = 0
loss_cr = 0
acc_x = 0
feat_x = [self.F(x) for x in input_x]
feat_x2 = [self.F(x) for x in input_x2]
feat_u2 = self.F(input_u2)
for feat_xi, feat_x2i, label_xi, i in zip(
feat_x, feat_x2, label_x, domain_x
):
cr_s = [j for j in domain_x if j != i]
# Learning expert
pred_xi = self.E(i, feat_xi)
loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
expert_label_xi = pred_xi.detach()
acc_x += compute_accuracy(pred_xi.detach(),
label_xi.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat_x2i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc_x /= self.n_domain
# Unsupervised loss
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u2)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1)
pred_u = pred_u.mean(1)
l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
loss_u = (l_u * label_u_mask).mean()
loss = 0
loss += loss_x
loss += loss_cr
loss += loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": acc_x,
"loss_cr": loss_cr.item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
input_x2 = batch_x["img2"]
label_x = batch_x["label"]
domain_x = batch_x["domain"]
input_u = batch_u["img"]
input_u2 = batch_u["img2"]
label_x = create_onehot(label_x, self.num_classes)
input_x = input_x.to(self.device)
input_x2 = input_x2.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_u2 = input_u2.to(self.device)
return input_x, input_x2, label_x, domain_x, input_u, input_u2
def model_inference(self, input):
f = self.F(input)
p = []
for k in range(self.num_source_domains):
p_k = self.E(k, f)
p_k = p_k.unsqueeze(1)
p.append(p_k)
p = torch.cat(p, 1)
p = p.mean(1)
return p

View File

@@ -0,0 +1,78 @@
import numpy as np
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling import build_head
from dassl.modeling.ops import ReverseGrad
@TRAINER_REGISTRY.register()
class DANN(TrainerXU):
"""Domain-Adversarial Neural Networks.
https://arxiv.org/abs/1505.07818.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.build_critic()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
global_step = self.batch_idx + self.epoch * self.num_batches
progress = global_step / (self.max_epoch * self.num_batches)
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
logit_x, feat_x = self.model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
loss_x = self.ce(logit_x, label_x)
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
output_xd = self.critic(feat_x)
output_ud = self.critic(feat_u)
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
loss = loss_x + loss_d
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,208 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
class PairClassifiers(nn.Module):
def __init__(self, fdim, num_classes):
super().__init__()
self.c1 = nn.Linear(fdim, num_classes)
self.c2 = nn.Linear(fdim, num_classes)
def forward(self, x):
z1 = self.c1(x)
if not self.training:
return z1
z2 = self.c2(x)
return z1, z2
@TRAINER_REGISTRY.register()
class M3SDA(TrainerXU):
"""Moment Matching for Multi-Source Domain Adaptation.
https://arxiv.org/abs/1812.01754.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
self.lmda = cfg.TRAINER.M3SDA.LMDA
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C")
self.C = nn.ModuleList(
[
PairClassifiers(fdim, self.num_classes)
for _ in range(self.num_source_domains)
]
)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, domain_x, input_u = parsed
input_x = torch.split(input_x, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Step A
loss_x = 0
feat_x = []
for x, y, d in zip(input_x, label_x, domain_x):
f = self.F(x)
z1, z2 = self.C[d](f)
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
feat_x.append(f)
loss_x /= self.n_domain
feat_u = self.F(input_u)
loss_msda = self.moment_distance(feat_x, feat_u)
loss_step_A = loss_x + loss_msda * self.lmda
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_u = self.F(input_u)
loss_x, loss_dis = 0, 0
for x, y, d in zip(input_x, label_x, domain_x):
with torch.no_grad():
f = self.F(x)
z1, z2 = self.C[d](f)
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
z1, z2 = self.C[d](feat_u)
p1 = F.softmax(z1, 1)
p2 = F.softmax(z2, 1)
loss_dis += self.discrepancy(p1, p2)
loss_x /= self.n_domain
loss_dis /= self.n_domain
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, "C")
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
loss_dis = 0
for d in domain_x:
z1, z2 = self.C[d](feat_u)
p1 = F.softmax(z1, 1)
p2 = F.softmax(z2, 1)
loss_dis += self.discrepancy(p1, p2)
loss_dis /= self.n_domain
loss_step_C = loss_dis
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def moment_distance(self, x, u):
# x (list): a list of feature matrix.
# u (torch.Tensor): feature matrix.
x_mean = [xi.mean(0) for xi in x]
u_mean = u.mean(0)
dist1 = self.pairwise_distance(x_mean, u_mean)
x_var = [xi.var(0) for xi in x]
u_var = u.var(0)
dist2 = self.pairwise_distance(x_var, u_var)
return (dist1+dist2) / 2
def pairwise_distance(self, x, u):
# x (list): a list of feature vector.
# u (torch.Tensor): feature vector.
dist = 0
count = 0
for xi in x:
dist += self.euclidean(xi, u)
count += 1
for i in range(len(x) - 1):
for j in range(i + 1, len(x)):
dist += self.euclidean(x[i], x[j])
count += 1
return dist / count
def euclidean(self, input1, input2):
return ((input1 - input2)**2).sum().sqrt()
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
label_x = batch_x["label"]
domain_x = batch_x["domain"]
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
return input_x, label_x, domain_x, input_u
def model_inference(self, input):
f = self.F(input)
p = 0
for C_i in self.C:
z = C_i(f)
p += F.softmax(z, 1)
p = p / len(self.C)
return p

View File

@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class MCD(TrainerXU):
"""Maximum Classifier Discrepancy.
https://arxiv.org/abs/1712.02560.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.n_step_F = cfg.TRAINER.MCD.N_STEP_F
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building C1")
self.C1 = nn.Linear(fdim, self.num_classes)
self.C1.to(self.device)
print("# params: {:,}".format(count_num_param(self.C1)))
self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)
print("Building C2")
self.C2 = nn.Linear(fdim, self.num_classes)
self.C2.to(self.device)
print("# params: {:,}".format(count_num_param(self.C2)))
self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)
def forward_backward(self, batch_x, batch_u):
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u = parsed
# Step A
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_step_A = loss_x1 + loss_x2
self.model_backward_and_update(loss_step_A)
# Step B
with torch.no_grad():
feat_x = self.F(input_x)
logit_x1 = self.C1(feat_x)
logit_x2 = self.C2(feat_x)
loss_x1 = F.cross_entropy(logit_x1, label_x)
loss_x2 = F.cross_entropy(logit_x2, label_x)
loss_x = loss_x1 + loss_x2
with torch.no_grad():
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_dis = self.discrepancy(pred_u1, pred_u2)
loss_step_B = loss_x - loss_dis
self.model_backward_and_update(loss_step_B, ["C1", "C2"])
# Step C
for _ in range(self.n_step_F):
feat_u = self.F(input_u)
pred_u1 = F.softmax(self.C1(feat_u), 1)
pred_u2 = F.softmax(self.C2(feat_u), 1)
loss_step_C = self.discrepancy(pred_u1, pred_u2)
self.model_backward_and_update(loss_step_C, "F")
loss_summary = {
"loss_step_A": loss_step_A.item(),
"loss_step_B": loss_step_B.item(),
"loss_step_C": loss_step_C.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def discrepancy(self, y1, y2):
return (y1 - y2).abs().mean()
def model_inference(self, input):
feat = self.F(input)
return self.C1(feat)

View File

@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import ReverseGrad
from dassl.engine.trainer import SimpleNet
class Prototypes(nn.Module):
def __init__(self, fdim, num_classes, temp=0.05):
super().__init__()
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
self.temp = temp
def forward(self, x):
x = F.normalize(x, p=2, dim=1)
out = self.prototypes(x)
out = out / self.temp
return out
@TRAINER_REGISTRY.register()
class MME(TrainerXU):
"""Minimax Entropy.
https://arxiv.org/abs/1904.06487.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.MME.LMDA
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building C")
self.C = Prototypes(self.F.fdim, self.num_classes)
self.C.to(self.device)
print("# params: {:,}".format(count_num_param(self.C)))
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
self.register_model("C", self.C, self.optim_C, self.sched_C)
self.revgrad = ReverseGrad()
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
feat_x = self.F(input_x)
logit_x = self.C(feat_x)
loss_x = F.cross_entropy(logit_x, label_x)
self.model_backward_and_update(loss_x)
feat_u = self.F(input_u)
feat_u = self.revgrad(feat_u)
logit_u = self.C(feat_u)
prob_u = F.softmax(logit_u, 1)
loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
self.model_backward_and_update(loss_u * self.lmda)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.C(self.F(input))

View File

@@ -0,0 +1,78 @@
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class SelfEnsembling(TrainerXU):
"""Self-ensembling for visual domain adaptation.
https://arxiv.org/abs/1706.05208.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA
self.conf_thre = cfg.TRAINER.SE.CONF_THRE
self.rampup = cfg.TRAINER.SE.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS == 2
def forward_backward(self, batch_x, batch_u):
global_step = self.batch_idx + self.epoch * self.num_batches
parsed = self.parse_batch_train(batch_x, batch_u)
input_x, label_x, input_u1, input_u2 = parsed
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
prob_u = F.softmax(self.model(input_u1), 1)
t_prob_u = F.softmax(self.teacher(input_u2), 1)
loss_u = ((prob_u - t_prob_u)**2).sum(1)
if self.conf_thre:
max_prob = t_prob_u.max(1)[0]
mask = (max_prob > self.conf_thre).float()
loss_u = (loss_u * mask).mean()
else:
weight_u = sigmoid_rampup(global_step, self.rampup)
loss_u = loss_u.mean() * weight_u
loss = loss_x + loss_u
self.model_backward_and_update(loss)
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_u1, input_u2 = input_u
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u1 = input_u1.to(self.device)
input_u2 = input_u2.to(self.device)
return input_x, label_x, input_u1, input_u2

View File

@@ -0,0 +1,34 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class SourceOnly(TrainerXU):
"""Baseline model for domain adaptation, which is
trained using source data only.
"""
def forward_backward(self, batch_x, batch_u):
input, label = self.parse_batch_train(batch_x, batch_u)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input = batch_x["img"]
label = batch_x["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label

View File

@@ -0,0 +1,4 @@
from .ddaig import DDAIG
from .daeldg import DAELDG
from .vanilla import Vanilla
from .crossgrad import CrossGrad

View File

@@ -0,0 +1,83 @@
import torch
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class CrossGrad(TrainerX):
"""Cross-gradient training.
https://arxiv.org/abs/1804.10745.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.eps_f = cfg.TRAINER.CG.EPS_F
self.eps_d = cfg.TRAINER.CG.EPS_D
self.alpha_f = cfg.TRAINER.CG.ALPHA_F
self.alpha_d = cfg.TRAINER.CG.ALPHA_D
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building D")
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
self.D.to(self.device)
print("# params: {:,}".format(count_num_param(self.D)))
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
self.register_model("D", self.D, self.optim_D, self.sched_D)
def forward_backward(self, batch):
input, label, domain = self.parse_batch_train(batch)
input.requires_grad = True
# Compute domain perturbation
loss_d = F.cross_entropy(self.D(input), domain)
loss_d.backward()
grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
input_d = input.data + self.eps_f * grad_d
# Compute label perturbation
input.grad.data.zero_()
loss_f = F.cross_entropy(self.F(input), label)
loss_f.backward()
grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
input_f = input.data + self.eps_d * grad_f
input = input.detach()
# Update label net
loss_f1 = F.cross_entropy(self.F(input), label)
loss_f2 = F.cross_entropy(self.F(input_d), label)
loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
self.model_backward_and_update(loss_f, "F")
# Update domain net
loss_d1 = F.cross_entropy(self.D(input), domain)
loss_d2 = F.cross_entropy(self.D(input_f), domain)
loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
self.model_backward_and_update(loss_d, "D")
loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.F(input)

View File

@@ -0,0 +1,169 @@
import torch
import torch.nn as nn
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot
class Experts(nn.Module):
def __init__(self, n_source, fdim, num_classes):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
)
self.softmax = nn.Softmax(dim=1)
def forward(self, i, x):
x = self.linears[i](x)
x = self.softmax(x)
return x
@TRAINER_REGISTRY.register()
class DAELDG(TrainerX):
"""Domain Adaptive Ensemble Learning.
DG version: only use labeled source data.
https://arxiv.org/abs/2003.07325.
"""
def __init__(self, cfg):
super().__init__(cfg)
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
if n_domain <= 0:
n_domain = self.num_source_domains
self.split_batch = batch_size // n_domain
self.n_domain = n_domain
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
def check_cfg(self, cfg):
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u
self.val_loader = dm.val_loader
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, 0)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
fdim = self.F.fdim
print("Building E")
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
self.E.to(self.device)
print("# params: {:,}".format(count_num_param(self.E)))
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
self.register_model("E", self.E, self.optim_E, self.sched_E)
def forward_backward(self, batch):
parsed_data = self.parse_batch_train(batch)
input, input2, label, domain = parsed_data
input = torch.split(input, self.split_batch, 0)
input2 = torch.split(input2, self.split_batch, 0)
label = torch.split(label, self.split_batch, 0)
domain = torch.split(domain, self.split_batch, 0)
domain = [d[0].item() for d in domain]
loss_x = 0
loss_cr = 0
acc = 0
feat = [self.F(x) for x in input]
feat2 = [self.F(x) for x in input2]
for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
cr_s = [j for j in domain if j != i]
# Learning expert
pred_i = self.E(i, feat_i)
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
expert_label_i = pred_i.detach()
acc += compute_accuracy(pred_i.detach(),
label_i.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat2_i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc /= self.n_domain
loss = 0
loss += loss_x
loss += loss_cr
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc": acc,
"loss_cr": loss_cr.item()
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
input2 = batch["img2"]
label = batch["label"]
domain = batch["domain"]
label = create_onehot(label, self.num_classes)
input = input.to(self.device)
input2 = input2.to(self.device)
label = label.to(self.device)
return input, input2, label, domain
def model_inference(self, input):
f = self.F(input)
p = []
for k in range(self.num_source_domains):
p_k = self.E(k, f)
p_k = p_k.unsqueeze(1)
p.append(p_k)
p = torch.cat(p, 1)
p = p.mean(1)
return p

View File

@@ -0,0 +1,107 @@
import torch
from torch.nn import functional as F
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.modeling import build_network
from dassl.engine.trainer import SimpleNet
@TRAINER_REGISTRY.register()
class DDAIG(TrainerX):
"""Deep Domain-Adversarial Image Generation.
https://arxiv.org/abs/2003.06054.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.DDAIG.LMDA
self.clamp = cfg.TRAINER.DDAIG.CLAMP
self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
self.warmup = cfg.TRAINER.DDAIG.WARMUP
self.alpha = cfg.TRAINER.DDAIG.ALPHA
def build_model(self):
cfg = self.cfg
print("Building F")
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
self.F.to(self.device)
print("# params: {:,}".format(count_num_param(self.F)))
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
self.register_model("F", self.F, self.optim_F, self.sched_F)
print("Building D")
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
self.D.to(self.device)
print("# params: {:,}".format(count_num_param(self.D)))
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
self.register_model("D", self.D, self.optim_D, self.sched_D)
print("Building G")
self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
self.G.to(self.device)
print("# params: {:,}".format(count_num_param(self.G)))
self.optim_G = build_optimizer(self.G, cfg.OPTIM)
self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
self.register_model("G", self.G, self.optim_G, self.sched_G)
def forward_backward(self, batch):
input, label, domain = self.parse_batch_train(batch)
#############
# Update G
#############
input_p = self.G(input, lmda=self.lmda)
if self.clamp:
input_p = torch.clamp(
input_p, min=self.clamp_min, max=self.clamp_max
)
loss_g = 0
# Minimize label loss
loss_g += F.cross_entropy(self.F(input_p), label)
# Maximize domain loss
loss_g -= F.cross_entropy(self.D(input_p), domain)
self.model_backward_and_update(loss_g, "G")
# Perturb data with new G
with torch.no_grad():
input_p = self.G(input, lmda=self.lmda)
if self.clamp:
input_p = torch.clamp(
input_p, min=self.clamp_min, max=self.clamp_max
)
#############
# Update F
#############
loss_f = F.cross_entropy(self.F(input), label)
if (self.epoch + 1) > self.warmup:
loss_fp = F.cross_entropy(self.F(input_p), label)
loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
self.model_backward_and_update(loss_f, "F")
#############
# Update D
#############
loss_d = F.cross_entropy(self.D(input), domain)
self.model_backward_and_update(loss_d, "D")
loss_summary = {
"loss_g": loss_g.item(),
"loss_f": loss_f.item(),
"loss_d": loss_d.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def model_inference(self, input):
return self.F(input)

View File

@@ -0,0 +1,32 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class Vanilla(TrainerX):
"""Vanilla baseline."""
def forward_backward(self, batch):
input, label = self.parse_batch_train(batch)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label

View File

@@ -0,0 +1,5 @@
from .entmin import EntMin
from .fixmatch import FixMatch
from .mixmatch import MixMatch
from .mean_teacher import MeanTeacher
from .sup_baseline import SupBaseline

View File

@@ -0,0 +1,41 @@
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class EntMin(TrainerXU):
"""Entropy Minimization.
http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.ENTMIN.LMDA
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
output_x = self.model(input_x)
loss_x = F.cross_entropy(output_x, label_x)
output_u = F.softmax(self.model(input_u), 1)
loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()
loss = loss_x + loss_u * self.lmda
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,112 @@
import torch
from torch.nn import functional as F
from dassl.data import DataManager
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.data.transforms import build_transform
@TRAINER_REGISTRY.register()
class FixMatch(TrainerXU):
"""FixMatch: Simplifying Semi-Supervised Learning with
Consistency and Confidence.
https://arxiv.org/abs/2001.07685.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE
def check_cfg(self, cfg):
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = self.dm.train_loader_x
self.train_loader_u = self.dm.train_loader_u
self.val_loader = self.dm.val_loader
self.test_loader = self.dm.test_loader
self.num_classes = self.dm.num_classes
def assess_y_pred_quality(self, y_pred, y_true, mask):
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
keep_rate = mask.sum() / mask.numel()
output = {
"acc_thre": acc_thre,
"acc_raw": acc_raw,
"keep_rate": keep_rate
}
return output
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
input_u = torch.cat([input_x, input_u], 0)
input_u2 = torch.cat([input_x2, input_u2], 0)
n_x = input_x.size(0)
# Generate pseudo labels
with torch.no_grad():
output_u = F.softmax(self.model(input_u), 1)
max_prob, label_u_pred = output_u.max(1)
mask_u = (max_prob >= self.conf_thre).float()
# Evaluate pseudo labels' accuracy
y_u_pred_stats = self.assess_y_pred_quality(
label_u_pred[n_x:], label_u, mask_u[n_x:]
)
# Supervised loss
output_x = self.model(input_x)
loss_x = F.cross_entropy(output_x, label_x)
# Unsupervised loss
output_u = self.model(input_u2)
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none")
loss_u = (loss_u * mask_u).mean()
loss = loss_x + loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
"loss_u": loss_u.item(),
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"],
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"],
"y_u_pred_keep": y_u_pred_stats["keep_rate"],
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
input_x2 = batch_x["img2"]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_u2 = batch_u["img2"]
# label_u is used only for evaluating pseudo labels' accuracy
label_u = batch_u["label"]
input_x = input_x.to(self.device)
input_x2 = input_x2.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_u2 = input_u2.to(self.device)
label_u = label_u.to(self.device)
return input_x, input_x2, label_x, input_u, input_u2, label_u

View File

@@ -0,0 +1,54 @@
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class MeanTeacher(TrainerXU):
"""Mean teacher.
https://arxiv.org/abs/1703.01780.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U
self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA
self.rampup = cfg.TRAINER.MEANTEA.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
target_u = F.softmax(self.teacher(input_u), 1)
prob_u = F.softmax(self.model(input_u), 1)
loss_u = ((prob_u - target_u)**2).sum(1).mean()
weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
global_step = self.batch_idx + self.epoch * self.num_batches
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,98 @@
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling.ops import mixup
from dassl.modeling.ops.utils import (
sharpen_prob, create_onehot, linear_rampup, shuffle_index
)
@TRAINER_REGISTRY.register()
class MixMatch(TrainerXU):
"""MixMatch: A Holistic Approach to Semi-Supervised Learning.
https://arxiv.org/abs/1905.02249.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U
self.temp = cfg.TRAINER.MIXMATCH.TEMP
self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA
self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS > 1
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
num_x = input_x.shape[0]
global_step = self.batch_idx + self.epoch * self.num_batches
weight_u = self.weight_u * linear_rampup(global_step, self.rampup)
# Generate pseudo-label for unlabeled data
with torch.no_grad():
output_u = 0
for input_ui in input_u:
output_ui = F.softmax(self.model(input_ui), 1)
output_u += output_ui
output_u /= len(input_u)
label_u = sharpen_prob(output_u, self.temp)
label_u = [label_u] * len(input_u)
label_u = torch.cat(label_u, 0)
input_u = torch.cat(input_u, 0)
# Combine and shuffle labeled and unlabeled data
input_xu = torch.cat([input_x, input_u], 0)
label_xu = torch.cat([label_x, label_u], 0)
input_xu, label_xu = shuffle_index(input_xu, label_xu)
# Mixup
input_x, label_x = mixup(
input_x,
input_xu[:num_x],
label_x,
label_xu[:num_x],
self.beta,
preserve_order=True,
)
input_u, label_u = mixup(
input_u,
input_xu[num_x:],
label_u,
label_xu[num_x:],
self.beta,
preserve_order=True,
)
# Compute losses
output_x = F.softmax(self.model(input_x), 1)
loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()
output_u = F.softmax(self.model(input_u), 1)
loss_u = ((label_u - output_u)**2).mean()
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
label_x = create_onehot(label_x, self.num_classes)
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = [input_ui.to(self.device) for input_ui in input_u]
return input_x, label_x, input_u

View File

@@ -0,0 +1,32 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class SupBaseline(TrainerXU):
"""Supervised Baseline."""
def forward_backward(self, batch_x, batch_u):
input, label = self.parse_batch_train(batch_x, batch_u)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input = batch_x["img"]
label = batch_x["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label

View File

@@ -0,0 +1,735 @@
import json
import time
import numpy as np
import os.path as osp
import datetime
from collections import OrderedDict
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import (
MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
load_pretrained_weights
)
from dassl.modeling import build_head, build_backbone
from dassl.evaluation import build_evaluator
class SimpleNet(nn.Module):
"""A simple neural network composed of a CNN backbone
and optionally a head such as mlp for classification.
"""
def __init__(self, cfg, model_cfg, num_classes, **kwargs):
super().__init__()
self.backbone = build_backbone(
model_cfg.BACKBONE.NAME,
verbose=cfg.VERBOSE,
pretrained=model_cfg.BACKBONE.PRETRAINED,
**kwargs,
)
fdim = self.backbone.out_features
self.head = None
if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:
self.head = build_head(
model_cfg.HEAD.NAME,
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,
activation=model_cfg.HEAD.ACTIVATION,
bn=model_cfg.HEAD.BN,
dropout=model_cfg.HEAD.DROPOUT,
**kwargs,
)
fdim = self.head.out_features
self.classifier = None
if num_classes > 0:
self.classifier = nn.Linear(fdim, num_classes)
self._fdim = fdim
@property
def fdim(self):
return self._fdim
def forward(self, x, return_feature=False):
f = self.backbone(x)
if self.head is not None:
f = self.head(f)
if self.classifier is None:
return f
y = self.classifier(f)
if return_feature:
return y, f
return y
class TrainerBase:
"""Base class for iterative trainer."""
def __init__(self):
self._models = OrderedDict()
self._optims = OrderedDict()
self._scheds = OrderedDict()
self._writer = None
def register_model(self, name="model", model=None, optim=None, sched=None):
if self.__dict__.get("_models") is None:
raise AttributeError(
"Cannot assign model before super().__init__() call"
)
if self.__dict__.get("_optims") is None:
raise AttributeError(
"Cannot assign optim before super().__init__() call"
)
if self.__dict__.get("_scheds") is None:
raise AttributeError(
"Cannot assign sched before super().__init__() call"
)
assert name not in self._models, "Found duplicate model names"
self._models[name] = model
self._optims[name] = optim
self._scheds[name] = sched
def get_model_names(self, names=None):
names_real = list(self._models.keys())
if names is not None:
names = tolist_if_not(names)
for name in names:
assert name in names_real
return names
else:
return names_real
def save_model(self, epoch, directory, is_best=False, model_name=""):
names = self.get_model_names()
for name in names:
model_dict = self._models[name].state_dict()
optim_dict = None
if self._optims[name] is not None:
optim_dict = self._optims[name].state_dict()
sched_dict = None
if self._scheds[name] is not None:
sched_dict = self._scheds[name].state_dict()
save_checkpoint(
{
"state_dict": model_dict,
"epoch": epoch + 1,
"optimizer": optim_dict,
"scheduler": sched_dict,
},
osp.join(directory, name),
is_best=is_best,
model_name=model_name,
)
def resume_model_if_exist(self, directory):
names = self.get_model_names()
file_missing = False
for name in names:
path = osp.join(directory, name)
if not osp.exists(path):
file_missing = True
break
if file_missing:
print("No checkpoint found, train from scratch")
return 0
print(
'Found checkpoint in "{}". Will resume training'.format(directory)
)
for name in names:
path = osp.join(directory, name)
start_epoch = resume_from_checkpoint(
path, self._models[name], self._optims[name],
self._scheds[name]
)
return start_epoch
def load_model(self, directory, epoch=None):
if not directory:
print(
"Note that load_model() is skipped as no pretrained "
"model is given (ignore this if it's done on purpose)"
)
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError(
'Model not found at "{}"'.format(model_path)
)
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
print(
"Loading weights to {} "
'from "{}" (epoch = {})'.format(name, model_path, epoch)
)
self._models[name].load_state_dict(state_dict)
def set_model_mode(self, mode="train", names=None):
names = self.get_model_names(names)
for name in names:
if mode == "train":
self._models[name].train()
elif mode in ["test", "eval"]:
self._models[name].eval()
else:
raise KeyError
def update_lr(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._scheds[name] is not None:
self._scheds[name].step()
def detect_anomaly(self, loss):
if not torch.isfinite(loss).all():
raise FloatingPointError("Loss is infinite or NaN!")
def init_writer(self, log_dir):
if self.__dict__.get("_writer") is None or self._writer is None:
print(
"Initializing summary writer for tensorboard "
"with log_dir={}".format(log_dir)
)
self._writer = SummaryWriter(log_dir=log_dir)
def close_writer(self):
if self._writer is not None:
self._writer.close()
def write_scalar(self, tag, scalar_value, global_step=None):
if self._writer is None:
# Do nothing if writer is not initialized
# Note that writer is only used when training is needed
pass
else:
self._writer.add_scalar(tag, scalar_value, global_step)
def train(self, start_epoch, max_epoch):
"""Generic training loops."""
self.start_epoch = start_epoch
self.max_epoch = max_epoch
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
self.before_epoch()
self.run_epoch()
self.after_epoch()
self.after_train()
def before_train(self):
pass
def after_train(self):
pass
def before_epoch(self):
pass
def after_epoch(self):
pass
def run_epoch(self):
raise NotImplementedError
def test(self):
raise NotImplementedError
def parse_batch_train(self, batch):
raise NotImplementedError
def parse_batch_test(self, batch):
raise NotImplementedError
def forward_backward(self, batch):
raise NotImplementedError
def model_inference(self, input):
raise NotImplementedError
def model_zero_grad(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].zero_grad()
def model_backward(self, loss):
self.detect_anomaly(loss)
loss.backward()
def model_update(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].step()
def model_backward_and_update(self, loss, names=None):
self.model_zero_grad(names)
self.model_backward(loss)
self.model_update(names)
def prograd_backward_and_update(
self, loss_a, loss_b, lambda_=1, names=None
):
# loss_b not increase is okay
# loss_a has to decline
self.model_zero_grad(names)
# get name of the model parameters
names = self.get_model_names(names)
# backward loss_a
self.detect_anomaly(loss_b)
loss_b.backward(retain_graph=True)
# normalize gradient
b_grads = []
for name in names:
for p in self._models[name].parameters():
b_grads.append(p.grad.clone())
# optimizer don't step
for name in names:
self._optims[name].zero_grad()
# backward loss_a
self.detect_anomaly(loss_a)
loss_a.backward()
for name in names:
for p, b_grad in zip(self._models[name].parameters(), b_grads):
# calculate cosine distance
b_grad_norm = b_grad / torch.linalg.norm(b_grad)
a_grad = p.grad.clone()
a_grad_norm = a_grad / torch.linalg.norm(a_grad)
if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:
p.grad = a_grad - lambda_ * torch.dot(
a_grad.flatten(), b_grad_norm.flatten()
) * b_grad_norm
# optimizer
for name in names:
self._optims[name].step()
class SimpleTrainer(TrainerBase):
"""A simple trainer class implementing generic functions."""
def __init__(self, cfg):
super().__init__()
self.check_cfg(cfg)
if torch.cuda.is_available() and cfg.USE_CUDA:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# Save as attributes some frequently used variables
self.start_epoch = self.epoch = 0
self.max_epoch = cfg.OPTIM.MAX_EPOCH
self.output_dir = cfg.OUTPUT_DIR
self.cfg = cfg
self.build_data_loader()
self.build_model()
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
self.best_result = -np.inf
def check_cfg(self, cfg):
"""Check whether some variables are set correctly for
the trainer (optional).
For example, a trainer might require a particular sampler
for training such as 'RandomDomainSampler', so it is good
to do the checking:
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
"""
pass
def build_data_loader(self):
"""Create essential data-related attributes.
A re-implementation of this method must create the
same attributes (except self.dm).
"""
dm = DataManager(self.cfg)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u # optional, can be None
self.val_loader = dm.val_loader # optional, can be None
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname # dict {label: classname}
self.dm = dm
def build_model(self):
"""Build and register model.
The default builds a classification model along with its
optimizer and scheduler.
Custom trainers can re-implement this method if necessary.
"""
cfg = self.cfg
print("Building model")
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
print("# params: {:,}".format(count_num_param(self.model)))
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("model", self.model, self.optim, self.sched)
device_count = torch.cuda.device_count()
if device_count > 1:
print(
f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel"
)
self.model = nn.DataParallel(self.model)
def train(self):
super().train(self.start_epoch, self.max_epoch)
def before_train(self):
directory = self.cfg.OUTPUT_DIR
if self.cfg.RESUME:
directory = self.cfg.RESUME
self.start_epoch = self.resume_model_if_exist(directory)
# Initialize summary writer
writer_dir = osp.join(self.output_dir, "tensorboard")
mkdir_if_missing(writer_dir)
self.init_writer(writer_dir)
# Remember the starting time (for computing the elapsed time)
self.time_start = time.time()
def after_train(self):
print("Finished training")
do_test = not self.cfg.TEST.NO_TEST
if do_test:
if self.cfg.TEST.FINAL_MODEL == "best_val":
print("Deploy the model with the best val performance")
self.load_model(self.output_dir)
self.test()
# Show elapsed time
elapsed = round(time.time() - self.time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed: {}".format(elapsed))
# Close writer
self.close_writer()
def after_epoch(self):
last_epoch = (self.epoch + 1) == self.max_epoch
do_test = not self.cfg.TEST.NO_TEST
meet_checkpoint_freq = (
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
)
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
curr_result = self.test(split="val")
is_best = curr_result > self.best_result
if is_best:
self.best_result = curr_result
self.save_model(
self.epoch,
self.output_dir,
model_name="model-best.pth.tar"
)
if meet_checkpoint_freq or last_epoch:
self.save_model(self.epoch, self.output_dir)
@torch.no_grad()
def output_test(self, split=None):
"""testing pipline, which could also output the results."""
self.set_model_mode("eval")
self.evaluator.reset()
output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')
res_json = {}
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
print("Do evaluation on {} set".format(split))
else:
data_loader = self.test_loader
print("Do evaluation on test set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
img_path = batch['impath']
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
for i in range(len(img_path)):
res_json[img_path[i]] = {
'predict': output[i].cpu().numpy().tolist(),
'gt': label[i].cpu().numpy().tolist()
}
with open(output_file, 'w') as f:
json.dump(res_json, f)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = "{}/{}".format(split, k)
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
@torch.no_grad()
def test(self, split=None):
"""A generic testing pipeline."""
self.set_model_mode("eval")
self.evaluator.reset()
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
print("Do evaluation on {} set".format(split))
else:
data_loader = self.test_loader
print("Do evaluation on test set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = "{}/{}".format(split, k)
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
def model_inference(self, input):
return self.model(input)
def parse_batch_test(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def get_current_lr(self, names=None):
names = self.get_model_names(names)
name = names[0]
return self._optims[name].param_groups[0]["lr"]
class TrainerXU(SimpleTrainer):
"""A base trainer using both labeled and unlabeled data.
In the context of domain adaptation, labeled and unlabeled data
come from source and target domains respectively.
When it comes to semi-supervised learning, all data comes from the
same domain.
"""
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
# Decide to iterate over labeled or unlabeled dataset
len_train_loader_x = len(self.train_loader_x)
len_train_loader_u = len(self.train_loader_u)
if self.cfg.TRAIN.COUNT_ITER == "train_x":
self.num_batches = len_train_loader_x
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
self.num_batches = len_train_loader_u
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
self.num_batches = min(len_train_loader_x, len_train_loader_u)
else:
raise ValueError
train_loader_x_iter = iter(self.train_loader_x)
train_loader_u_iter = iter(self.train_loader_u)
end = time.time()
for self.batch_idx in range(self.num_batches):
try:
batch_x = next(train_loader_x_iter)
except StopIteration:
train_loader_x_iter = iter(self.train_loader_x)
batch_x = next(train_loader_x_iter)
try:
batch_u = next(train_loader_u_iter)
except StopIteration:
train_loader_u_iter = iter(self.train_loader_u)
batch_u = next(train_loader_u_iter)
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch_x, batch_u)
batch_time.update(time.time() - end)
losses.update(loss_summary)
if (
self.batch_idx + 1
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
"epoch [{0}/{1}][{2}/{3}]\t"
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"eta {eta}\t"
"{losses}\t"
"lr {lr:.6e}".format(
self.epoch + 1,
self.max_epoch,
self.batch_idx + 1,
self.num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta,
losses=losses,
lr=self.get_current_lr(),
)
)
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
return input_x, label_x, input_u
class TrainerX(SimpleTrainer):
"""A base trainer using labeled data only."""
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.num_batches = len(self.train_loader_x)
end = time.time()
for self.batch_idx, batch in enumerate(self.train_loader_x):
data_time.update(time.time() - end)
loss_summary = self.forward_backward(batch)
batch_time.update(time.time() - end)
losses.update(loss_summary)
if (
self.batch_idx + 1
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
"epoch [{0}/{1}][{2}/{3}]\t"
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"eta {eta}\t"
"{losses}\t"
"lr {lr:.6e}".format(
self.epoch + 1,
self.max_epoch,
self.batch_idx + 1,
self.num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta,
losses=losses,
lr=self.get_current_lr(),
)
)
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
domain = batch["domain"]
input = input.to(self.device)
label = label.to(self.device)
domain = domain.to(self.device)
return input, label, domain

View File

@@ -0,0 +1,3 @@
from .build import build_evaluator, EVALUATOR_REGISTRY # isort:skip
from .evaluator import EvaluatorBase, Classification

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
EVALUATOR_REGISTRY = Registry("EVALUATOR")
def build_evaluator(cfg, **kwargs):
avai_evaluators = EVALUATOR_REGISTRY.registered_names()
check_availability(cfg.TEST.EVALUATOR, avai_evaluators)
if cfg.VERBOSE:
print("Loading evaluator: {}".format(cfg.TEST.EVALUATOR))
return EVALUATOR_REGISTRY.get(cfg.TEST.EVALUATOR)(cfg, **kwargs)

View File

@@ -0,0 +1,127 @@
import numpy as np
import os.path as osp
from collections import OrderedDict, defaultdict
import torch
from sklearn.metrics import f1_score, confusion_matrix
from .build import EVALUATOR_REGISTRY
class EvaluatorBase:
"""Base evaluator."""
def __init__(self, cfg):
self.cfg = cfg
def reset(self):
raise NotImplementedError
def process(self, mo, gt):
raise NotImplementedError
def evaluate(self):
raise NotImplementedError
@EVALUATOR_REGISTRY.register()
class Classification(EvaluatorBase):
"""Evaluator for classification."""
def __init__(self, cfg, lab2cname=None, **kwargs):
super().__init__(cfg)
self._lab2cname = lab2cname
self._correct = 0
self._total = 0
self._per_class_res = None
self._y_true = []
self._y_pred = []
if cfg.TEST.PER_CLASS_RESULT:
assert lab2cname is not None
self._per_class_res = defaultdict(list)
def reset(self):
self._correct = 0
self._total = 0
self._y_true = []
self._y_pred = []
if self._per_class_res is not None:
self._per_class_res = defaultdict(list)
def process(self, mo, gt):
# mo (torch.Tensor): model output [batch, num_classes]
# gt (torch.LongTensor): ground truth [batch]
pred = mo.max(1)[1]
matches = pred.eq(gt).float()
self._correct += int(matches.sum().item())
self._total += gt.shape[0]
self._y_true.extend(gt.data.cpu().numpy().tolist())
self._y_pred.extend(pred.data.cpu().numpy().tolist())
if self._per_class_res is not None:
for i, label in enumerate(gt):
label = label.item()
matches_i = int(matches[i].item())
self._per_class_res[label].append(matches_i)
def evaluate(self):
results = OrderedDict()
acc = 100.0 * self._correct / self._total
err = 100.0 - acc
macro_f1 = 100.0 * f1_score(
self._y_true,
self._y_pred,
average="macro",
labels=np.unique(self._y_true)
)
# The first value will be returned by trainer.test()
results["accuracy"] = acc
results["error_rate"] = err
results["macro_f1"] = macro_f1
print(
"=> result\n"
f"* total: {self._total:,}\n"
f"* correct: {self._correct:,}\n"
f"* accuracy: {acc:.2f}%\n"
f"* error: {err:.2f}%\n"
f"* macro_f1: {macro_f1:.2f}%"
)
if self._per_class_res is not None:
labels = list(self._per_class_res.keys())
labels.sort()
print("=> per-class result")
accs = []
for label in labels:
classname = self._lab2cname[label]
res = self._per_class_res[label]
correct = sum(res)
total = len(res)
acc = 100.0 * correct / total
accs.append(acc)
print(
"* class: {} ({})\t"
"total: {:,}\t"
"correct: {:,}\t"
"acc: {:.2f}%".format(
label, classname, total, correct, acc
)
)
mean_acc = np.mean(accs)
print("* average: {:.2f}%".format(mean_acc))
results["perclass_accuracy"] = mean_acc
if self.cfg.TEST.COMPUTE_CMAT:
cmat = confusion_matrix(
self._y_true, self._y_pred, normalize="true"
)
save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
torch.save(cmat, save_path)
print('Confusion matrix is saved to "{}"'.format(save_path))
return results

View File

@@ -0,0 +1,4 @@
from .accuracy import compute_accuracy
from .distance import (
cosine_distance, compute_distance_matrix, euclidean_squared_distance
)

View File

@@ -0,0 +1,30 @@
def compute_accuracy(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for
the specified values of k.
Args:
output (torch.Tensor): prediction matrix with shape (batch_size, num_classes).
target (torch.LongTensor): ground truth labels with shape (batch_size).
topk (tuple, optional): accuracy at top-k will be computed. For example,
topk=(1, 5) means accuracy at top-1 and top-5 will be computed.
Returns:
list: accuracy at top-k.
"""
maxk = max(topk)
batch_size = target.size(0)
if isinstance(output, (tuple, list)):
output = output[0]
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
acc = correct_k.mul_(100.0 / batch_size)
res.append(acc)
return res

View File

@@ -0,0 +1,77 @@
"""
Source: https://github.com/KaiyangZhou/deep-person-reid
"""
import torch
from torch.nn import functional as F
def compute_distance_matrix(input1, input2, metric="euclidean"):
"""A wrapper function for computing distance matrix.
Each input matrix has the shape (n_data, feature_dim).
Args:
input1 (torch.Tensor): 2-D feature matrix.
input2 (torch.Tensor): 2-D feature matrix.
metric (str, optional): "euclidean" or "cosine".
Default is "euclidean".
Returns:
torch.Tensor: distance matrix.
"""
# check input
assert isinstance(input1, torch.Tensor)
assert isinstance(input2, torch.Tensor)
assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
input1.dim()
)
assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
input2.dim()
)
assert input1.size(1) == input2.size(1)
if metric == "euclidean":
distmat = euclidean_squared_distance(input1, input2)
elif metric == "cosine":
distmat = cosine_distance(input1, input2)
else:
raise ValueError(
"Unknown distance metric: {}. "
'Please choose either "euclidean" or "cosine"'.format(metric)
)
return distmat
def euclidean_squared_distance(input1, input2):
"""Computes euclidean squared distance.
Args:
input1 (torch.Tensor): 2-D feature matrix.
input2 (torch.Tensor): 2-D feature matrix.
Returns:
torch.Tensor: distance matrix.
"""
m, n = input1.size(0), input2.size(0)
mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat = mat1 + mat2
distmat.addmm_(1, -2, input1, input2.t())
return distmat
def cosine_distance(input1, input2):
"""Computes cosine distance.
Args:
input1 (torch.Tensor): 2-D feature matrix.
input2 (torch.Tensor): 2-D feature matrix.
Returns:
torch.Tensor: distance matrix.
"""
input1_normed = F.normalize(input1, p=2, dim=1)
input2_normed = F.normalize(input2, p=2, dim=1)
distmat = 1 - torch.mm(input1_normed, input2_normed.t())
return distmat

View File

@@ -0,0 +1,3 @@
from .head import HEAD_REGISTRY, build_head
from .network import NETWORK_REGISTRY, build_network
from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone

View File

@@ -0,0 +1,27 @@
from .build import build_backbone, BACKBONE_REGISTRY # isort:skip
from .backbone import Backbone # isort:skip
from .vgg import vgg16
from .resnet import (
resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1,
resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1,
resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123,
resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12,
resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123,
resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123
)
from .alexnet import alexnet
from .mobilenetv2 import mobilenetv2
from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2
from .cnn_digitsdg import cnn_digitsdg
from .efficientnet import (
efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3,
efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
)
from .shufflenetv2 import (
shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5,
shufflenet_v2_x2_0
)
from .cnn_digitsingle import cnn_digitsingle
from .preact_resnet18 import preact_resnet18
from .cnn_digit5_m3sda import cnn_digit5_m3sda

View File

@@ -0,0 +1,64 @@
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth",
}
class AlexNet(Backbone):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
# Note that self.classifier outputs features rather than logits
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
)
self._out_features = 4096
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return self.classifier(x)
def init_pretrained_weights(model, model_url):
pretrain_dict = model_zoo.load_url(model_url)
model.load_state_dict(pretrain_dict, strict=False)
@BACKBONE_REGISTRY.register()
def alexnet(pretrained=True, **kwargs):
model = AlexNet()
if pretrained:
init_pretrained_weights(model, model_urls["alexnet"])
return model

View File

@@ -0,0 +1,17 @@
import torch.nn as nn
class Backbone(nn.Module):
def __init__(self):
super().__init__()
def forward(self):
pass
@property
def out_features(self):
"""Output feature dimension."""
if self.__dict__.get("_out_features") is None:
return None
return self._out_features

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
BACKBONE_REGISTRY = Registry("BACKBONE")
def build_backbone(name, verbose=True, **kwargs):
avai_backbones = BACKBONE_REGISTRY.registered_names()
check_availability(name, avai_backbones)
if verbose:
print("Backbone: {}".format(name))
return BACKBONE_REGISTRY.get(name)(**kwargs)

View File

@@ -0,0 +1,58 @@
"""
Reference
https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA
"""
import torch.nn as nn
from torch.nn import functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class FeatureExtractor(Backbone):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
self.bn3 = nn.BatchNorm2d(128)
self.fc1 = nn.Linear(8192, 3072)
self.bn1_fc = nn.BatchNorm1d(3072)
self.fc2 = nn.Linear(3072, 2048)
self.bn2_fc = nn.BatchNorm1d(2048)
self._out_features = 2048
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(x.size(0), 8192)
x = F.relu(self.bn1_fc(self.fc1(x)))
x = F.dropout(x, training=self.training)
x = F.relu(self.bn2_fc(self.fc2(x)))
return x
@BACKBONE_REGISTRY.register()
def cnn_digit5_m3sda(**kwargs):
"""
This architecture was used for the Digit-5 dataset in:
- Peng et al. Moment Matching for Multi-Source
Domain Adaptation. ICCV 2019.
"""
return FeatureExtractor()

View File

@@ -0,0 +1,61 @@
import torch.nn as nn
from torch.nn import functional as F
from dassl.utils import init_network_weights
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class Convolution(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1)
self.relu = nn.ReLU(True)
def forward(self, x):
return self.relu(self.conv(x))
class ConvNet(Backbone):
def __init__(self, c_hidden=64):
super().__init__()
self.conv1 = Convolution(3, c_hidden)
self.conv2 = Convolution(c_hidden, c_hidden)
self.conv3 = Convolution(c_hidden, c_hidden)
self.conv4 = Convolution(c_hidden, c_hidden)
self._out_features = 2**2 * c_hidden
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = self.conv1(x)
x = F.max_pool2d(x, 2)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.conv3(x)
x = F.max_pool2d(x, 2)
x = self.conv4(x)
x = F.max_pool2d(x, 2)
return x.view(x.size(0), -1)
@BACKBONE_REGISTRY.register()
def cnn_digitsdg(**kwargs):
"""
This architecture was used for DigitsDG dataset in:
- Zhou et al. Deep Domain-Adversarial Image Generation
for Domain Generalisation. AAAI 2020.
"""
model = ConvNet(c_hidden=64)
init_network_weights(model, init_type="kaiming")
return model

View File

@@ -0,0 +1,56 @@
"""
This model is built based on
https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py
"""
import torch.nn as nn
from torch.nn import functional as F
from dassl.utils import init_network_weights
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class CNN(Backbone):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 5)
self.conv2 = nn.Conv2d(64, 128, 5)
self.fc3 = nn.Linear(5 * 5 * 128, 1024)
self.fc4 = nn.Linear(1024, 1024)
self._out_features = 1024
def _check_input(self, x):
H, W = x.shape[2:]
assert (
H == 32 and W == 32
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
def forward(self, x):
self._check_input(x)
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = self.fc3(x)
x = F.relu(x)
x = self.fc4(x)
x = F.relu(x)
return x
@BACKBONE_REGISTRY.register()
def cnn_digitsingle(**kwargs):
model = CNN()
init_network_weights(model, init_type="kaiming")
return model

View File

@@ -0,0 +1,12 @@
"""
Source: https://github.com/lukemelas/EfficientNet-PyTorch.
"""
__version__ = "0.6.4"
from .model import (
EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2,
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6,
efficientnet_b7
)
from .utils import (
BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params
)

View File

@@ -0,0 +1,371 @@
import torch
from torch import nn
from torch.nn import functional as F
from .utils import (
Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats,
get_model_params, efficientnet_params, get_same_padding_conv2d,
load_pretrained_weights, calculate_output_image_size
)
from ..build import BACKBONE_REGISTRY
from ..backbone import Backbone
class MBConvBlock(nn.Module):
"""
Mobile Inverted Residual Bottleneck Block
Args:
block_args (namedtuple): BlockArgs, see above
global_params (namedtuple): GlobalParam, see above
Attributes:
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
"""
def __init__(self, block_args, global_params, image_size=None):
super().__init__()
self._block_args = block_args
self._bn_mom = 1 - global_params.batch_norm_momentum
self._bn_eps = global_params.batch_norm_epsilon
self.has_se = (self._block_args.se_ratio
is not None) and (0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip # skip connection and drop connect
# Expansion phase
inp = self._block_args.input_filters # number of input channels
oup = (
self._block_args.input_filters * self._block_args.expand_ratio
) # number of output channels
if self._block_args.expand_ratio != 1:
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._expand_conv = Conv2d(
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
)
self._bn0 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
)
# image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing
# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._depthwise_conv = Conv2d(
in_channels=oup,
out_channels=oup,
groups=oup, # groups makes it depthwise
kernel_size=k,
stride=s,
bias=False,
)
self._bn1 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
)
image_size = calculate_output_image_size(image_size, s)
# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(
1,
int(
self._block_args.input_filters * self._block_args.se_ratio
)
)
self._se_reduce = Conv2d(
in_channels=oup,
out_channels=num_squeezed_channels,
kernel_size=1
)
self._se_expand = Conv2d(
in_channels=num_squeezed_channels,
out_channels=oup,
kernel_size=1
)
# Output phase
final_oup = self._block_args.output_filters
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._project_conv = Conv2d(
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
)
self._bn2 = nn.BatchNorm2d(
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
)
self._swish = MemoryEfficientSwish()
def forward(self, inputs, drop_connect_rate=None):
"""
:param inputs: input tensor
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
:return: output of block
"""
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self._swish(self._bn1(self._depthwise_conv(x)))
# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_expand(
self._swish(self._se_reduce(x_squeezed))
)
x = torch.sigmoid(x_squeezed) * x
x = self._bn2(self._project_conv(x))
# Skip connection and drop connect
input_filters, output_filters = (
self._block_args.input_filters,
self._block_args.output_filters,
)
if (
self.id_skip and self._block_args.stride == 1
and input_filters == output_filters
):
if drop_connect_rate:
x = drop_connect(
x, p=drop_connect_rate, training=self.training
)
x = x + inputs # skip connection
return x
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
class EfficientNet(Backbone):
"""
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
Args:
blocks_args (list): A list of BlockArgs to construct blocks
global_params (namedtuple): A set of GlobalParams shared between blocks
Example:
model = EfficientNet.from_pretrained('efficientnet-b0')
"""
def __init__(self, blocks_args=None, global_params=None):
super().__init__()
assert isinstance(blocks_args, list), "blocks_args should be a list"
assert len(blocks_args) > 0, "block args must be greater than 0"
self._global_params = global_params
self._blocks_args = blocks_args
# Batch norm parameters
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
# Get stem static or dynamic convolution depending on image size
image_size = global_params.image_size
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
# Stem
in_channels = 3 # rgb
out_channels = round_filters(
32, self._global_params
) # number of output channels
self._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)
self._bn0 = nn.BatchNorm2d(
num_features=out_channels, momentum=bn_mom, eps=bn_eps
)
image_size = calculate_output_image_size(image_size, 2)
# Build blocks
self._blocks = nn.ModuleList([])
for block_args in self._blocks_args:
# Update block input and output filters based on depth multiplier.
block_args = block_args._replace(
input_filters=round_filters(
block_args.input_filters, self._global_params
),
output_filters=round_filters(
block_args.output_filters, self._global_params
),
num_repeat=round_repeats(
block_args.num_repeat, self._global_params
),
)
# The first block needs to take care of stride and filter size increase.
self._blocks.append(
MBConvBlock(
block_args, self._global_params, image_size=image_size
)
)
image_size = calculate_output_image_size(
image_size, block_args.stride
)
if block_args.num_repeat > 1:
block_args = block_args._replace(
input_filters=block_args.output_filters, stride=1
)
for _ in range(block_args.num_repeat - 1):
self._blocks.append(
MBConvBlock(
block_args, self._global_params, image_size=image_size
)
)
# image_size = calculate_output_image_size(image_size, block_args.stride) # ?
# Head
in_channels = block_args.output_filters # output of final block
out_channels = round_filters(1280, self._global_params)
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._conv_head = Conv2d(
in_channels, out_channels, kernel_size=1, bias=False
)
self._bn1 = nn.BatchNorm2d(
num_features=out_channels, momentum=bn_mom, eps=bn_eps
)
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
# self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()
self._out_features = out_channels
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
for block in self._blocks:
block.set_swish(memory_efficient)
def extract_features(self, inputs):
"""Returns output of the final convolution layer"""
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))
# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
# Head
x = self._swish(self._bn1(self._conv_head(x)))
return x
def forward(self, inputs):
"""
Calls extract_features to extract features, applies
final linear layer, and returns logits.
"""
bs = inputs.size(0)
# Convolution layers
x = self.extract_features(inputs)
# Pooling and final linear layer
x = self._avg_pooling(x)
x = x.view(bs, -1)
x = self._dropout(x)
# x = self._fc(x)
return x
@classmethod
def from_name(cls, model_name, override_params=None):
cls._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(
model_name, override_params
)
return cls(blocks_args, global_params)
@classmethod
def from_pretrained(
cls, model_name, advprop=False, num_classes=1000, in_channels=3
):
model = cls.from_name(
model_name, override_params={"num_classes": num_classes}
)
load_pretrained_weights(
model, model_name, load_fc=(num_classes == 1000), advprop=advprop
)
model._change_in_channels(in_channels)
return model
@classmethod
def get_image_size(cls, model_name):
cls._check_model_name_is_valid(model_name)
_, _, res, _ = efficientnet_params(model_name)
return res
@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name."""
valid_models = ["efficientnet-b" + str(i) for i in range(9)]
if model_name not in valid_models:
raise ValueError(
"model_name should be one of: " + ", ".join(valid_models)
)
def _change_in_channels(model, in_channels):
if in_channels != 3:
Conv2d = get_same_padding_conv2d(
image_size=model._global_params.image_size
)
out_channels = round_filters(32, model._global_params)
model._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)
def build_efficientnet(name, pretrained):
if pretrained:
return EfficientNet.from_pretrained("efficientnet-{}".format(name))
else:
return EfficientNet.from_name("efficientnet-{}".format(name))
@BACKBONE_REGISTRY.register()
def efficientnet_b0(pretrained=True, **kwargs):
return build_efficientnet("b0", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b1(pretrained=True, **kwargs):
return build_efficientnet("b1", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b2(pretrained=True, **kwargs):
return build_efficientnet("b2", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b3(pretrained=True, **kwargs):
return build_efficientnet("b3", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b4(pretrained=True, **kwargs):
return build_efficientnet("b4", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b5(pretrained=True, **kwargs):
return build_efficientnet("b5", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b6(pretrained=True, **kwargs):
return build_efficientnet("b6", pretrained)
@BACKBONE_REGISTRY.register()
def efficientnet_b7(pretrained=True, **kwargs):
return build_efficientnet("b7", pretrained)

View File

@@ -0,0 +1,477 @@
"""
This file contains helper functions for building the model and for loading model parameters.
These helper functions are built to mirror those in the official TensorFlow implementation.
"""
import re
import math
import collections
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
########################################################################
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
########################################################################
# Parameters for the entire model (stem, all blocks, and head)
GlobalParams = collections.namedtuple(
"GlobalParams",
[
"batch_norm_momentum",
"batch_norm_epsilon",
"dropout_rate",
"num_classes",
"width_coefficient",
"depth_coefficient",
"depth_divisor",
"min_depth",
"drop_connect_rate",
"image_size",
],
)
# Parameters for an individual model block
BlockArgs = collections.namedtuple(
"BlockArgs",
[
"kernel_size",
"num_repeat",
"input_filters",
"output_filters",
"expand_ratio",
"id_skip",
"stride",
"se_ratio",
],
)
# Change namedtuple defaults
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
def round_filters(filters, global_params):
"""Calculate and round number of filters based on depth multiplier."""
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
min_depth = min_depth or divisor
new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor
return int(new_filters)
def round_repeats(repeats, global_params):
"""Round number of filters based on depth multiplier."""
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
def drop_connect(inputs, p, training):
"""Drop connect."""
if not training:
return inputs
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
random_tensor += torch.rand(
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
)
binary_tensor = torch.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def get_same_padding_conv2d(image_size=None):
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
Static padding is necessary for ONNX exporting of models."""
if image_size is None:
return Conv2dDynamicSamePadding
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)
def get_width_and_height_from_size(x):
"""Obtains width and height from a int or tuple"""
if isinstance(x, int):
return x, x
if isinstance(x, list) or isinstance(x, tuple):
return x
else:
raise TypeError()
def calculate_output_image_size(input_image_size, stride):
"""
Calculates the output image size when using Conv2dSamePadding with a stride.
Necessary for static padding. Thanks to mannatsingh for pointing this out.
"""
if input_image_size is None:
return None
image_height, image_width = get_width_and_height_from_size(
input_image_size
)
stride = stride if isinstance(stride, int) else stride[0]
image_height = int(math.ceil(image_height / stride))
image_width = int(math.ceil(image_width / stride))
return [image_height, image_width]
class Conv2dDynamicSamePadding(nn.Conv2d):
"""2D Convolutions like TensorFlow, for a dynamic image size"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
):
super().__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation,
groups, bias
)
self.stride = self.stride if len(self.stride
) == 2 else [self.stride[0]] * 2
def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max(
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
)
pad_w = max(
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
)
if pad_h > 0 or pad_w > 0:
x = F.pad(
x,
[pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2]
)
return F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
class Conv2dStaticSamePadding(nn.Conv2d):
"""2D Convolutions like TensorFlow, for a fixed image size"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
image_size=None,
**kwargs
):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.stride = self.stride if len(self.stride
) == 2 else [self.stride[0]] * 2
# Calculate padding based on image size and save it
assert image_size is not None
ih, iw = (image_size,
image_size) if isinstance(image_size, int) else image_size
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max(
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
)
pad_w = max(
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
)
if pad_h > 0 or pad_w > 0:
self.static_padding = nn.ZeroPad2d(
(pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2)
)
else:
self.static_padding = Identity()
def forward(self, x):
x = self.static_padding(x)
x = F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return x
class Identity(nn.Module):
def __init__(self, ):
super(Identity, self).__init__()
def forward(self, input):
return input
########################################################################
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
########################################################################
def efficientnet_params(model_name):
"""Map EfficientNet model name to parameter coefficients."""
params_dict = {
# Coefficients: width,depth,res,dropout
"efficientnet-b0": (1.0, 1.0, 224, 0.2),
"efficientnet-b1": (1.0, 1.1, 240, 0.2),
"efficientnet-b2": (1.1, 1.2, 260, 0.3),
"efficientnet-b3": (1.2, 1.4, 300, 0.3),
"efficientnet-b4": (1.4, 1.8, 380, 0.4),
"efficientnet-b5": (1.6, 2.2, 456, 0.4),
"efficientnet-b6": (1.8, 2.6, 528, 0.5),
"efficientnet-b7": (2.0, 3.1, 600, 0.5),
"efficientnet-b8": (2.2, 3.6, 672, 0.5),
"efficientnet-l2": (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]
class BlockDecoder(object):
"""Block Decoder for readability, straight from the official TensorFlow repository"""
@staticmethod
def _decode_block_string(block_string):
"""Gets a block through a string notation of arguments."""
assert isinstance(block_string, str)
ops = block_string.split("_")
options = {}
for op in ops:
splits = re.split(r"(\d.*)", op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# Check stride
assert ("s" in options and len(options["s"]) == 1) or (
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
)
return BlockArgs(
kernel_size=int(options["k"]),
num_repeat=int(options["r"]),
input_filters=int(options["i"]),
output_filters=int(options["o"]),
expand_ratio=int(options["e"]),
id_skip=("noskip" not in block_string),
se_ratio=float(options["se"]) if "se" in options else None,
stride=[int(options["s"][0])],
)
@staticmethod
def _encode_block_string(block):
"""Encodes a block to a string."""
args = [
"r%d" % block.num_repeat,
"k%d" % block.kernel_size,
"s%d%d" % (block.strides[0], block.strides[1]),
"e%s" % block.expand_ratio,
"i%d" % block.input_filters,
"o%d" % block.output_filters,
]
if 0 < block.se_ratio <= 1:
args.append("se%s" % block.se_ratio)
if block.id_skip is False:
args.append("noskip")
return "_".join(args)
@staticmethod
def decode(string_list):
"""
Decodes a list of string notations to specify blocks inside the network.
:param string_list: a list of strings, each string is a notation of block
:return: a list of BlockArgs namedtuples of block args
"""
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
return blocks_args
@staticmethod
def encode(blocks_args):
"""
Encodes a list of BlockArgs to a list of strings.
:param blocks_args: a list of BlockArgs namedtuples of block args
:return: a list of strings, each string is a notation of block
"""
block_strings = []
for block in blocks_args:
block_strings.append(BlockDecoder._encode_block_string(block))
return block_strings
def efficientnet(
width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
drop_connect_rate=0.2,
image_size=None,
num_classes=1000,
):
"""Creates a efficientnet model."""
blocks_args = [
"r1_k3_s11_e1_i32_o16_se0.25",
"r2_k3_s22_e6_i16_o24_se0.25",
"r2_k5_s22_e6_i24_o40_se0.25",
"r3_k3_s22_e6_i40_o80_se0.25",
"r3_k5_s11_e6_i80_o112_se0.25",
"r4_k5_s22_e6_i112_o192_se0.25",
"r1_k3_s11_e6_i192_o320_se0.25",
]
blocks_args = BlockDecoder.decode(blocks_args)
global_params = GlobalParams(
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
drop_connect_rate=drop_connect_rate,
# data_format='channels_last', # removed, this is always true in PyTorch
num_classes=num_classes,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
image_size=image_size,
)
return blocks_args, global_params
def get_model_params(model_name, override_params):
"""Get the block args and global params for a given model"""
if model_name.startswith("efficientnet"):
w, d, s, p = efficientnet_params(model_name)
# note: all models have drop connect rate = 0.2
blocks_args, global_params = efficientnet(
width_coefficient=w,
depth_coefficient=d,
dropout_rate=p,
image_size=s
)
else:
raise NotImplementedError(
"model name is not pre-defined: %s" % model_name
)
if override_params:
# ValueError will be raised here if override_params has fields not included in global_params.
global_params = global_params._replace(**override_params)
return blocks_args, global_params
url_map = {
"efficientnet-b0":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
"efficientnet-b1":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
"efficientnet-b2":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
"efficientnet-b3":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
"efficientnet-b4":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
"efficientnet-b5":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
"efficientnet-b6":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
"efficientnet-b7":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
}
url_map_advprop = {
"efficientnet-b0":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
"efficientnet-b1":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
"efficientnet-b2":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
"efficientnet-b3":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
"efficientnet-b4":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
"efficientnet-b5":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
"efficientnet-b6":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
"efficientnet-b7":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
"efficientnet-b8":
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
}
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
"""Loads pretrained weights, and downloads if loading for the first time."""
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map_advprop if advprop else url_map
state_dict = model_zoo.load_url(url_map_[model_name])
model.load_state_dict(state_dict, strict=False)
"""
if load_fc:
model.load_state_dict(state_dict)
else:
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
res = model.load_state_dict(state_dict, strict=False)
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
print('Loaded pretrained weights for {}'.format(model_name))
"""

View File

@@ -0,0 +1,217 @@
import torch.utils.model_zoo as model_zoo
from torch import nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"mobilenet_v2":
"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
}
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor/2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(
self, in_planes, out_planes, kernel_size=3, stride=1, groups=1
):
padding = (kernel_size-1) // 2
super().__init__(
nn.Conv2d(
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias=False,
),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True),
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend(
[
# dw
ConvBNReLU(
hidden_dim, hidden_dim, stride=stride, groups=hidden_dim
),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
]
)
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(Backbone):
def __init__(
self,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None,
):
"""
MobileNet V2.
Args:
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
"""
super().__init__()
if block is None:
block = InvertedResidual
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if (
len(inverted_residual_setting) == 0
or len(inverted_residual_setting[0]) != 4
):
raise ValueError(
"inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".
format(inverted_residual_setting)
)
# building first layer
input_channel = _make_divisible(
input_channel * width_mult, round_nearest
)
self.last_channel = _make_divisible(
last_channel * max(1.0, width_mult), round_nearest
)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(
block(
input_channel, output_channel, stride, expand_ratio=t
)
)
input_channel = output_channel
# building last several layers
features.append(
ConvBNReLU(input_channel, self.last_channel, kernel_size=1)
)
# make it nn.Sequential
self.features = nn.Sequential(*features)
self._out_features = self.last_channel
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
x = x.mean([2, 3])
return x
def forward(self, x):
return self._forward_impl(x)
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
if model_url is None:
import warnings
warnings.warn(
"ImageNet pretrained weights are unavailable for this model"
)
return
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
@BACKBONE_REGISTRY.register()
def mobilenetv2(pretrained=True, **kwargs):
model = MobileNetV2(**kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["mobilenet_v2"])
return model

View File

@@ -0,0 +1,135 @@
import torch.nn as nn
import torch.nn.functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class PreActBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
class PreActBottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, self.expansion * planes, kernel_size=1, bias=False
)
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
)
)
def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += shortcut
return out
class PreActResNet(Backbone):
def __init__(self, block, num_blocks):
super().__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self._out_features = 512 * block.expansion
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
return out
"""
Preact-ResNet18 was used for the CIFAR10 and
SVHN datasets (both are SSL tasks) in
- Wang et al. Semi-Supervised Learning by
Augmented Distribution Alignment. ICCV 2019.
"""
@BACKBONE_REGISTRY.register()
def preact_resnet18(**kwargs):
return PreActResNet(PreActBlock, [2, 2, 2, 2])

View File

@@ -0,0 +1,589 @@
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(Backbone):
def __init__(
self,
block,
layers,
ms_class=None,
ms_layers=[],
ms_p=0.5,
ms_a=0.1,
**kwargs
):
self.inplanes = 64
super().__init__()
# backbone network
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self._out_features = 512 * block.expansion
self.mixstyle = None
if ms_layers:
self.mixstyle = ms_class(p=ms_p, alpha=ms_a)
for layer_name in ms_layers:
assert layer_name in ["layer1", "layer2", "layer3"]
print(f"Insert MixStyle after {ms_layers}")
self.ms_layers = ms_layers
self._init_params()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
if "layer1" in self.ms_layers:
x = self.mixstyle(x)
x = self.layer2(x)
if "layer2" in self.ms_layers:
x = self.mixstyle(x)
x = self.layer3(x)
if "layer3" in self.ms_layers:
x = self.mixstyle(x)
return self.layer4(x)
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
return v.view(v.size(0), -1)
def init_pretrained_weights(model, model_url):
pretrain_dict = model_zoo.load_url(model_url)
model.load_state_dict(pretrain_dict, strict=False)
"""
Residual network configurations:
--
resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
"""
@BACKBONE_REGISTRY.register()
def resnet18(pretrained=True, **kwargs):
model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2])
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet34(pretrained=True, **kwargs):
model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet34"])
return model
@BACKBONE_REGISTRY.register()
def resnet50(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 4, 23, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet152(pretrained=True, **kwargs):
model = ResNet(block=Bottleneck, layers=[3, 8, 36, 3])
if pretrained:
init_pretrained_weights(model, model_urls["resnet152"])
return model
"""
Residual networks with mixstyle
"""
@BACKBONE_REGISTRY.register()
def resnet18_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_ms_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import MixStyle
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=MixStyle,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
"""
Residual networks with efdmix
"""
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet18_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=BasicBlock,
layers=[2, 2, 2, 2],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet18"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet50_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 6, 3],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet50"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l123(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2", "layer3"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l12(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1", "layer2"],
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model
@BACKBONE_REGISTRY.register()
def resnet101_efdmix_l1(pretrained=True, **kwargs):
from dassl.modeling.ops import EFDMix
model = ResNet(
block=Bottleneck,
layers=[3, 4, 23, 3],
ms_class=EFDMix,
ms_layers=["layer1"]
)
if pretrained:
init_pretrained_weights(model, model_urls["resnet101"])
return model

View File

@@ -0,0 +1,229 @@
"""
Code source: https://github.com/pytorch/vision
"""
import torch
import torch.utils.model_zoo as model_zoo
from torch import nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
model_urls = {
"shufflenetv2_x0.5":
"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0":
"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": None,
"shufflenetv2_x2.0": None,
}
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super().__init__()
if not (1 <= stride <= 3):
raise ValueError("illegal stride value")
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(
inp, inp, kernel_size=3, stride=self.stride, padding=1
),
nn.BatchNorm2d(inp),
nn.Conv2d(
inp,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
),
nn.BatchNorm2d(branch_features),
nn.Conv2d(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(
i, o, kernel_size, stride, padding, bias=bias, groups=i
)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(Backbone):
def __init__(self, stages_repeats, stages_out_channels, **kwargs):
super().__init__()
if len(stages_repeats) != 3:
raise ValueError(
"expected stages_repeats as list of 3 positive ints"
)
if len(stages_out_channels) != 5:
raise ValueError(
"expected stages_out_channels as list of 5 positive ints"
)
self._stage_out_channels = stages_out_channels
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]
):
seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(
InvertedResidual(output_channels, output_channels, 1)
)
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
self._out_features = output_channels
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
return v.view(v.size(0), -1)
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
if model_url is None:
import warnings
warnings.warn(
"ImageNet pretrained weights are unavailable for this model"
)
return
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {
k: v
for k, v in pretrain_dict.items()
if k in model_dict and model_dict[k].size() == v.size()
}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x0_5(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x0.5"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x1_0(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x1.0"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x1_5(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x1.5"])
return model
@BACKBONE_REGISTRY.register()
def shufflenet_v2_x2_0(pretrained=True, **kwargs):
model = ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls["shufflenetv2_x2.0"])
return model

View File

@@ -0,0 +1,147 @@
import torch
import torch.nn as nn
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
model_urls = {
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
}
class VGG(Backbone):
def __init__(self, features, init_weights=True):
super().__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
# Note that self.classifier outputs features rather than logits
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
)
self._out_features = 4096
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return self.classifier(x)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == "M":
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfgs = {
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"B":
[64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"D": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
"M",
512,
512,
512,
"M",
512,
512,
512,
"M",
],
"E": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
256,
"M",
512,
512,
512,
512,
"M",
512,
512,
512,
512,
"M",
],
}
def _vgg(arch, cfg, batch_norm, pretrained):
init_weights = False if pretrained else True
model = VGG(
make_layers(cfgs[cfg], batch_norm=batch_norm),
init_weights=init_weights
)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=True)
model.load_state_dict(state_dict, strict=False)
return model
@BACKBONE_REGISTRY.register()
def vgg16(pretrained=True, **kwargs):
return _vgg("vgg16", "D", False, pretrained)

View File

@@ -0,0 +1,150 @@
"""
Modified from https://github.com/xternalz/WideResNet-pytorch
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .build import BACKBONE_REGISTRY
from .backbone import Backbone
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu1 = nn.LeakyReLU(0.01, inplace=True)
self.conv1 = nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(out_planes)
self.relu2 = nn.LeakyReLU(0.01, inplace=True)
self.conv2 = nn.Conv2d(
out_planes,
out_planes,
kernel_size=3,
stride=1,
padding=1,
bias=False
)
self.droprate = dropRate
self.equalInOut = in_planes == out_planes
self.convShortcut = (
(not self.equalInOut) and nn.Conv2d(
in_planes,
out_planes,
kernel_size=1,
stride=stride,
padding=0,
bias=False,
) or None
)
def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class NetworkBlock(nn.Module):
def __init__(
self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0
):
super().__init__()
self.layer = self._make_layer(
block, in_planes, out_planes, nb_layers, stride, dropRate
)
def _make_layer(
self, block, in_planes, out_planes, nb_layers, stride, dropRate
):
layers = []
for i in range(int(nb_layers)):
layers.append(
block(
i == 0 and in_planes or out_planes,
out_planes,
i == 0 and stride or 1,
dropRate,
)
)
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNet(Backbone):
def __init__(self, depth, widen_factor, dropRate=0.0):
super().__init__()
nChannels = [
16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor
]
assert (depth-4) % 6 == 0
n = (depth-4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(
3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False
)
# 1st block
self.block1 = NetworkBlock(
n, nChannels[0], nChannels[1], block, 1, dropRate
)
# 2nd block
self.block2 = NetworkBlock(
n, nChannels[1], nChannels[2], block, 2, dropRate
)
# 3rd block
self.block3 = NetworkBlock(
n, nChannels[2], nChannels[3], block, 2, dropRate
)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3])
self.relu = nn.LeakyReLU(0.01, inplace=True)
self._out_features = nChannels[3]
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.adaptive_avg_pool2d(out, 1)
return out.view(out.size(0), -1)
@BACKBONE_REGISTRY.register()
def wide_resnet_28_2(**kwargs):
return WideResNet(28, 2)
@BACKBONE_REGISTRY.register()
def wide_resnet_16_4(**kwargs):
return WideResNet(16, 4)

View File

@@ -0,0 +1,3 @@
from .build import build_head, HEAD_REGISTRY # isort:skip
from .mlp import mlp

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
HEAD_REGISTRY = Registry("HEAD")
def build_head(name, verbose=True, **kwargs):
avai_heads = HEAD_REGISTRY.registered_names()
check_availability(name, avai_heads)
if verbose:
print("Head: {}".format(name))
return HEAD_REGISTRY.get(name)(**kwargs)

View File

@@ -0,0 +1,50 @@
import functools
import torch.nn as nn
from .build import HEAD_REGISTRY
class MLP(nn.Module):
def __init__(
self,
in_features=2048,
hidden_layers=[],
activation="relu",
bn=True,
dropout=0.0,
):
super().__init__()
if isinstance(hidden_layers, int):
hidden_layers = [hidden_layers]
assert len(hidden_layers) > 0
self.out_features = hidden_layers[-1]
mlp = []
if activation == "relu":
act_fn = functools.partial(nn.ReLU, inplace=True)
elif activation == "leaky_relu":
act_fn = functools.partial(nn.LeakyReLU, inplace=True)
else:
raise NotImplementedError
for hidden_dim in hidden_layers:
mlp += [nn.Linear(in_features, hidden_dim)]
if bn:
mlp += [nn.BatchNorm1d(hidden_dim)]
mlp += [act_fn()]
if dropout > 0:
mlp += [nn.Dropout(dropout)]
in_features = hidden_dim
self.mlp = nn.Sequential(*mlp)
def forward(self, x):
return self.mlp(x)
@HEAD_REGISTRY.register()
def mlp(**kwargs):
return MLP(**kwargs)

View File

@@ -0,0 +1,5 @@
from .build import build_network, NETWORK_REGISTRY # isort:skip
from .ddaig_fcn import (
fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn
)

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
NETWORK_REGISTRY = Registry("NETWORK")
def build_network(name, verbose=True, **kwargs):
avai_models = NETWORK_REGISTRY.registered_names()
check_availability(name, avai_models)
if verbose:
print("Network: {}".format(name))
return NETWORK_REGISTRY.get(name)(**kwargs)

View File

@@ -0,0 +1,329 @@
"""
Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
"""
import functools
import torch
import torch.nn as nn
from torch.nn import functional as F
from .build import NETWORK_REGISTRY
def init_network_weights(model, init_type="normal", gain=0.02):
def _init_func(m):
classname = m.__class__.__name__
if hasattr(m, "weight") and (
classname.find("Conv") != -1 or classname.find("Linear") != -1
):
if init_type == "normal":
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "kaiming":
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
nn.init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError(
"initialization method {} is not implemented".
format(init_type)
)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("InstanceNorm2d") != -1:
if m.weight is not None and m.bias is not None:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
model.apply(_init_func)
def get_norm_layer(norm_type="instance"):
if norm_type == "batch":
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == "instance":
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=False
)
elif norm_type == "none":
norm_layer = None
else:
raise NotImplementedError(
"normalization layer [%s] is not found" % norm_type
)
return norm_layer
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super().__init__()
self.conv_block = self.build_conv_block(
dim, padding_type, norm_layer, use_dropout, use_bias
)
def build_conv_block(
self, dim, padding_type, norm_layer, use_dropout, use_bias
):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
"padding [%s] is not implemented" % padding_type
)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True),
]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
"padding [%s] is not implemented" % padding_type
)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
]
return nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class LocNet(nn.Module):
"""Localization network."""
def __init__(
self,
input_nc,
nc=32,
n_blocks=3,
use_dropout=False,
padding_type="zero",
image_size=32,
):
super().__init__()
backbone = []
backbone += [
nn.Conv2d(
input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False
)
]
backbone += [nn.BatchNorm2d(nc)]
backbone += [nn.ReLU(True)]
for _ in range(n_blocks):
backbone += [
ResnetBlock(
nc,
padding_type=padding_type,
norm_layer=nn.BatchNorm2d,
use_dropout=use_dropout,
use_bias=False,
)
]
backbone += [nn.MaxPool2d(2, stride=2)]
self.backbone = nn.Sequential(*backbone)
reduced_imsize = int(image_size * 0.5**(n_blocks + 1))
self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)
def forward(self, x):
x = self.backbone(x)
x = x.view(x.size(0), -1)
x = self.fc_loc(x)
x = torch.tanh(x)
x = x.view(-1, 2, 2)
theta = x.data.new_zeros(x.size(0), 2, 3)
theta[:, :, :2] = x
return theta
class FCN(nn.Module):
"""Fully convolutional network."""
def __init__(
self,
input_nc,
output_nc,
nc=32,
n_blocks=3,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
padding_type="reflect",
gctx=True,
stn=False,
image_size=32,
):
super().__init__()
backbone = []
p = 0
if padding_type == "reflect":
backbone += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
backbone += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError
backbone += [
nn.Conv2d(
input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False
)
]
backbone += [norm_layer(nc)]
backbone += [nn.ReLU(True)]
for _ in range(n_blocks):
backbone += [
ResnetBlock(
nc,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
use_bias=False,
)
]
self.backbone = nn.Sequential(*backbone)
# global context fusion layer
self.gctx_fusion = None
if gctx:
self.gctx_fusion = nn.Sequential(
nn.Conv2d(
2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False
),
norm_layer(nc),
nn.ReLU(True),
)
self.regress = nn.Sequential(
nn.Conv2d(
nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True
),
nn.Tanh(),
)
self.locnet = None
if stn:
self.locnet = LocNet(
input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size
)
def init_loc_layer(self):
"""Initialize the weights/bias with identity transformation."""
if self.locnet is not None:
self.locnet.fc_loc.weight.data.zero_()
self.locnet.fc_loc.bias.data.copy_(
torch.tensor([1, 0, 0, 1], dtype=torch.float)
)
def stn(self, x):
"""Spatial transformer network."""
theta = self.locnet(x)
grid = F.affine_grid(theta, x.size())
return F.grid_sample(x, grid), theta
def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
"""
Args:
x (torch.Tensor): input mini-batch.
lmda (float): multiplier for perturbation.
return_p (bool): return perturbation.
return_stn_output (bool): return the output of stn.
"""
theta = None
if self.locnet is not None:
x, theta = self.stn(x)
input = x
x = self.backbone(x)
if self.gctx_fusion is not None:
c = F.adaptive_avg_pool2d(x, (1, 1))
c = c.expand_as(x)
x = torch.cat([x, c], 1)
x = self.gctx_fusion(x)
p = self.regress(x)
x_p = input + lmda*p
if return_stn_output:
return x_p, p, input
if return_p:
return x_p, p
return x_p
@NETWORK_REGISTRY.register()
def fcn_3x32_gctx(**kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)
init_network_weights(net, init_type="normal", gain=0.02)
return net
@NETWORK_REGISTRY.register()
def fcn_3x64_gctx(**kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)
init_network_weights(net, init_type="normal", gain=0.02)
return net
@NETWORK_REGISTRY.register()
def fcn_3x32_gctx_stn(image_size=32, **kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(
3,
3,
nc=32,
n_blocks=3,
norm_layer=norm_layer,
stn=True,
image_size=image_size
)
init_network_weights(net, init_type="normal", gain=0.02)
net.init_loc_layer()
return net
@NETWORK_REGISTRY.register()
def fcn_3x64_gctx_stn(image_size=224, **kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(
3,
3,
nc=64,
n_blocks=3,
norm_layer=norm_layer,
stn=True,
image_size=image_size
)
init_network_weights(net, init_type="normal", gain=0.02)
net.init_loc_layer()
return net

View File

@@ -0,0 +1,16 @@
from .mmd import MaximumMeanDiscrepancy
from .dsbn import DSBN1d, DSBN2d
from .mixup import mixup
from .efdmix import (
EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix,
crossdomain_efdmix, run_without_efdmix
)
from .mixstyle import (
MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle,
deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle
)
from .transnorm import TransNorm1d, TransNorm2d
from .sequential2 import Sequential2
from .reverse_grad import ReverseGrad
from .cross_entropy import cross_entropy
from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance

View File

@@ -0,0 +1,30 @@
import torch
from torch.nn import functional as F
def cross_entropy(input, target, label_smooth=0, reduction="mean"):
"""Cross entropy loss.
Args:
input (torch.Tensor): logit matrix with shape of (batch, num_classes).
target (torch.LongTensor): int label matrix.
label_smooth (float, optional): label smoothing hyper-parameter.
Default is 0.
reduction (str, optional): how the losses for a mini-batch
will be aggregated. Default is 'mean'.
"""
num_classes = input.shape[1]
log_prob = F.log_softmax(input, dim=1)
zeros = torch.zeros(log_prob.size())
target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1)
target = target.type_as(input)
target = (1-label_smooth) * target + label_smooth/num_classes
loss = (-target * log_prob).sum(1)
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
elif reduction == "none":
return loss
else:
raise ValueError

View File

@@ -0,0 +1,45 @@
import torch.nn as nn
class _DSBN(nn.Module):
"""Domain Specific Batch Normalization.
Args:
num_features (int): number of features.
n_domain (int): number of domains.
bn_type (str): type of bn. Choices are ['1d', '2d'].
"""
def __init__(self, num_features, n_domain, bn_type):
super().__init__()
if bn_type == "1d":
BN = nn.BatchNorm1d
elif bn_type == "2d":
BN = nn.BatchNorm2d
else:
raise ValueError
self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))
self.valid_domain_idxs = list(range(n_domain))
self.n_domain = n_domain
self.domain_idx = 0
def select_bn(self, domain_idx=0):
assert domain_idx in self.valid_domain_idxs
self.domain_idx = domain_idx
def forward(self, x):
return self.bn[self.domain_idx](x)
class DSBN1d(_DSBN):
def __init__(self, num_features, n_domain):
super().__init__(num_features, n_domain, "1d")
class DSBN2d(_DSBN):
def __init__(self, num_features, n_domain):
super().__init__(num_features, n_domain, "2d")

View File

@@ -0,0 +1,118 @@
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(False)
def activate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(True)
def random_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("random")
def crossdomain_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_efdmix(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_efdmix)
yield
finally:
model.apply(activate_efdmix)
@contextmanager
def run_with_efdmix(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_efdmix)
elif mix == "crossdomain":
model.apply(crossdomain_efdmix)
try:
model.apply(activate_efdmix)
yield
finally:
model.apply(deactivate_efdmix)
class EFDMix(nn.Module):
"""EFDMix.
Reference:
Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)
x_view = x.view(B, C, -1)
value_x, index_x = torch.sort(x_view) # sort inputs
lmda = self.beta.sample((B, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
inverse_index = index_x.argsort(-1)
x_view_copy = value_x[perm].gather(-1, inverse_index)
new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)
return new_x.view(B, C, W, H)

View File

@@ -0,0 +1,124 @@
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(False)
def activate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(True)
def random_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("random")
def crossdomain_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_mixstyle(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_mixstyle)
yield
finally:
model.apply(activate_mixstyle)
@contextmanager
def run_with_mixstyle(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_mixstyle)
elif mix == "crossdomain":
model.apply(crossdomain_mixstyle)
try:
model.apply(activate_mixstyle)
yield
finally:
model.apply(deactivate_mixstyle)
class MixStyle(nn.Module):
"""MixStyle.
Reference:
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B = x.size(0)
mu = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], keepdim=True)
sig = (var + self.eps).sqrt()
mu, sig = mu.detach(), sig.detach()
x_normed = (x-mu) / sig
lmda = self.beta.sample((B, 1, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
mu2, sig2 = mu[perm], sig[perm]
mu_mix = mu*lmda + mu2 * (1-lmda)
sig_mix = sig*lmda + sig2 * (1-lmda)
return x_normed*sig_mix + mu_mix

View File

@@ -0,0 +1,23 @@
import torch
def mixup(x1, x2, y1, y2, beta, preserve_order=False):
"""Mixup.
Args:
x1 (torch.Tensor): data with shape of (b, c, h, w).
x2 (torch.Tensor): data with shape of (b, c, h, w).
y1 (torch.Tensor): label with shape of (b, n).
y2 (torch.Tensor): label with shape of (b, n).
beta (float): hyper-parameter for Beta sampling.
preserve_order (bool): apply lmda=max(lmda, 1-lmda).
Default is False.
"""
lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])
if preserve_order:
lmda = torch.max(lmda, 1 - lmda)
lmda = lmda.to(x1.device)
xmix = x1*lmda + x2 * (1-lmda)
lmda = lmda[:, :, 0, 0]
ymix = y1*lmda + y2 * (1-lmda)
return xmix, ymix

View File

@@ -0,0 +1,91 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
class MaximumMeanDiscrepancy(nn.Module):
def __init__(self, kernel_type="rbf", normalize=False):
super().__init__()
self.kernel_type = kernel_type
self.normalize = normalize
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
if self.normalize:
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
if self.kernel_type == "linear":
return self.linear_mmd(x, y)
elif self.kernel_type == "poly":
return self.poly_mmd(x, y)
elif self.kernel_type == "rbf":
return self.rbf_mmd(x, y)
else:
raise NotImplementedError
def linear_mmd(self, x, y):
# k(x, y) = x^T y
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_xy = torch.mm(x, y.t())
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
# k(x, y) = (alpha * x^T y + c)^d
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_xx = (alpha*k_xx + c).pow(d)
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_yy = (alpha*k_yy + c).pow(d)
k_xy = torch.mm(x, y.t())
k_xy = (alpha*k_xy + c).pow(d)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def rbf_mmd(self, x, y):
# k_xx
d_xx = self.euclidean_squared_distance(x, x)
d_xx = self.remove_self_distance(d_xx)
k_xx = self.rbf_kernel_mixture(d_xx)
# k_yy
d_yy = self.euclidean_squared_distance(y, y)
d_yy = self.remove_self_distance(d_yy)
k_yy = self.rbf_kernel_mixture(d_yy)
# k_xy
d_xy = self.euclidean_squared_distance(x, y)
k_xy = self.rbf_kernel_mixture(d_xy)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
@staticmethod
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
K = 0
for sigma in sigmas:
gamma = 1.0 / (2.0 * sigma**2)
K += torch.exp(-gamma * exponent)
return K
@staticmethod
def remove_self_distance(distmat):
tmp_list = []
for i, row in enumerate(distmat):
row1 = torch.cat([row[:i], row[i + 1:]])
tmp_list.append(row1)
return torch.stack(tmp_list)
@staticmethod
def euclidean_squared_distance(x, y):
m, n = x.size(0), y.size(0)
distmat = (
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
# distmat.addmm_(1, -2, x, y.t())
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
return distmat
if __name__ == "__main__":
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
d = mmd(input1, input2)
print(d.item())

View File

@@ -0,0 +1,147 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
class OptimalTransport(nn.Module):
@staticmethod
def distance(batch1, batch2, dist_metric="cosine"):
if dist_metric == "cosine":
batch1 = F.normalize(batch1, p=2, dim=1)
batch2 = F.normalize(batch2, p=2, dim=1)
dist_mat = 1 - torch.mm(batch1, batch2.t())
elif dist_metric == "euclidean":
m, n = batch1.size(0), batch2.size(0)
dist_mat = (
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
dist_mat.addmm_(
1, -2, batch1, batch2.t()
) # squared euclidean distance
elif dist_metric == "fast_euclidean":
batch1 = batch1.unsqueeze(-2)
batch2 = batch2.unsqueeze(-3)
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
else:
raise ValueError(
"Unknown cost function: {}. Expected to "
"be one of [cosine | euclidean]".format(dist_metric)
)
return dist_mat
class SinkhornDivergence(OptimalTransport):
thre = 1e-3
def __init__(
self,
dist_metric="cosine",
eps=0.01,
max_iter=5,
bp_to_sinkhorn=False
):
super().__init__()
self.dist_metric = dist_metric
self.eps = eps
self.max_iter = max_iter
self.bp_to_sinkhorn = bp_to_sinkhorn
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
W_xy = self.transport_cost(x, y)
W_xx = self.transport_cost(x, x)
W_yy = self.transport_cost(y, y)
return 2*W_xy - W_xx - W_yy
def transport_cost(self, x, y, return_pi=False):
C = self.distance(x, y, dist_metric=self.dist_metric)
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
if not self.bp_to_sinkhorn:
pi = pi.detach()
cost = torch.sum(pi * C)
if return_pi:
return cost, pi
return cost
@staticmethod
def sinkhorn_iterate(C, eps, max_iter, thre):
nx, ny = C.shape
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
u = torch.zeros_like(mu)
v = torch.zeros_like(nu)
def M(_C, _u, _v):
"""Modified cost for logarithmic updates.
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
"""
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
real_iter = 0 # check if algorithm terminates before max_iter
# Sinkhorn iterations
for i in range(max_iter):
u0 = u
u = eps * (
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
) + u
v = (
eps * (
torch.log(nu + 1e-8) -
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
) + v
)
err = (u - u0).abs().sum()
real_iter += 1
if err.item() < thre:
break
# Transport plan pi = diag(a)*K*diag(b)
return torch.exp(M(C, u, v))
class MinibatchEnergyDistance(SinkhornDivergence):
def __init__(
self,
dist_metric="cosine",
eps=0.01,
max_iter=5,
bp_to_sinkhorn=False
):
super().__init__(
dist_metric=dist_metric,
eps=eps,
max_iter=max_iter,
bp_to_sinkhorn=bp_to_sinkhorn,
)
def forward(self, x, y):
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
cost = 0
cost += self.transport_cost(x1, y1)
cost += self.transport_cost(x1, y2)
cost += self.transport_cost(x2, y1)
cost += self.transport_cost(x2, y2)
cost -= 2 * self.transport_cost(x1, x2)
cost -= 2 * self.transport_cost(y1, y2)
return cost
if __name__ == "__main__":
# example: https://dfdazac.github.io/sinkhorn.html
import numpy as np
n_points = 5
a = np.array([[i, 0] for i in range(n_points)])
b = np.array([[i, 1] for i in range(n_points)])
x = torch.tensor(a, dtype=torch.float)
y = torch.tensor(b, dtype=torch.float)
sinkhorn = SinkhornDivergence(
dist_metric="euclidean", eps=0.01, max_iter=5
)
dist, pi = sinkhorn.transport_cost(x, y, True)
import pdb
pdb.set_trace()

View File

@@ -0,0 +1,34 @@
import torch.nn as nn
from torch.autograd import Function
class _ReverseGrad(Function):
@staticmethod
def forward(ctx, input, grad_scaling):
ctx.grad_scaling = grad_scaling
return input.view_as(input)
@staticmethod
def backward(ctx, grad_output):
grad_scaling = ctx.grad_scaling
return -grad_scaling * grad_output, None
reverse_grad = _ReverseGrad.apply
class ReverseGrad(nn.Module):
"""Gradient reversal layer.
It acts as an identity layer in the forward,
but reverses the sign of the gradient in
the backward.
"""
def forward(self, x, grad_scaling=1.0):
assert (grad_scaling >=
0), "grad_scaling must be non-negative, " "but got {}".format(
grad_scaling
)
return reverse_grad(x, grad_scaling)

View File

@@ -0,0 +1,15 @@
import torch.nn as nn
class Sequential2(nn.Sequential):
"""An alternative sequential container to nn.Sequential,
which accepts an arbitrary number of input arguments.
"""
def forward(self, *inputs):
for module in self._modules.values():
if isinstance(inputs, tuple):
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs

View File

@@ -0,0 +1,138 @@
import torch
import torch.nn as nn
class _TransNorm(nn.Module):
"""Transferable normalization.
Reference:
- Wang et al. Transferable Normalization: Towards Improving
Transferability of Deep Neural Networks. NeurIPS 2019.
Args:
num_features (int): number of features.
eps (float): epsilon.
momentum (float): value for updating running_mean and running_var.
adaptive_alpha (bool): apply domain adaptive alpha.
"""
def __init__(
self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.adaptive_alpha = adaptive_alpha
self.register_buffer("running_mean_s", torch.zeros(num_features))
self.register_buffer("running_var_s", torch.ones(num_features))
self.register_buffer("running_mean_t", torch.zeros(num_features))
self.register_buffer("running_var_t", torch.ones(num_features))
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def resnet_running_stats(self):
self.running_mean_s.zero_()
self.running_var_s.fill_(1)
self.running_mean_t.zero_()
self.running_var_t.fill_(1)
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def _check_input(self, x):
raise NotImplementedError
def _compute_alpha(self, mean_s, var_s, mean_t, var_t):
C = self.num_features
ratio_s = mean_s / (var_s + self.eps).sqrt()
ratio_t = mean_t / (var_t + self.eps).sqrt()
dist = (ratio_s - ratio_t).abs()
dist_inv = 1 / (1+dist)
return C * dist_inv / dist_inv.sum()
def forward(self, input):
self._check_input(input)
C = self.num_features
if input.dim() == 2:
new_shape = (1, C)
elif input.dim() == 4:
new_shape = (1, C, 1, 1)
else:
raise ValueError
weight = self.weight.view(*new_shape)
bias = self.bias.view(*new_shape)
if not self.training:
mean_t = self.running_mean_t.view(*new_shape)
var_t = self.running_var_t.view(*new_shape)
output = (input-mean_t) / (var_t + self.eps).sqrt()
output = output*weight + bias
if self.adaptive_alpha:
mean_s = self.running_mean_s.view(*new_shape)
var_s = self.running_var_s.view(*new_shape)
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)
x_s = input_s.transpose(0, 1).reshape(C, -1)
mean_s = x_s.mean(1)
var_s = x_s.var(1)
self.running_mean_s.mul_(self.momentum)
self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
self.running_var_s.mul_(self.momentum)
self.running_var_s.add_((1 - self.momentum) * var_s.data)
mean_s = mean_s.reshape(*new_shape)
var_s = var_s.reshape(*new_shape)
output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
output_s = output_s*weight + bias
x_t = input_t.transpose(0, 1).reshape(C, -1)
mean_t = x_t.mean(1)
var_t = x_t.var(1)
self.running_mean_t.mul_(self.momentum)
self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
self.running_var_t.mul_(self.momentum)
self.running_var_t.add_((1 - self.momentum) * var_t.data)
mean_t = mean_t.reshape(*new_shape)
var_t = var_t.reshape(*new_shape)
output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
output_t = output_t*weight + bias
output = torch.cat([output_s, output_t], 0)
if self.adaptive_alpha:
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
class TransNorm1d(_TransNorm):
def _check_input(self, x):
if x.dim() != 2:
raise ValueError(
"Expected the input to be 2-D, "
"but got {}-D".format(x.dim())
)
class TransNorm2d(_TransNorm):
def _check_input(self, x):
if x.dim() != 4:
raise ValueError(
"Expected the input to be 4-D, "
"but got {}-D".format(x.dim())
)

View File

@@ -0,0 +1,75 @@
import numpy as np
import torch
def sharpen_prob(p, temperature=2):
"""Sharpening probability with a temperature.
Args:
p (torch.Tensor): probability matrix (batch_size, n_classes)
temperature (float): temperature.
"""
p = p.pow(temperature)
return p / p.sum(1, keepdim=True)
def reverse_index(data, label):
"""Reverse order."""
inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()
return data[inv_idx], label[inv_idx]
def shuffle_index(data, label):
"""Shuffle order."""
rnd_idx = torch.randperm(data.shape[0])
return data[rnd_idx], label[rnd_idx]
def create_onehot(label, num_classes):
"""Create one-hot tensor.
We suggest using nn.functional.one_hot.
Args:
label (torch.Tensor): 1-D tensor.
num_classes (int): number of classes.
"""
onehot = torch.zeros(label.shape[0], num_classes)
return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)
def sigmoid_rampup(current, rampup_length):
"""Exponential rampup.
Args:
current (int): current step.
rampup_length (int): maximum step.
"""
assert rampup_length > 0
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current/rampup_length
return float(np.exp(-5.0 * phase * phase))
def linear_rampup(current, rampup_length):
"""Linear rampup.
Args:
current (int): current step.
rampup_length (int): maximum step.
"""
assert rampup_length > 0
ratio = np.clip(current / rampup_length, 0.0, 1.0)
return float(ratio)
def ema_model_update(model, ema_model, alpha):
"""Exponential moving average of model parameters.
Args:
model (nn.Module): model being trained.
ema_model (nn.Module): ema of the model.
alpha (float): ema decay rate.
"""
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

View File

@@ -0,0 +1,2 @@
from .optimizer import build_optimizer
from .lr_scheduler import build_lr_scheduler

View File

@@ -0,0 +1,154 @@
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import torch
from torch.optim.lr_scheduler import _LRScheduler
AVAI_SCHEDS = ["single_step", "multi_step", "cosine"]
class _BaseWarmupScheduler(_LRScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
last_epoch=-1,
verbose=False
):
self.successor = successor
self.warmup_epoch = warmup_epoch
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
if self.last_epoch >= self.warmup_epoch:
self.successor.step(epoch)
self._last_lr = self.successor.get_last_lr()
else:
super().step(epoch)
class ConstantWarmupScheduler(_BaseWarmupScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
cons_lr,
last_epoch=-1,
verbose=False
):
self.cons_lr = cons_lr
super().__init__(
optimizer, successor, warmup_epoch, last_epoch, verbose
)
def get_lr(self):
if self.last_epoch >= self.warmup_epoch:
return self.successor.get_last_lr()
return [self.cons_lr for _ in self.base_lrs]
class LinearWarmupScheduler(_BaseWarmupScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
min_lr,
last_epoch=-1,
verbose=False
):
self.min_lr = min_lr
super().__init__(
optimizer, successor, warmup_epoch, last_epoch, verbose
)
def get_lr(self):
if self.last_epoch >= self.warmup_epoch:
return self.successor.get_last_lr()
if self.last_epoch == 0:
return [self.min_lr for _ in self.base_lrs]
return [
lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs
]
def build_lr_scheduler(optimizer, optim_cfg):
"""A function wrapper for building a learning rate scheduler.
Args:
optimizer (Optimizer): an Optimizer.
optim_cfg (CfgNode): optimization config.
"""
lr_scheduler = optim_cfg.LR_SCHEDULER
stepsize = optim_cfg.STEPSIZE
gamma = optim_cfg.GAMMA
max_epoch = optim_cfg.MAX_EPOCH
if lr_scheduler not in AVAI_SCHEDS:
raise ValueError(
"Unsupported scheduler: {}. Must be one of {}".format(
lr_scheduler, AVAI_SCHEDS
)
)
if lr_scheduler == "single_step":
if isinstance(stepsize, (list, tuple)):
stepsize = stepsize[-1]
if not isinstance(stepsize, int):
raise TypeError(
"For single_step lr_scheduler, stepsize must "
"be an integer, but got {}".format(type(stepsize))
)
if stepsize <= 0:
stepsize = max_epoch
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=stepsize, gamma=gamma
)
elif lr_scheduler == "multi_step":
if not isinstance(stepsize, (list, tuple)):
raise TypeError(
"For multi_step lr_scheduler, stepsize must "
"be a list, but got {}".format(type(stepsize))
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=stepsize, gamma=gamma
)
elif lr_scheduler == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(max_epoch)
)
if optim_cfg.WARMUP_EPOCH > 0:
if not optim_cfg.WARMUP_RECOUNT:
scheduler.last_epoch = optim_cfg.WARMUP_EPOCH
if optim_cfg.WARMUP_TYPE == "constant":
scheduler = ConstantWarmupScheduler(
optimizer, scheduler, optim_cfg.WARMUP_EPOCH,
optim_cfg.WARMUP_CONS_LR
)
elif optim_cfg.WARMUP_TYPE == "linear":
scheduler = LinearWarmupScheduler(
optimizer, scheduler, optim_cfg.WARMUP_EPOCH,
optim_cfg.WARMUP_MIN_LR
)
else:
raise ValueError
return scheduler

View File

@@ -0,0 +1,136 @@
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import warnings
import torch
import torch.nn as nn
from .radam import RAdam
AVAI_OPTIMS = ["adam", "amsgrad", "sgd", "rmsprop", "radam", "adamw"]
def build_optimizer(model, optim_cfg):
"""A function wrapper for building an optimizer.
Args:
model (nn.Module or iterable): model.
optim_cfg (CfgNode): optimization config.
"""
optim = optim_cfg.NAME
lr = optim_cfg.LR
weight_decay = optim_cfg.WEIGHT_DECAY
momentum = optim_cfg.MOMENTUM
sgd_dampening = optim_cfg.SGD_DAMPNING
sgd_nesterov = optim_cfg.SGD_NESTEROV
rmsprop_alpha = optim_cfg.RMSPROP_ALPHA
adam_beta1 = optim_cfg.ADAM_BETA1
adam_beta2 = optim_cfg.ADAM_BETA2
staged_lr = optim_cfg.STAGED_LR
new_layers = optim_cfg.NEW_LAYERS
base_lr_mult = optim_cfg.BASE_LR_MULT
if optim not in AVAI_OPTIMS:
raise ValueError(
"Unsupported optim: {}. Must be one of {}".format(
optim, AVAI_OPTIMS
)
)
if staged_lr:
if not isinstance(model, nn.Module):
raise TypeError(
"When staged_lr is True, model given to "
"build_optimizer() must be an instance of nn.Module"
)
if isinstance(model, nn.DataParallel):
model = model.module
if isinstance(new_layers, str):
if new_layers is None:
warnings.warn(
"new_layers is empty, therefore, staged_lr is useless"
)
new_layers = [new_layers]
base_params = []
base_layers = []
new_params = []
for name, module in model.named_children():
if name in new_layers:
new_params += [p for p in module.parameters()]
else:
base_params += [p for p in module.parameters()]
base_layers.append(name)
param_groups = [
{
"params": base_params,
"lr": lr * base_lr_mult
},
{
"params": new_params
},
]
else:
if isinstance(model, nn.Module):
param_groups = model.parameters()
else:
param_groups = model
if optim == "adam":
optimizer = torch.optim.Adam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
)
elif optim == "amsgrad":
optimizer = torch.optim.Adam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
amsgrad=True,
)
elif optim == "sgd":
optimizer = torch.optim.SGD(
param_groups,
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
dampening=sgd_dampening,
nesterov=sgd_nesterov,
)
elif optim == "rmsprop":
optimizer = torch.optim.RMSprop(
param_groups,
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
alpha=rmsprop_alpha,
)
elif optim == "radam":
optimizer = RAdam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
)
elif optim == "adamw":
optimizer = torch.optim.AdamW(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
)
return optimizer

Some files were not shown because too many files have changed in this diff Show More