release code
This commit is contained in:
18
Dassl.ProGrad.pytorch/dassl/__init__.py
Normal file
18
Dassl.ProGrad.pytorch/dassl/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Dassl
|
||||
------
|
||||
PyTorch toolbox for domain adaptation and semi-supervised learning.
|
||||
|
||||
URL: https://github.com/KaiyangZhou/Dassl.pytorch
|
||||
|
||||
@article{zhou2020domain,
|
||||
title={Domain Adaptive Ensemble Learning},
|
||||
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
|
||||
journal={arXiv preprint arXiv:2003.07325},
|
||||
year={2020}
|
||||
}
|
||||
"""
|
||||
|
||||
__version__ = "0.5.0"
|
||||
__author__ = "Kaiyang Zhou"
|
||||
__homepage__ = "https://kaiyangzhou.github.io/"
|
||||
5
Dassl.ProGrad.pytorch/dassl/config/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .defaults import _C as cfg_default
|
||||
|
||||
|
||||
def get_cfg_default():
|
||||
return cfg_default.clone()
|
||||
275
Dassl.ProGrad.pytorch/dassl/config/defaults.py
Normal file
275
Dassl.ProGrad.pytorch/dassl/config/defaults.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
###########################
|
||||
# Config definition
|
||||
###########################
|
||||
|
||||
_C = CN()
|
||||
|
||||
_C.VERSION = 1
|
||||
|
||||
# Directory to save the output files (like log.txt and model weights)
|
||||
_C.OUTPUT_DIR = "./output"
|
||||
# Path to a directory where the files were saved previously
|
||||
_C.RESUME = ""
|
||||
# Set seed to negative value to randomize everything
|
||||
# Set seed to positive value to use a fixed seed
|
||||
_C.SEED = -1
|
||||
_C.USE_CUDA = True
|
||||
# Print detailed information
|
||||
# E.g. trainer, dataset, and backbone
|
||||
_C.VERBOSE = True
|
||||
|
||||
###########################
|
||||
# Input
|
||||
###########################
|
||||
_C.INPUT = CN()
|
||||
_C.INPUT.SIZE = (224, 224)
|
||||
# Mode of interpolation in resize functions
|
||||
_C.INPUT.INTERPOLATION = "bilinear"
|
||||
# For available choices please refer to transforms.py
|
||||
_C.INPUT.TRANSFORMS = ()
|
||||
# If True, tfm_train and tfm_test will be None
|
||||
_C.INPUT.NO_TRANSFORM = False
|
||||
# Default mean and std come from ImageNet
|
||||
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
|
||||
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
|
||||
# Padding for random crop
|
||||
_C.INPUT.CROP_PADDING = 4
|
||||
# Cutout
|
||||
_C.INPUT.CUTOUT_N = 1
|
||||
_C.INPUT.CUTOUT_LEN = 16
|
||||
# Gaussian noise
|
||||
_C.INPUT.GN_MEAN = 0.0
|
||||
_C.INPUT.GN_STD = 0.15
|
||||
# RandomAugment
|
||||
_C.INPUT.RANDAUGMENT_N = 2
|
||||
_C.INPUT.RANDAUGMENT_M = 10
|
||||
# ColorJitter (brightness, contrast, saturation, hue)
|
||||
_C.INPUT.COLORJITTER_B = 0.4
|
||||
_C.INPUT.COLORJITTER_C = 0.4
|
||||
_C.INPUT.COLORJITTER_S = 0.4
|
||||
_C.INPUT.COLORJITTER_H = 0.1
|
||||
# Random gray scale's probability
|
||||
_C.INPUT.RGS_P = 0.2
|
||||
# Gaussian blur
|
||||
_C.INPUT.GB_P = 0.5 # propability of applying this operation
|
||||
_C.INPUT.GB_K = 21 # kernel size (should be an odd number)
|
||||
|
||||
###########################
|
||||
# Dataset
|
||||
###########################
|
||||
_C.DATASET = CN()
|
||||
# Directory where datasets are stored
|
||||
_C.DATASET.ROOT = ""
|
||||
_C.DATASET.NAME = ""
|
||||
# List of names of source domains
|
||||
_C.DATASET.SOURCE_DOMAINS = ()
|
||||
# List of names of target domains
|
||||
_C.DATASET.TARGET_DOMAINS = ()
|
||||
# Number of labeled instances in total
|
||||
# Useful for the semi-supervised learning
|
||||
_C.DATASET.NUM_LABELED = -1
|
||||
# Number of images per class
|
||||
_C.DATASET.NUM_SHOTS = -1
|
||||
# Percentage of validation data (only used for SSL datasets)
|
||||
# Set to 0 if do not want to use val data
|
||||
# Using val data for hyperparameter tuning was done in Oliver et al. 2018
|
||||
_C.DATASET.VAL_PERCENT = 0.1
|
||||
# Fold index for STL-10 dataset (normal range is 0 - 9)
|
||||
# Negative number means None
|
||||
_C.DATASET.STL10_FOLD = -1
|
||||
# CIFAR-10/100-C's corruption type and intensity level
|
||||
_C.DATASET.CIFAR_C_TYPE = ""
|
||||
_C.DATASET.CIFAR_C_LEVEL = 1
|
||||
# Use all data in the unlabeled data set (e.g. FixMatch)
|
||||
_C.DATASET.ALL_AS_UNLABELED = False
|
||||
|
||||
###########################
|
||||
# Dataloader
|
||||
###########################
|
||||
_C.DATALOADER = CN()
|
||||
_C.DATALOADER.NUM_WORKERS = 4
|
||||
# Apply transformations to an image K times (during training)
|
||||
_C.DATALOADER.K_TRANSFORMS = 1
|
||||
# img0 denotes image tensor without augmentation
|
||||
# Useful for consistency learning
|
||||
_C.DATALOADER.RETURN_IMG0 = False
|
||||
# Setting for the train_x data-loader
|
||||
_C.DATALOADER.TRAIN_X = CN()
|
||||
_C.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler"
|
||||
_C.DATALOADER.TRAIN_X.BATCH_SIZE = 32
|
||||
# Parameter for RandomDomainSampler
|
||||
# 0 or -1 means sampling from all domains
|
||||
_C.DATALOADER.TRAIN_X.N_DOMAIN = 0
|
||||
# Parameter of RandomClassSampler
|
||||
# Number of instances per class
|
||||
_C.DATALOADER.TRAIN_X.N_INS = 16
|
||||
|
||||
# Setting for the train_u data-loader
|
||||
_C.DATALOADER.TRAIN_U = CN()
|
||||
# Set to false if you want to have unique
|
||||
# data loader params for train_u
|
||||
_C.DATALOADER.TRAIN_U.SAME_AS_X = True
|
||||
_C.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler"
|
||||
_C.DATALOADER.TRAIN_U.BATCH_SIZE = 32
|
||||
_C.DATALOADER.TRAIN_U.N_DOMAIN = 0
|
||||
_C.DATALOADER.TRAIN_U.N_INS = 16
|
||||
|
||||
# Setting for the test data-loader
|
||||
_C.DATALOADER.TEST = CN()
|
||||
_C.DATALOADER.TEST.SAMPLER = "SequentialSampler"
|
||||
_C.DATALOADER.TEST.BATCH_SIZE = 32
|
||||
|
||||
###########################
|
||||
# Model
|
||||
###########################
|
||||
_C.MODEL = CN()
|
||||
# Path to model weights (for initialization)
|
||||
_C.MODEL.INIT_WEIGHTS = ""
|
||||
_C.MODEL.BACKBONE = CN()
|
||||
_C.MODEL.BACKBONE.NAME = ""
|
||||
_C.MODEL.BACKBONE.PRETRAINED = True
|
||||
# Definition of embedding layers
|
||||
_C.MODEL.HEAD = CN()
|
||||
# If none, do not construct embedding layers, the
|
||||
# backbone's output will be passed to the classifier
|
||||
_C.MODEL.HEAD.NAME = ""
|
||||
# Structure of hidden layers (a list), e.g. [512, 512]
|
||||
# If undefined, no embedding layer will be constructed
|
||||
_C.MODEL.HEAD.HIDDEN_LAYERS = ()
|
||||
_C.MODEL.HEAD.ACTIVATION = "relu"
|
||||
_C.MODEL.HEAD.BN = True
|
||||
_C.MODEL.HEAD.DROPOUT = 0.0
|
||||
|
||||
###########################
|
||||
# Optimization
|
||||
###########################
|
||||
_C.OPTIM = CN()
|
||||
_C.OPTIM.NAME = "adam"
|
||||
_C.OPTIM.LR = 0.0003
|
||||
_C.OPTIM.WEIGHT_DECAY = 5e-4
|
||||
_C.OPTIM.MOMENTUM = 0.9
|
||||
_C.OPTIM.SGD_DAMPNING = 0
|
||||
_C.OPTIM.SGD_NESTEROV = False
|
||||
_C.OPTIM.RMSPROP_ALPHA = 0.99
|
||||
_C.OPTIM.ADAM_BETA1 = 0.9
|
||||
_C.OPTIM.ADAM_BETA2 = 0.999
|
||||
# STAGED_LR allows different layers to have
|
||||
# different lr, e.g. pre-trained base layers
|
||||
# can be assigned a smaller lr than the new
|
||||
# classification layer
|
||||
_C.OPTIM.STAGED_LR = False
|
||||
_C.OPTIM.NEW_LAYERS = ()
|
||||
_C.OPTIM.BASE_LR_MULT = 0.1
|
||||
# Learning rate scheduler
|
||||
_C.OPTIM.LR_SCHEDULER = "single_step"
|
||||
# -1 or 0 means the stepsize is equal to max_epoch
|
||||
_C.OPTIM.STEPSIZE = (-1, )
|
||||
_C.OPTIM.GAMMA = 0.1
|
||||
_C.OPTIM.MAX_EPOCH = 10
|
||||
# Set WARMUP_EPOCH larger than 0 to activate warmup training
|
||||
_C.OPTIM.WARMUP_EPOCH = -1
|
||||
# Either linear or constant
|
||||
_C.OPTIM.WARMUP_TYPE = "linear"
|
||||
# Constant learning rate when type=constant
|
||||
_C.OPTIM.WARMUP_CONS_LR = 1e-5
|
||||
# Minimum learning rate when type=linear
|
||||
_C.OPTIM.WARMUP_MIN_LR = 1e-5
|
||||
# Recount epoch for the next scheduler (last_epoch=-1)
|
||||
# Otherwise last_epoch=warmup_epoch
|
||||
_C.OPTIM.WARMUP_RECOUNT = True
|
||||
|
||||
###########################
|
||||
# Train
|
||||
###########################
|
||||
_C.TRAIN = CN()
|
||||
# How often (epoch) to save model during training
|
||||
# Set to 0 or negative value to only save the last one
|
||||
_C.TRAIN.CHECKPOINT_FREQ = 0
|
||||
# How often (batch) to print training information
|
||||
_C.TRAIN.PRINT_FREQ = 10
|
||||
# Use 'train_x', 'train_u' or 'smaller_one' to count
|
||||
# the number of iterations in an epoch (for DA and SSL)
|
||||
_C.TRAIN.COUNT_ITER = "train_x"
|
||||
|
||||
###########################
|
||||
# Test
|
||||
###########################
|
||||
_C.TEST = CN()
|
||||
_C.TEST.EVALUATOR = "Classification"
|
||||
_C.TEST.PER_CLASS_RESULT = False
|
||||
# Compute confusion matrix, which will be saved
|
||||
# to $OUTPUT_DIR/cmat.pt
|
||||
_C.TEST.COMPUTE_CMAT = False
|
||||
# If NO_TEST=True, no testing will be conducted
|
||||
_C.TEST.NO_TEST = False
|
||||
# Use test or val set for FINAL evaluation
|
||||
_C.TEST.SPLIT = "test"
|
||||
# Which model to test after training
|
||||
# Either last_step or best_val
|
||||
_C.TEST.FINAL_MODEL = "last_step"
|
||||
|
||||
###########################
|
||||
# Trainer specifics
|
||||
###########################
|
||||
_C.TRAINER = CN()
|
||||
_C.TRAINER.NAME = ""
|
||||
|
||||
# MCD
|
||||
_C.TRAINER.MCD = CN()
|
||||
_C.TRAINER.MCD.N_STEP_F = 4 # number of steps to train F
|
||||
# MME
|
||||
_C.TRAINER.MME = CN()
|
||||
_C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss
|
||||
# SelfEnsembling
|
||||
_C.TRAINER.SE = CN()
|
||||
_C.TRAINER.SE.EMA_ALPHA = 0.999
|
||||
_C.TRAINER.SE.CONF_THRE = 0.95
|
||||
_C.TRAINER.SE.RAMPUP = 300
|
||||
|
||||
# M3SDA
|
||||
_C.TRAINER.M3SDA = CN()
|
||||
_C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss
|
||||
_C.TRAINER.M3SDA.N_STEP_F = 4 # follow MCD
|
||||
# DAEL
|
||||
_C.TRAINER.DAEL = CN()
|
||||
_C.TRAINER.DAEL.WEIGHT_U = 0.5 # weight on the unlabeled loss
|
||||
_C.TRAINER.DAEL.CONF_THRE = 0.95 # confidence threshold
|
||||
_C.TRAINER.DAEL.STRONG_TRANSFORMS = ()
|
||||
|
||||
# CrossGrad
|
||||
_C.TRAINER.CG = CN()
|
||||
_C.TRAINER.CG.EPS_F = 1.0 # scaling parameter for D's gradients
|
||||
_C.TRAINER.CG.EPS_D = 1.0 # scaling parameter for F's gradients
|
||||
_C.TRAINER.CG.ALPHA_F = 0.5 # balancing weight for the label net's loss
|
||||
_C.TRAINER.CG.ALPHA_D = 0.5 # balancing weight for the domain net's loss
|
||||
# DDAIG
|
||||
_C.TRAINER.DDAIG = CN()
|
||||
_C.TRAINER.DDAIG.G_ARCH = "" # generator's architecture
|
||||
_C.TRAINER.DDAIG.LMDA = 0.3 # perturbation weight
|
||||
_C.TRAINER.DDAIG.CLAMP = False # clamp perturbation values
|
||||
_C.TRAINER.DDAIG.CLAMP_MIN = -1.0
|
||||
_C.TRAINER.DDAIG.CLAMP_MAX = 1.0
|
||||
_C.TRAINER.DDAIG.WARMUP = 0
|
||||
_C.TRAINER.DDAIG.ALPHA = 0.5 # balancing weight for the losses
|
||||
|
||||
# EntMin
|
||||
_C.TRAINER.ENTMIN = CN()
|
||||
_C.TRAINER.ENTMIN.LMDA = 1e-3 # weight on the entropy loss
|
||||
# Mean Teacher
|
||||
_C.TRAINER.MEANTEA = CN()
|
||||
_C.TRAINER.MEANTEA.WEIGHT_U = 1.0 # weight on the unlabeled loss
|
||||
_C.TRAINER.MEANTEA.EMA_ALPHA = 0.999
|
||||
_C.TRAINER.MEANTEA.RAMPUP = 5 # epochs used to ramp up the loss_u weight
|
||||
# MixMatch
|
||||
_C.TRAINER.MIXMATCH = CN()
|
||||
_C.TRAINER.MIXMATCH.WEIGHT_U = 100.0 # weight on the unlabeled loss
|
||||
_C.TRAINER.MIXMATCH.TEMP = 2.0 # temperature for sharpening the probability
|
||||
_C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75
|
||||
_C.TRAINER.MIXMATCH.RAMPUP = 20000 # steps used to ramp up the loss_u weight
|
||||
# FixMatch
|
||||
_C.TRAINER.FIXMATCH = CN()
|
||||
_C.TRAINER.FIXMATCH.WEIGHT_U = 1.0 # weight on the unlabeled loss
|
||||
_C.TRAINER.FIXMATCH.CONF_THRE = 0.95 # confidence threshold
|
||||
_C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = ()
|
||||
1
Dassl.ProGrad.pytorch/dassl/data/__init__.py
Normal file
1
Dassl.ProGrad.pytorch/dassl/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .data_manager import DataManager, DatasetWrapper
|
||||
264
Dassl.ProGrad.pytorch/dassl/data/data_manager.py
Normal file
264
Dassl.ProGrad.pytorch/dassl/data/data_manager.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
from dassl.utils import read_image
|
||||
|
||||
from .datasets import build_dataset
|
||||
from .samplers import build_sampler
|
||||
from .transforms import build_transform
|
||||
|
||||
INTERPOLATION_MODES = {
|
||||
"bilinear": Image.BILINEAR,
|
||||
"bicubic": Image.BICUBIC,
|
||||
"nearest": Image.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
def build_data_loader(
|
||||
cfg,
|
||||
sampler_type="SequentialSampler",
|
||||
data_source=None,
|
||||
batch_size=64,
|
||||
n_domain=0,
|
||||
n_ins=2,
|
||||
tfm=None,
|
||||
is_train=True,
|
||||
dataset_wrapper=None,
|
||||
):
|
||||
# Build sampler
|
||||
sampler = build_sampler(
|
||||
sampler_type,
|
||||
cfg=cfg,
|
||||
data_source=data_source,
|
||||
batch_size=batch_size,
|
||||
n_domain=n_domain,
|
||||
n_ins=n_ins,
|
||||
)
|
||||
|
||||
if dataset_wrapper is None:
|
||||
dataset_wrapper = DatasetWrapper
|
||||
|
||||
# Build data loader
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
||||
drop_last=is_train and len(data_source) >= batch_size,
|
||||
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
|
||||
)
|
||||
assert len(data_loader) > 0
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
class DataManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
custom_tfm_train=None,
|
||||
custom_tfm_test=None,
|
||||
dataset_wrapper=None
|
||||
):
|
||||
# Load dataset
|
||||
dataset = build_dataset(cfg)
|
||||
# Build transform
|
||||
if custom_tfm_train is None:
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
else:
|
||||
print("* Using custom transform for training")
|
||||
tfm_train = custom_tfm_train
|
||||
|
||||
if custom_tfm_test is None:
|
||||
tfm_test = build_transform(cfg, is_train=False)
|
||||
else:
|
||||
print("* Using custom transform for testing")
|
||||
tfm_test = custom_tfm_test
|
||||
|
||||
# Build train_loader_x
|
||||
train_loader_x = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=dataset.train_x,
|
||||
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build train_loader_u
|
||||
train_loader_u = None
|
||||
if dataset.train_u:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
|
||||
|
||||
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
|
||||
|
||||
train_loader_u = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=sampler_type_,
|
||||
data_source=dataset.train_u,
|
||||
batch_size=batch_size_,
|
||||
n_domain=n_domain_,
|
||||
n_ins=n_ins_,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build val_loader
|
||||
val_loader = None
|
||||
if dataset.val:
|
||||
val_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.val,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build test_loader
|
||||
test_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.test,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Attributes
|
||||
self._num_classes = dataset.num_classes
|
||||
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
|
||||
self._lab2cname = dataset.lab2cname
|
||||
|
||||
# Dataset and data-loaders
|
||||
self.dataset = dataset
|
||||
self.train_loader_x = train_loader_x
|
||||
self.train_loader_u = train_loader_u
|
||||
self.val_loader = val_loader
|
||||
self.test_loader = test_loader
|
||||
|
||||
if cfg.VERBOSE:
|
||||
self.show_dataset_summary(cfg)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._num_classes
|
||||
|
||||
@property
|
||||
def num_source_domains(self):
|
||||
return self._num_source_domains
|
||||
|
||||
@property
|
||||
def lab2cname(self):
|
||||
return self._lab2cname
|
||||
|
||||
def show_dataset_summary(self, cfg):
|
||||
print("***** Dataset statistics *****")
|
||||
|
||||
print(" Dataset: {}".format(cfg.DATASET.NAME))
|
||||
|
||||
if cfg.DATASET.SOURCE_DOMAINS:
|
||||
print(" Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS))
|
||||
if cfg.DATASET.TARGET_DOMAINS:
|
||||
print(" Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS))
|
||||
|
||||
print(" # classes: {:,}".format(self.num_classes))
|
||||
|
||||
print(" # train_x: {:,}".format(len(self.dataset.train_x)))
|
||||
|
||||
if self.dataset.train_u:
|
||||
print(" # train_u: {:,}".format(len(self.dataset.train_u)))
|
||||
|
||||
if self.dataset.val:
|
||||
print(" # val: {:,}".format(len(self.dataset.val)))
|
||||
|
||||
print(" # test: {:,}".format(len(self.dataset.test)))
|
||||
|
||||
|
||||
class DatasetWrapper(TorchDataset):
|
||||
|
||||
def __init__(self, cfg, data_source, transform=None, is_train=False):
|
||||
self.cfg = cfg
|
||||
self.data_source = data_source
|
||||
self.transform = transform # accept list (tuple) as input
|
||||
self.is_train = is_train
|
||||
# Augmenting an image K>1 times is only allowed during training
|
||||
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
||||
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
||||
|
||||
if self.k_tfm > 1 and transform is None:
|
||||
raise ValueError(
|
||||
"Cannot augment the image {} times "
|
||||
"because transform is None".format(self.k_tfm)
|
||||
)
|
||||
|
||||
# Build transform that doesn't apply any data augmentation
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
to_tensor = []
|
||||
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
to_tensor += [T.ToTensor()]
|
||||
if "normalize" in cfg.INPUT.TRANSFORMS:
|
||||
normalize = T.Normalize(
|
||||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||||
)
|
||||
to_tensor += [normalize]
|
||||
self.to_tensor = T.Compose(to_tensor)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_source[idx]
|
||||
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath
|
||||
}
|
||||
|
||||
img0 = read_image(item.impath)
|
||||
|
||||
if self.transform is not None:
|
||||
if isinstance(self.transform, (list, tuple)):
|
||||
for i, tfm in enumerate(self.transform):
|
||||
img = self._transform_image(tfm, img0)
|
||||
keyname = "img"
|
||||
if (i + 1) > 1:
|
||||
keyname += str(i + 1)
|
||||
output[keyname] = img
|
||||
else:
|
||||
img = self._transform_image(self.transform, img0)
|
||||
output["img"] = img
|
||||
|
||||
if self.return_img0:
|
||||
output["img0"] = self.to_tensor(img0)
|
||||
|
||||
return output
|
||||
|
||||
def _transform_image(self, tfm, img0):
|
||||
img_list = []
|
||||
|
||||
for k in range(self.k_tfm):
|
||||
img_list.append(tfm(img0))
|
||||
|
||||
img = img_list
|
||||
if len(img) == 1:
|
||||
img = img[0]
|
||||
|
||||
return img
|
||||
6
Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .build import DATASET_REGISTRY, build_dataset # isort:skip
|
||||
from .base_dataset import Datum, DatasetBase # isort:skip
|
||||
|
||||
from .da import *
|
||||
from .dg import *
|
||||
from .ssl import *
|
||||
225
Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
Normal file
225
Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
import random
|
||||
import os.path as osp
|
||||
import tarfile
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
import gdown
|
||||
|
||||
from dassl.utils import check_isfile
|
||||
|
||||
|
||||
class Datum:
|
||||
"""Data instance which defines the basic attributes.
|
||||
|
||||
Args:
|
||||
impath (str): image path.
|
||||
label (int): class label.
|
||||
domain (int): domain label.
|
||||
classname (str): class name.
|
||||
"""
|
||||
|
||||
def __init__(self, impath="", label=0, domain=0, classname=""):
|
||||
assert isinstance(impath, str)
|
||||
assert check_isfile(impath)
|
||||
|
||||
self._impath = impath
|
||||
self._label = label
|
||||
self._domain = domain
|
||||
self._classname = classname
|
||||
|
||||
@property
|
||||
def impath(self):
|
||||
return self._impath
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return self._label
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
return self._domain
|
||||
|
||||
@property
|
||||
def classname(self):
|
||||
return self._classname
|
||||
|
||||
|
||||
class DatasetBase:
|
||||
"""A unified dataset class for
|
||||
1) domain adaptation
|
||||
2) domain generalization
|
||||
3) semi-supervised learning
|
||||
"""
|
||||
|
||||
dataset_dir = "" # the directory where the dataset is stored
|
||||
domains = [] # string names of all domains
|
||||
|
||||
def __init__(self, train_x=None, train_u=None, val=None, test=None):
|
||||
self._train_x = train_x # labeled training data
|
||||
self._train_u = train_u # unlabeled training data (optional)
|
||||
self._val = val # validation data (optional)
|
||||
self._test = test # test data
|
||||
|
||||
self._num_classes = self.get_num_classes(train_x)
|
||||
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
|
||||
|
||||
@property
|
||||
def train_x(self):
|
||||
return self._train_x
|
||||
|
||||
@property
|
||||
def train_u(self):
|
||||
return self._train_u
|
||||
|
||||
@property
|
||||
def val(self):
|
||||
return self._val
|
||||
|
||||
@property
|
||||
def test(self):
|
||||
return self._test
|
||||
|
||||
@property
|
||||
def lab2cname(self):
|
||||
return self._lab2cname
|
||||
|
||||
@property
|
||||
def classnames(self):
|
||||
return self._classnames
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._num_classes
|
||||
|
||||
def get_num_classes(self, data_source):
|
||||
"""Count number of classes.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
label_set = set()
|
||||
for item in data_source:
|
||||
label_set.add(item.label)
|
||||
return max(label_set) + 1
|
||||
|
||||
def get_lab2cname(self, data_source):
|
||||
"""Get a label-to-classname mapping (dict).
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
container = set()
|
||||
for item in data_source:
|
||||
container.add((item.label, item.classname))
|
||||
mapping = {label: classname for label, classname in container}
|
||||
labels = list(mapping.keys())
|
||||
labels.sort()
|
||||
classnames = [mapping[label] for label in labels]
|
||||
return mapping, classnames
|
||||
|
||||
def check_input_domains(self, source_domains, target_domains):
|
||||
self.is_input_domain_valid(source_domains)
|
||||
self.is_input_domain_valid(target_domains)
|
||||
|
||||
def is_input_domain_valid(self, input_domains):
|
||||
for domain in input_domains:
|
||||
if domain not in self.domains:
|
||||
raise ValueError(
|
||||
"Input domain must belong to {}, "
|
||||
"but got [{}]".format(self.domains, domain)
|
||||
)
|
||||
|
||||
def download_data(self, url, dst, from_gdrive=True):
|
||||
if not osp.exists(osp.dirname(dst)):
|
||||
os.makedirs(osp.dirname(dst))
|
||||
|
||||
if from_gdrive:
|
||||
gdown.download(url, dst, quiet=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
print("Extracting file ...")
|
||||
|
||||
try:
|
||||
tar = tarfile.open(dst)
|
||||
tar.extractall(path=osp.dirname(dst))
|
||||
tar.close()
|
||||
except:
|
||||
zip_ref = zipfile.ZipFile(dst, "r")
|
||||
zip_ref.extractall(osp.dirname(dst))
|
||||
zip_ref.close()
|
||||
|
||||
print("File extracted to {}".format(osp.dirname(dst)))
|
||||
|
||||
def generate_fewshot_dataset(
|
||||
self, *data_sources, num_shots=-1, repeat=False
|
||||
):
|
||||
"""Generate a few-shot dataset (typically for the training set).
|
||||
|
||||
This function is useful when one wants to evaluate a model
|
||||
in a few-shot learning setting where each class only contains
|
||||
a few number of images.
|
||||
|
||||
Args:
|
||||
data_sources: each individual is a list containing Datum objects.
|
||||
num_shots (int): number of instances per class to sample.
|
||||
repeat (bool): repeat images if needed (default: False).
|
||||
"""
|
||||
if num_shots < 1:
|
||||
if len(data_sources) == 1:
|
||||
return data_sources[0]
|
||||
return data_sources
|
||||
|
||||
print(f"Creating a {num_shots}-shot dataset")
|
||||
|
||||
output = []
|
||||
|
||||
for data_source in data_sources:
|
||||
tracker = self.split_dataset_by_label(data_source)
|
||||
dataset = []
|
||||
|
||||
for label, items in tracker.items():
|
||||
if len(items) >= num_shots:
|
||||
sampled_items = random.sample(items, num_shots)
|
||||
else:
|
||||
if repeat:
|
||||
sampled_items = random.choices(items, k=num_shots)
|
||||
else:
|
||||
sampled_items = items
|
||||
dataset.extend(sampled_items)
|
||||
|
||||
output.append(dataset)
|
||||
|
||||
if len(output) == 1:
|
||||
return output[0]
|
||||
|
||||
return output
|
||||
|
||||
def split_dataset_by_label(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into class-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
output[item.label].append(item)
|
||||
|
||||
return output
|
||||
|
||||
def split_dataset_by_domain(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into domain-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
output[item.domain].append(item)
|
||||
|
||||
return output
|
||||
11
Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
DATASET_REGISTRY = Registry("DATASET")
|
||||
|
||||
|
||||
def build_dataset(cfg):
|
||||
avai_datasets = DATASET_REGISTRY.registered_names()
|
||||
check_availability(cfg.DATASET.NAME, avai_datasets)
|
||||
if cfg.VERBOSE:
|
||||
print("Loading dataset: {}".format(cfg.DATASET.NAME))
|
||||
return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg)
|
||||
7
Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
Normal file
7
Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .digit5 import Digit5
|
||||
from .visda17 import VisDA17
|
||||
from .cifarstl import CIFARSTL
|
||||
from .office31 import Office31
|
||||
from .domainnet import DomainNet
|
||||
from .office_home import OfficeHome
|
||||
from .mini_domainnet import miniDomainNet
|
||||
68
Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
Normal file
68
Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFARSTL(DatasetBase):
|
||||
"""CIFAR-10 and STL-10.
|
||||
|
||||
CIFAR-10:
|
||||
- 60,000 32x32 colour images.
|
||||
- 10 classes, with 6,000 images per class.
|
||||
- 50,000 training images and 10,000 test images.
|
||||
- URL: https://www.cs.toronto.edu/~kriz/cifar.html.
|
||||
|
||||
STL-10:
|
||||
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
|
||||
monkey, ship, truck.
|
||||
- Images are 96x96 pixels, color.
|
||||
- 500 training images (10 pre-defined folds), 800 test images
|
||||
per class.
|
||||
- URL: https://cs.stanford.edu/~acoates/stl10/.
|
||||
|
||||
Reference:
|
||||
- Krizhevsky. Learning Multiple Layers of Features
|
||||
from Tiny Images. Tech report.
|
||||
- Coates et al. An Analysis of Single Layer Networks in
|
||||
Unsupervised Feature Learning. AISTATS 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "cifar_stl"
|
||||
domains = ["cifar", "stl"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
data_dir = osp.join(self.dataset_dir, dname, split)
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
|
||||
for class_name in class_names:
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
label = int(class_name.split("_")[0])
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label, domain=domain)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
124
Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import random
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
# Folder names for train and test sets
|
||||
MNIST = {"train": "train_images", "test": "test_images"}
|
||||
MNIST_M = {"train": "train_images", "test": "test_images"}
|
||||
SVHN = {"train": "train_images", "test": "test_images"}
|
||||
SYN = {"train": "train_images", "test": "test_images"}
|
||||
USPS = {"train": "train_images", "test": "test_images"}
|
||||
|
||||
|
||||
def read_image_list(im_dir, n_max=None, n_repeat=None):
|
||||
items = []
|
||||
|
||||
for imname in listdir_nohidden(im_dir):
|
||||
imname_noext = osp.splitext(imname)[0]
|
||||
label = int(imname_noext.split("_")[1])
|
||||
impath = osp.join(im_dir, imname)
|
||||
items.append((impath, label))
|
||||
|
||||
if n_max is not None:
|
||||
items = random.sample(items, n_max)
|
||||
|
||||
if n_repeat is not None:
|
||||
items *= n_repeat
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def load_mnist(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_mnist_m(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST_M[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_svhn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SVHN[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_syn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SYN[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_usps(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, USPS[split])
|
||||
n_repeat = 3 if split == "train" else None
|
||||
return read_image_list(data_dir, n_repeat=n_repeat)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Digit5(DatasetBase):
|
||||
"""Five digit datasets.
|
||||
|
||||
It contains:
|
||||
- MNIST: hand-written digits.
|
||||
- MNIST-M: variant of MNIST with blended background.
|
||||
- SVHN: street view house number.
|
||||
- SYN: synthetic digits.
|
||||
- USPS: hand-written digits, slightly different from MNIST.
|
||||
|
||||
For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from
|
||||
the training set and 9,000 images from the test set. For USPS which has only
|
||||
9,298 images in total, we use the entire dataset but replicate its training
|
||||
set for 3 times so as to match the training set size of other domains.
|
||||
|
||||
Reference:
|
||||
- Lecun et al. Gradient-based learning applied to document
|
||||
recognition. IEEE 1998.
|
||||
- Ganin et al. Domain-adversarial training of neural networks.
|
||||
JMLR 2016.
|
||||
- Netzer et al. Reading digits in natural images with unsupervised
|
||||
feature learning. NIPS-W 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "digit5"
|
||||
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
func = "load_" + dname
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
items_d = eval(func)(domain_dir, split=split)
|
||||
|
||||
for impath, label in items_d:
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=str(label)
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
69
Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
Normal file
69
Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DomainNet(DatasetBase):
|
||||
"""DomainNet.
|
||||
|
||||
Statistics:
|
||||
- 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,
|
||||
Real, Sketch.
|
||||
- Around 0.6M images.
|
||||
- 345 categories.
|
||||
- URL: http://ai.bu.edu/M3SDA/.
|
||||
|
||||
Special note: the t-shirt class (327) is missing in painting_train.txt.
|
||||
|
||||
Reference:
|
||||
- Peng et al. Moment Matching for Multi-Source Domain
|
||||
Adaptation. ICCV 2019.
|
||||
"""
|
||||
|
||||
dataset_dir = "domainnet"
|
||||
domains = [
|
||||
"clipart", "infograph", "painting", "quickdraw", "real", "sketch"
|
||||
]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.split_dir = osp.join(self.dataset_dir, "splits")
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
filename = dname + "_" + split + ".txt"
|
||||
split_file = osp.join(self.split_dir, filename)
|
||||
|
||||
with open(split_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[1]
|
||||
impath = osp.join(self.dataset_dir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,58 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class miniDomainNet(DatasetBase):
|
||||
"""A subset of DomainNet.
|
||||
|
||||
Reference:
|
||||
- Peng et al. Moment Matching for Multi-Source Domain
|
||||
Adaptation. ICCV 2019.
|
||||
- Zhou et al. Domain Adaptive Ensemble Learning.
|
||||
"""
|
||||
|
||||
dataset_dir = "domainnet"
|
||||
domains = ["clipart", "painting", "real", "sketch"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.split_dir = osp.join(self.dataset_dir, "splits_mini")
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
filename = dname + "_" + split + ".txt"
|
||||
split_file = osp.join(self.split_dir, filename)
|
||||
|
||||
with open(split_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[1]
|
||||
impath = osp.join(self.dataset_dir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
Normal file
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Office31(DatasetBase):
|
||||
"""Office-31.
|
||||
|
||||
Statistics:
|
||||
- 4,110 images.
|
||||
- 31 classes related to office objects.
|
||||
- 3 domains: Amazon, Webcam, Dslr.
|
||||
- URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.
|
||||
|
||||
Reference:
|
||||
- Saenko et al. Adapting visual category models to
|
||||
new domains. ECCV 2010.
|
||||
"""
|
||||
|
||||
dataset_dir = "office31"
|
||||
domains = ["amazon", "webcam", "dslr"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
class_names = listdir_nohidden(domain_dir)
|
||||
class_names.sort()
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_path = osp.join(domain_dir, class_name)
|
||||
imnames = listdir_nohidden(class_path)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_path, imname)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=class_name
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
Normal file
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OfficeHome(DatasetBase):
|
||||
"""Office-Home.
|
||||
|
||||
Statistics:
|
||||
- Around 15,500 images.
|
||||
- 65 classes related to office and home objects.
|
||||
- 4 domains: Art, Clipart, Product, Real World.
|
||||
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
|
||||
|
||||
Reference:
|
||||
- Venkateswara et al. Deep Hashing Network for Unsupervised
|
||||
Domain Adaptation. CVPR 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "office_home"
|
||||
domains = ["art", "clipart", "product", "real_world"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
class_names = listdir_nohidden(domain_dir)
|
||||
class_names.sort()
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_path = osp.join(domain_dir, class_name)
|
||||
imnames = listdir_nohidden(class_path)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_path, imname)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=class_name.lower(),
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
61
Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
Normal file
61
Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VisDA17(DatasetBase):
|
||||
"""VisDA17.
|
||||
|
||||
Focusing on simulation-to-reality domain shift.
|
||||
|
||||
URL: http://ai.bu.edu/visda-2017/.
|
||||
|
||||
Reference:
|
||||
- Peng et al. VisDA: The Visual Domain Adaptation
|
||||
Challenge. ArXiv 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "visda17"
|
||||
domains = ["synthetic", "real"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data("synthetic")
|
||||
train_u = self._read_data("real")
|
||||
test = self._read_data("real")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, dname):
|
||||
filedir = "train" if dname == "synthetic" else "validation"
|
||||
image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
|
||||
items = []
|
||||
# There is only one source domain
|
||||
domain = 0
|
||||
|
||||
with open(image_list, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[0]
|
||||
impath = osp.join(self.dataset_dir, filedir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
6
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .pacs import PACS
|
||||
from .vlcs import VLCS
|
||||
from .cifar_c import CIFAR10C, CIFAR100C
|
||||
from .digits_dg import DigitsDG
|
||||
from .digit_single import DigitSingle
|
||||
from .office_home_dg import OfficeHomeDG
|
||||
123
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
Normal file
123
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
AVAI_C_TYPES = [
|
||||
"brightness",
|
||||
"contrast",
|
||||
"defocus_blur",
|
||||
"elastic_transform",
|
||||
"fog",
|
||||
"frost",
|
||||
"gaussian_blur",
|
||||
"gaussian_noise",
|
||||
"glass_blur",
|
||||
"impulse_noise",
|
||||
"jpeg_compression",
|
||||
"motion_blur",
|
||||
"pixelate",
|
||||
"saturate",
|
||||
"shot_noise",
|
||||
"snow",
|
||||
"spatter",
|
||||
"speckle_noise",
|
||||
"zoom_blur",
|
||||
]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFAR10C(DatasetBase):
|
||||
"""CIFAR-10 -> CIFAR-10-C.
|
||||
|
||||
Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o
|
||||
|
||||
Statistics:
|
||||
- 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10
|
||||
- 10 categories
|
||||
|
||||
Reference:
|
||||
- Hendrycks et al. Benchmarking neural network robustness
|
||||
to common corruptions and perturbations. ICLR 2019.
|
||||
"""
|
||||
|
||||
dataset_dir = ""
|
||||
domains = ["cifar10", "cifar10_c"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = root
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
source_domain = cfg.DATASET.SOURCE_DOMAINS[0]
|
||||
target_domain = cfg.DATASET.TARGET_DOMAINS[0]
|
||||
assert source_domain == self.domains[0]
|
||||
assert target_domain == self.domains[1]
|
||||
|
||||
c_type = cfg.DATASET.CIFAR_C_TYPE
|
||||
c_level = cfg.DATASET.CIFAR_C_LEVEL
|
||||
|
||||
if not c_type:
|
||||
raise ValueError(
|
||||
"Please specify DATASET.CIFAR_C_TYPE in the config file"
|
||||
)
|
||||
|
||||
assert (
|
||||
c_type in AVAI_C_TYPES
|
||||
), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"'
|
||||
assert 1 <= c_level <= 5
|
||||
|
||||
train_dir = osp.join(self.dataset_dir, source_domain, "train")
|
||||
test_dir = osp.join(
|
||||
self.dataset_dir, target_domain, c_type, str(c_level)
|
||||
)
|
||||
|
||||
if not osp.exists(test_dir):
|
||||
raise ValueError
|
||||
|
||||
train = self._read_data(train_dir)
|
||||
test = self._read_data(test_dir)
|
||||
|
||||
super().__init__(train_x=train, test=test)
|
||||
|
||||
def _read_data(self, data_dir):
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
class_names.sort()
|
||||
items = []
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label, domain=0)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFAR100C(CIFAR10C):
|
||||
"""CIFAR-100 -> CIFAR-100-C.
|
||||
|
||||
Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o
|
||||
|
||||
Statistics:
|
||||
- 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100
|
||||
- 10 categories
|
||||
|
||||
Reference:
|
||||
- Hendrycks et al. Benchmarking neural network robustness
|
||||
to common corruptions and perturbations. ICLR 2019.
|
||||
"""
|
||||
|
||||
dataset_dir = ""
|
||||
domains = ["cifar100", "cifar100_c"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
124
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
# Folder names for train and test sets
|
||||
MNIST = {"train": "train_images", "test": "test_images"}
|
||||
MNIST_M = {"train": "train_images", "test": "test_images"}
|
||||
SVHN = {"train": "train_images", "test": "test_images"}
|
||||
SYN = {"train": "train_images", "test": "test_images"}
|
||||
USPS = {"train": "train_images", "test": "test_images"}
|
||||
|
||||
|
||||
def read_image_list(im_dir, n_max=None, n_repeat=None):
|
||||
items = []
|
||||
|
||||
for imname in listdir_nohidden(im_dir):
|
||||
imname_noext = osp.splitext(imname)[0]
|
||||
label = int(imname_noext.split("_")[1])
|
||||
impath = osp.join(im_dir, imname)
|
||||
items.append((impath, label))
|
||||
|
||||
if n_max is not None:
|
||||
# Note that the sampling process is NOT random,
|
||||
# which follows that in Volpi et al. NIPS'18.
|
||||
items = items[:n_max]
|
||||
|
||||
if n_repeat is not None:
|
||||
items *= n_repeat
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def load_mnist(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST[split])
|
||||
n_max = 10000 if split == "train" else None
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_mnist_m(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST_M[split])
|
||||
n_max = 10000 if split == "train" else None
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_svhn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SVHN[split])
|
||||
n_max = 10000 if split == "train" else None
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_syn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SYN[split])
|
||||
n_max = 10000 if split == "train" else None
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_usps(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, USPS[split])
|
||||
return read_image_list(data_dir)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DigitSingle(DatasetBase):
|
||||
"""Digit recognition datasets for single-source domain generalization.
|
||||
|
||||
There are five digit datasets:
|
||||
- MNIST: hand-written digits.
|
||||
- MNIST-M: variant of MNIST with blended background.
|
||||
- SVHN: street view house number.
|
||||
- SYN: synthetic digits.
|
||||
- USPS: hand-written digits, slightly different from MNIST.
|
||||
|
||||
Protocol:
|
||||
Volpi et al. train a model using 10,000 images from MNIST and
|
||||
evaluate the model on the test split of the other four datasets. However,
|
||||
the code does not restrict you to only use MNIST as the source dataset.
|
||||
Instead, you can use any dataset as the source. But note that only 10,000
|
||||
images will be sampled from the source dataset for training.
|
||||
|
||||
Reference:
|
||||
- Lecun et al. Gradient-based learning applied to document
|
||||
recognition. IEEE 1998.
|
||||
- Ganin et al. Domain-adversarial training of neural networks.
|
||||
JMLR 2016.
|
||||
- Netzer et al. Reading digits in natural images with unsupervised
|
||||
feature learning. NIPS-W 2011.
|
||||
- Volpi et al. Generalizing to Unseen Domains via Adversarial Data
|
||||
Augmentation. NIPS 2018.
|
||||
"""
|
||||
|
||||
# Reuse the digit-5 folder instead of creating a new folder
|
||||
dataset_dir = "digit5"
|
||||
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
func = "load_" + dname
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
items_d = eval(func)(domain_dir, split=split)
|
||||
|
||||
for impath, label in items_d:
|
||||
item = Datum(impath=impath, label=label, domain=domain)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
97
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
Normal file
97
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import glob
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DigitsDG(DatasetBase):
|
||||
"""Digits-DG.
|
||||
|
||||
It contains 4 digit datasets:
|
||||
- MNIST: hand-written digits.
|
||||
- MNIST-M: variant of MNIST with blended background.
|
||||
- SVHN: street view house number.
|
||||
- SYN: synthetic digits.
|
||||
|
||||
Reference:
|
||||
- Lecun et al. Gradient-based learning applied to document
|
||||
recognition. IEEE 1998.
|
||||
- Ganin et al. Domain-adversarial training of neural networks.
|
||||
JMLR 2016.
|
||||
- Netzer et al. Reading digits in natural images with unsupervised
|
||||
feature learning. NIPS-W 2011.
|
||||
- Zhou et al. Deep Domain-Adversarial Image Generation for Domain
|
||||
Generalisation. AAAI 2020.
|
||||
"""
|
||||
|
||||
dataset_dir = "digits_dg"
|
||||
domains = ["mnist", "mnist_m", "svhn", "syn"]
|
||||
data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
if not osp.exists(self.dataset_dir):
|
||||
dst = osp.join(root, "digits_dg.zip")
|
||||
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = self.read_data(
|
||||
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
|
||||
)
|
||||
val = self.read_data(
|
||||
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
|
||||
)
|
||||
test = self.read_data(
|
||||
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
|
||||
)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
@staticmethod
|
||||
def read_data(dataset_dir, input_domains, split):
|
||||
|
||||
def _load_data_from_directory(directory):
|
||||
folders = listdir_nohidden(directory)
|
||||
folders.sort()
|
||||
items_ = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
impaths = glob.glob(osp.join(directory, folder, "*.jpg"))
|
||||
|
||||
for impath in impaths:
|
||||
items_.append((impath, label))
|
||||
|
||||
return items_
|
||||
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
if split == "all":
|
||||
train_dir = osp.join(dataset_dir, dname, "train")
|
||||
impath_label_list = _load_data_from_directory(train_dir)
|
||||
val_dir = osp.join(dataset_dir, dname, "val")
|
||||
impath_label_list += _load_data_from_directory(val_dir)
|
||||
else:
|
||||
split_dir = osp.join(dataset_dir, dname, split)
|
||||
impath_label_list = _load_data_from_directory(split_dir)
|
||||
|
||||
for impath, label in impath_label_list:
|
||||
class_name = impath.split("/")[-2].lower()
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=class_name
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,49 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from .digits_dg import DigitsDG
|
||||
from ..base_dataset import DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OfficeHomeDG(DatasetBase):
|
||||
"""Office-Home.
|
||||
|
||||
Statistics:
|
||||
- Around 15,500 images.
|
||||
- 65 classes related to office and home objects.
|
||||
- 4 domains: Art, Clipart, Product, Real World.
|
||||
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
|
||||
|
||||
Reference:
|
||||
- Venkateswara et al. Deep Hashing Network for Unsupervised
|
||||
Domain Adaptation. CVPR 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "office_home_dg"
|
||||
domains = ["art", "clipart", "product", "real_world"]
|
||||
data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
if not osp.exists(self.dataset_dir):
|
||||
dst = osp.join(root, "office_home_dg.zip")
|
||||
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = DigitsDG.read_data(
|
||||
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
|
||||
)
|
||||
val = DigitsDG.read_data(
|
||||
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
|
||||
)
|
||||
test = DigitsDG.read_data(
|
||||
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
|
||||
)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
94
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
Normal file
94
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PACS(DatasetBase):
|
||||
"""PACS.
|
||||
|
||||
Statistics:
|
||||
- 4 domains: Photo (1,670), Art (2,048), Cartoon
|
||||
(2,344), Sketch (3,929).
|
||||
- 7 categories: dog, elephant, giraffe, guitar, horse,
|
||||
house and person.
|
||||
|
||||
Reference:
|
||||
- Li et al. Deeper, broader and artier domain generalization.
|
||||
ICCV 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "pacs"
|
||||
domains = ["art_painting", "cartoon", "photo", "sketch"]
|
||||
data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE"
|
||||
# the following images contain errors and should be ignored
|
||||
_error_paths = ["sketch/dog/n02103406_4068-1.png"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.image_dir = osp.join(self.dataset_dir, "images")
|
||||
self.split_dir = osp.join(self.dataset_dir, "splits")
|
||||
|
||||
if not osp.exists(self.dataset_dir):
|
||||
dst = osp.join(root, "pacs.zip")
|
||||
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all")
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
if split == "all":
|
||||
file_train = osp.join(
|
||||
self.split_dir, dname + "_train_kfold.txt"
|
||||
)
|
||||
impath_label_list = self._read_split_pacs(file_train)
|
||||
file_val = osp.join(
|
||||
self.split_dir, dname + "_crossval_kfold.txt"
|
||||
)
|
||||
impath_label_list += self._read_split_pacs(file_val)
|
||||
else:
|
||||
file = osp.join(
|
||||
self.split_dir, dname + "_" + split + "_kfold.txt"
|
||||
)
|
||||
impath_label_list = self._read_split_pacs(file)
|
||||
|
||||
for impath, label in impath_label_list:
|
||||
classname = impath.split("/")[-2]
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
def _read_split_pacs(self, split_file):
|
||||
items = []
|
||||
|
||||
with open(split_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
if impath in self._error_paths:
|
||||
continue
|
||||
impath = osp.join(self.image_dir, impath)
|
||||
label = int(label) - 1
|
||||
items.append((impath, label))
|
||||
|
||||
return items
|
||||
60
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
Normal file
60
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import glob
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VLCS(DatasetBase):
|
||||
"""VLCS.
|
||||
|
||||
Statistics:
|
||||
- 4 domains: CALTECH, LABELME, PASCAL, SUN
|
||||
- 5 categories: bird, car, chair, dog, and person.
|
||||
|
||||
Reference:
|
||||
- Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "VLCS"
|
||||
domains = ["caltech", "labelme", "pascal", "sun"]
|
||||
data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
if not osp.exists(self.dataset_dir):
|
||||
dst = osp.join(root, "vlcs.zip")
|
||||
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
dname = dname.upper()
|
||||
path = osp.join(self.dataset_dir, dname, split)
|
||||
folders = listdir_nohidden(path)
|
||||
folders.sort()
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
impaths = glob.glob(osp.join(path, folder, "*.jpg"))
|
||||
|
||||
for impath in impaths:
|
||||
item = Datum(impath=impath, label=label, domain=domain)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,3 @@
|
||||
from .svhn import SVHN
|
||||
from .cifar import CIFAR10, CIFAR100
|
||||
from .stl10 import STL10
|
||||
108
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
Normal file
108
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import math
|
||||
import random
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFAR10(DatasetBase):
|
||||
"""CIFAR10 for SSL.
|
||||
|
||||
Reference:
|
||||
- Krizhevsky. Learning Multiple Layers of Features
|
||||
from Tiny Images. Tech report.
|
||||
"""
|
||||
|
||||
dataset_dir = "cifar10"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
train_dir = osp.join(self.dataset_dir, "train")
|
||||
test_dir = osp.join(self.dataset_dir, "test")
|
||||
|
||||
assert cfg.DATASET.NUM_LABELED > 0
|
||||
|
||||
train_x, train_u, val = self._read_data_train(
|
||||
train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT
|
||||
)
|
||||
test = self._read_data_test(test_dir)
|
||||
|
||||
if cfg.DATASET.ALL_AS_UNLABELED:
|
||||
train_u = train_u + train_x
|
||||
|
||||
if len(val) == 0:
|
||||
val = None
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
|
||||
|
||||
def _read_data_train(self, data_dir, num_labeled, val_percent):
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
class_names.sort()
|
||||
num_labeled_per_class = num_labeled / len(class_names)
|
||||
items_x, items_u, items_v = [], [], []
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
|
||||
# Split into train and val following Oliver et al. 2018
|
||||
# Set cfg.DATASET.VAL_PERCENT to 0 to not use val data
|
||||
num_val = math.floor(len(imnames) * val_percent)
|
||||
imnames_train = imnames[num_val:]
|
||||
imnames_val = imnames[:num_val]
|
||||
|
||||
# Note we do shuffle after split
|
||||
random.shuffle(imnames_train)
|
||||
|
||||
for i, imname in enumerate(imnames_train):
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label)
|
||||
|
||||
if (i + 1) <= num_labeled_per_class:
|
||||
items_x.append(item)
|
||||
|
||||
else:
|
||||
items_u.append(item)
|
||||
|
||||
for imname in imnames_val:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label)
|
||||
items_v.append(item)
|
||||
|
||||
return items_x, items_u, items_v
|
||||
|
||||
def _read_data_test(self, data_dir):
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
class_names.sort()
|
||||
items = []
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFAR100(CIFAR10):
|
||||
"""CIFAR100 for SSL.
|
||||
|
||||
Reference:
|
||||
- Krizhevsky. Learning Multiple Layers of Features
|
||||
from Tiny Images. Tech report.
|
||||
"""
|
||||
|
||||
dataset_dir = "cifar100"
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
87
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
Normal file
87
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class STL10(DatasetBase):
|
||||
"""STL-10 dataset.
|
||||
|
||||
Description:
|
||||
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
|
||||
monkey, ship, truck.
|
||||
- Images are 96x96 pixels, color.
|
||||
- 500 training images per class, 800 test images per class.
|
||||
- 100,000 unlabeled images for unsupervised learning.
|
||||
|
||||
Reference:
|
||||
- Coates et al. An Analysis of Single Layer Networks in
|
||||
Unsupervised Feature Learning. AISTATS 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "stl10"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
train_dir = osp.join(self.dataset_dir, "train")
|
||||
test_dir = osp.join(self.dataset_dir, "test")
|
||||
unlabeled_dir = osp.join(self.dataset_dir, "unlabeled")
|
||||
fold_file = osp.join(
|
||||
self.dataset_dir, "stl10_binary", "fold_indices.txt"
|
||||
)
|
||||
|
||||
# Only use the first five splits
|
||||
assert 0 <= cfg.DATASET.STL10_FOLD <= 4
|
||||
|
||||
train_x = self._read_data_train(
|
||||
train_dir, cfg.DATASET.STL10_FOLD, fold_file
|
||||
)
|
||||
train_u = self._read_data_all(unlabeled_dir)
|
||||
test = self._read_data_all(test_dir)
|
||||
|
||||
if cfg.DATASET.ALL_AS_UNLABELED:
|
||||
train_u = train_u + train_x
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data_train(self, data_dir, fold, fold_file):
|
||||
imnames = listdir_nohidden(data_dir)
|
||||
imnames.sort()
|
||||
items = []
|
||||
|
||||
list_idx = list(range(len(imnames)))
|
||||
if fold >= 0:
|
||||
with open(fold_file, "r") as f:
|
||||
str_idx = f.read().splitlines()[fold]
|
||||
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ")
|
||||
|
||||
for i in list_idx:
|
||||
imname = imnames[i]
|
||||
impath = osp.join(data_dir, imname)
|
||||
label = osp.splitext(imname)[0].split("_")[1]
|
||||
label = int(label)
|
||||
item = Datum(impath=impath, label=label)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
def _read_data_all(self, data_dir):
|
||||
imnames = listdir_nohidden(data_dir)
|
||||
items = []
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(data_dir, imname)
|
||||
label = osp.splitext(imname)[0].split("_")[1]
|
||||
if label == "none":
|
||||
label = -1
|
||||
else:
|
||||
label = int(label)
|
||||
item = Datum(impath=impath, label=label)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
17
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
Normal file
17
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .cifar import CIFAR10
|
||||
from ..build import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SVHN(CIFAR10):
|
||||
"""SVHN for SSL.
|
||||
|
||||
Reference:
|
||||
- Netzer et al. Reading Digits in Natural Images with
|
||||
Unsupervised Feature Learning. NIPS-W 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "svhn"
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
205
Dassl.ProGrad.pytorch/dassl/data/samplers.py
Normal file
205
Dassl.ProGrad.pytorch/dassl/data/samplers.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import copy
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
|
||||
|
||||
|
||||
class RandomDomainSampler(Sampler):
|
||||
"""Randomly samples N domains each with K images
|
||||
to form a minibatch of size N*K.
|
||||
|
||||
Args:
|
||||
data_source (list): list of Datums.
|
||||
batch_size (int): batch size.
|
||||
n_domain (int): number of domains to sample in a minibatch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, n_domain):
|
||||
self.data_source = data_source
|
||||
|
||||
# Keep track of image indices for each domain
|
||||
self.domain_dict = defaultdict(list)
|
||||
for i, item in enumerate(data_source):
|
||||
self.domain_dict[item.domain].append(i)
|
||||
self.domains = list(self.domain_dict.keys())
|
||||
|
||||
# Make sure each domain has equal number of images
|
||||
if n_domain is None or n_domain <= 0:
|
||||
n_domain = len(self.domains)
|
||||
assert batch_size % n_domain == 0
|
||||
self.n_img_per_domain = batch_size // n_domain
|
||||
|
||||
self.batch_size = batch_size
|
||||
# n_domain denotes number of domains sampled in a minibatch
|
||||
self.n_domain = n_domain
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
def __iter__(self):
|
||||
domain_dict = copy.deepcopy(self.domain_dict)
|
||||
final_idxs = []
|
||||
stop_sampling = False
|
||||
|
||||
while not stop_sampling:
|
||||
selected_domains = random.sample(self.domains, self.n_domain)
|
||||
|
||||
for domain in selected_domains:
|
||||
idxs = domain_dict[domain]
|
||||
selected_idxs = random.sample(idxs, self.n_img_per_domain)
|
||||
final_idxs.extend(selected_idxs)
|
||||
|
||||
for idx in selected_idxs:
|
||||
domain_dict[domain].remove(idx)
|
||||
|
||||
remaining = len(domain_dict[domain])
|
||||
if remaining < self.n_img_per_domain:
|
||||
stop_sampling = True
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class SeqDomainSampler(Sampler):
|
||||
"""Sequential domain sampler, which randomly samples K
|
||||
images from each domain to form a minibatch.
|
||||
|
||||
Args:
|
||||
data_source (list): list of Datums.
|
||||
batch_size (int): batch size.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size):
|
||||
self.data_source = data_source
|
||||
|
||||
# Keep track of image indices for each domain
|
||||
self.domain_dict = defaultdict(list)
|
||||
for i, item in enumerate(data_source):
|
||||
self.domain_dict[item.domain].append(i)
|
||||
self.domains = list(self.domain_dict.keys())
|
||||
self.domains.sort()
|
||||
|
||||
# Make sure each domain has equal number of images
|
||||
n_domain = len(self.domains)
|
||||
assert batch_size % n_domain == 0
|
||||
self.n_img_per_domain = batch_size // n_domain
|
||||
|
||||
self.batch_size = batch_size
|
||||
# n_domain denotes number of domains sampled in a minibatch
|
||||
self.n_domain = n_domain
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
def __iter__(self):
|
||||
domain_dict = copy.deepcopy(self.domain_dict)
|
||||
final_idxs = []
|
||||
stop_sampling = False
|
||||
|
||||
while not stop_sampling:
|
||||
for domain in self.domains:
|
||||
idxs = domain_dict[domain]
|
||||
selected_idxs = random.sample(idxs, self.n_img_per_domain)
|
||||
final_idxs.extend(selected_idxs)
|
||||
|
||||
for idx in selected_idxs:
|
||||
domain_dict[domain].remove(idx)
|
||||
|
||||
remaining = len(domain_dict[domain])
|
||||
if remaining < self.n_img_per_domain:
|
||||
stop_sampling = True
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class RandomClassSampler(Sampler):
|
||||
"""Randomly samples N classes each with K instances to
|
||||
form a minibatch of size N*K.
|
||||
|
||||
Modified from https://github.com/KaiyangZhou/deep-person-reid.
|
||||
|
||||
Args:
|
||||
data_source (list): list of Datums.
|
||||
batch_size (int): batch size.
|
||||
n_ins (int): number of instances per class to sample in a minibatch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, n_ins):
|
||||
if batch_size < n_ins:
|
||||
raise ValueError(
|
||||
"batch_size={} must be no less "
|
||||
"than n_ins={}".format(batch_size, n_ins)
|
||||
)
|
||||
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.n_ins = n_ins
|
||||
self.ncls_per_batch = self.batch_size // self.n_ins
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, item in enumerate(data_source):
|
||||
self.index_dic[item.label].append(index)
|
||||
self.labels = list(self.index_dic.keys())
|
||||
assert len(self.labels) >= self.ncls_per_batch
|
||||
|
||||
# estimate number of images in an epoch
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
def __iter__(self):
|
||||
batch_idxs_dict = defaultdict(list)
|
||||
|
||||
for label in self.labels:
|
||||
idxs = copy.deepcopy(self.index_dic[label])
|
||||
if len(idxs) < self.n_ins:
|
||||
idxs = np.random.choice(idxs, size=self.n_ins, replace=True)
|
||||
random.shuffle(idxs)
|
||||
batch_idxs = []
|
||||
for idx in idxs:
|
||||
batch_idxs.append(idx)
|
||||
if len(batch_idxs) == self.n_ins:
|
||||
batch_idxs_dict[label].append(batch_idxs)
|
||||
batch_idxs = []
|
||||
|
||||
avai_labels = copy.deepcopy(self.labels)
|
||||
final_idxs = []
|
||||
|
||||
while len(avai_labels) >= self.ncls_per_batch:
|
||||
selected_labels = random.sample(avai_labels, self.ncls_per_batch)
|
||||
for label in selected_labels:
|
||||
batch_idxs = batch_idxs_dict[label].pop(0)
|
||||
final_idxs.extend(batch_idxs)
|
||||
if len(batch_idxs_dict[label]) == 0:
|
||||
avai_labels.remove(label)
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def build_sampler(
|
||||
sampler_type,
|
||||
cfg=None,
|
||||
data_source=None,
|
||||
batch_size=32,
|
||||
n_domain=0,
|
||||
n_ins=16
|
||||
):
|
||||
if sampler_type == "RandomSampler":
|
||||
return RandomSampler(data_source)
|
||||
|
||||
elif sampler_type == "SequentialSampler":
|
||||
return SequentialSampler(data_source)
|
||||
|
||||
elif sampler_type == "RandomDomainSampler":
|
||||
return RandomDomainSampler(data_source, batch_size, n_domain)
|
||||
|
||||
elif sampler_type == "SeqDomainSampler":
|
||||
return SeqDomainSampler(data_source, batch_size)
|
||||
|
||||
elif sampler_type == "RandomClassSampler":
|
||||
return RandomClassSampler(data_source, batch_size, n_ins)
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown sampler type: {}".format(sampler_type))
|
||||
1
Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py
Normal file
1
Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .transforms import build_transform
|
||||
273
Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
Normal file
273
Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Source: https://github.com/DeepVoltaire/AutoAugment
|
||||
"""
|
||||
import numpy as np
|
||||
import random
|
||||
from PIL import Image, ImageOps, ImageEnhance
|
||||
|
||||
|
||||
class ImageNetPolicy:
|
||||
"""Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||
|
||||
Example:
|
||||
>>> policy = ImageNetPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> ImageNetPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class CIFAR10Policy:
|
||||
"""Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||
|
||||
Example:
|
||||
>>> policy = CIFAR10Policy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> CIFAR10Policy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class SVHNPolicy:
|
||||
"""Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||
|
||||
Example:
|
||||
>>> policy = SVHNPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> SVHNPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor),
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
p1,
|
||||
operation1,
|
||||
magnitude_idx1,
|
||||
p2,
|
||||
operation2,
|
||||
magnitude_idx2,
|
||||
fillcolor=(128, 128, 128),
|
||||
):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10,
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(
|
||||
rot, Image.new("RGBA", rot.size, (128, ) * 4), rot
|
||||
).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX":
|
||||
lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor,
|
||||
),
|
||||
"shearY":
|
||||
lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor,
|
||||
),
|
||||
"translateX":
|
||||
lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(
|
||||
1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0,
|
||||
1, 0
|
||||
),
|
||||
fillcolor=fillcolor,
|
||||
),
|
||||
"translateY":
|
||||
lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(
|
||||
1, 0, 0, 0, 1, magnitude * img.size[1] * random.
|
||||
choice([-1, 1])
|
||||
),
|
||||
fillcolor=fillcolor,
|
||||
),
|
||||
"rotate":
|
||||
lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color":
|
||||
lambda img, magnitude: ImageEnhance.Color(img).
|
||||
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize":
|
||||
lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize":
|
||||
lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast":
|
||||
lambda img, magnitude: ImageEnhance.Contrast(img).
|
||||
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness":
|
||||
lambda img, magnitude: ImageEnhance.Sharpness(img).
|
||||
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness":
|
||||
lambda img, magnitude: ImageEnhance.Brightness(img).
|
||||
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast":
|
||||
lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize":
|
||||
lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert":
|
||||
lambda img, magnitude: ImageOps.invert(img),
|
||||
}
|
||||
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1:
|
||||
img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2:
|
||||
img = self.operation2(img, self.magnitude2)
|
||||
return img
|
||||
363
Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
Normal file
363
Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Credit to
|
||||
1) https://github.com/ildoonet/pytorch-randaugment
|
||||
2) https://github.com/kakaobrain/fast-autoaugment
|
||||
"""
|
||||
import numpy as np
|
||||
import random
|
||||
import PIL
|
||||
import torch
|
||||
import PIL.ImageOps
|
||||
import PIL.ImageDraw
|
||||
import PIL.ImageEnhance
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def ShearX(img, v):
|
||||
assert -0.3 <= v <= 0.3
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
|
||||
|
||||
|
||||
def ShearY(img, v):
|
||||
assert -0.3 <= v <= 0.3
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
|
||||
|
||||
|
||||
def TranslateX(img, v):
|
||||
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||
assert -0.45 <= v <= 0.45
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
v = v * img.size[0]
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
||||
|
||||
|
||||
def TranslateXabs(img, v):
|
||||
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||
assert 0 <= v
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
||||
|
||||
|
||||
def TranslateY(img, v):
|
||||
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||
assert -0.45 <= v <= 0.45
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
v = v * img.size[1]
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
||||
|
||||
|
||||
def TranslateYabs(img, v):
|
||||
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||
assert 0 <= v
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
||||
|
||||
|
||||
def Rotate(img, v):
|
||||
assert -30 <= v <= 30
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.rotate(v)
|
||||
|
||||
|
||||
def AutoContrast(img, _):
|
||||
return PIL.ImageOps.autocontrast(img)
|
||||
|
||||
|
||||
def Invert(img, _):
|
||||
return PIL.ImageOps.invert(img)
|
||||
|
||||
|
||||
def Equalize(img, _):
|
||||
return PIL.ImageOps.equalize(img)
|
||||
|
||||
|
||||
def Flip(img, _):
|
||||
return PIL.ImageOps.mirror(img)
|
||||
|
||||
|
||||
def Solarize(img, v):
|
||||
assert 0 <= v <= 256
|
||||
return PIL.ImageOps.solarize(img, v)
|
||||
|
||||
|
||||
def SolarizeAdd(img, addition=0, threshold=128):
|
||||
img_np = np.array(img).astype(np.int)
|
||||
img_np = img_np + addition
|
||||
img_np = np.clip(img_np, 0, 255)
|
||||
img_np = img_np.astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
return PIL.ImageOps.solarize(img, threshold)
|
||||
|
||||
|
||||
def Posterize(img, v):
|
||||
assert 4 <= v <= 8
|
||||
v = int(v)
|
||||
return PIL.ImageOps.posterize(img, v)
|
||||
|
||||
|
||||
def Contrast(img, v):
|
||||
assert 0.0 <= v <= 2.0
|
||||
return PIL.ImageEnhance.Contrast(img).enhance(v)
|
||||
|
||||
|
||||
def Color(img, v):
|
||||
assert 0.0 <= v <= 2.0
|
||||
return PIL.ImageEnhance.Color(img).enhance(v)
|
||||
|
||||
|
||||
def Brightness(img, v):
|
||||
assert 0.0 <= v <= 2.0
|
||||
return PIL.ImageEnhance.Brightness(img).enhance(v)
|
||||
|
||||
|
||||
def Sharpness(img, v):
|
||||
assert 0.0 <= v <= 2.0
|
||||
return PIL.ImageEnhance.Sharpness(img).enhance(v)
|
||||
|
||||
|
||||
def Cutout(img, v):
|
||||
# [0, 60] => percentage: [0, 0.2]
|
||||
assert 0.0 <= v <= 0.2
|
||||
if v <= 0.0:
|
||||
return img
|
||||
|
||||
v = v * img.size[0]
|
||||
return CutoutAbs(img, v)
|
||||
|
||||
|
||||
def CutoutAbs(img, v):
|
||||
# [0, 60] => percentage: [0, 0.2]
|
||||
# assert 0 <= v <= 20
|
||||
if v < 0:
|
||||
return img
|
||||
w, h = img.size
|
||||
x0 = np.random.uniform(w)
|
||||
y0 = np.random.uniform(h)
|
||||
|
||||
x0 = int(max(0, x0 - v/2.0))
|
||||
y0 = int(max(0, y0 - v/2.0))
|
||||
x1 = min(w, x0 + v)
|
||||
y1 = min(h, y0 + v)
|
||||
|
||||
xy = (x0, y0, x1, y1)
|
||||
color = (125, 123, 114)
|
||||
# color = (0, 0, 0)
|
||||
img = img.copy()
|
||||
PIL.ImageDraw.Draw(img).rectangle(xy, color)
|
||||
return img
|
||||
|
||||
|
||||
def SamplePairing(imgs):
|
||||
# [0, 0.4]
|
||||
def f(img1, v):
|
||||
i = np.random.choice(len(imgs))
|
||||
img2 = PIL.Image.fromarray(imgs[i])
|
||||
return PIL.Image.blend(img1, img2, v)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def Identity(img, v):
|
||||
return img
|
||||
|
||||
|
||||
class Lighting:
|
||||
"""Lighting noise (AlexNet - style PCA - based noise)."""
|
||||
|
||||
def __init__(self, alphastd, eigval, eigvec):
|
||||
self.alphastd = alphastd
|
||||
self.eigval = torch.Tensor(eigval)
|
||||
self.eigvec = torch.Tensor(eigvec)
|
||||
|
||||
def __call__(self, img):
|
||||
if self.alphastd == 0:
|
||||
return img
|
||||
|
||||
alpha = img.new().resize_(3).normal_(0, self.alphastd)
|
||||
rgb = (
|
||||
self.eigvec.type_as(img).clone().mul(
|
||||
alpha.view(1, 3).expand(3, 3)
|
||||
).mul(self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze()
|
||||
)
|
||||
|
||||
return img.add(rgb.view(3, 1, 1).expand_as(img))
|
||||
|
||||
|
||||
class CutoutDefault:
|
||||
"""
|
||||
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
|
||||
"""
|
||||
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
def randaugment_list():
|
||||
# 16 oeprations and their ranges
|
||||
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
|
||||
# augs = [
|
||||
# (Identity, 0., 1.0),
|
||||
# (ShearX, 0., 0.3), # 0
|
||||
# (ShearY, 0., 0.3), # 1
|
||||
# (TranslateX, 0., 0.33), # 2
|
||||
# (TranslateY, 0., 0.33), # 3
|
||||
# (Rotate, 0, 30), # 4
|
||||
# (AutoContrast, 0, 1), # 5
|
||||
# (Invert, 0, 1), # 6
|
||||
# (Equalize, 0, 1), # 7
|
||||
# (Solarize, 0, 110), # 8
|
||||
# (Posterize, 4, 8), # 9
|
||||
# # (Contrast, 0.1, 1.9), # 10
|
||||
# (Color, 0.1, 1.9), # 11
|
||||
# (Brightness, 0.1, 1.9), # 12
|
||||
# (Sharpness, 0.1, 1.9), # 13
|
||||
# # (Cutout, 0, 0.2), # 14
|
||||
# # (SamplePairing(imgs), 0, 0.4) # 15
|
||||
# ]
|
||||
|
||||
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
|
||||
augs = [
|
||||
(AutoContrast, 0, 1),
|
||||
(Equalize, 0, 1),
|
||||
(Invert, 0, 1),
|
||||
(Rotate, 0, 30),
|
||||
(Posterize, 4, 8),
|
||||
(Solarize, 0, 256),
|
||||
(SolarizeAdd, 0, 110),
|
||||
(Color, 0.1, 1.9),
|
||||
(Contrast, 0.1, 1.9),
|
||||
(Brightness, 0.1, 1.9),
|
||||
(Sharpness, 0.1, 1.9),
|
||||
(ShearX, 0.0, 0.3),
|
||||
(ShearY, 0.0, 0.3),
|
||||
(CutoutAbs, 0, 40),
|
||||
(TranslateXabs, 0.0, 100),
|
||||
(TranslateYabs, 0.0, 100),
|
||||
]
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
def randaugment_list2():
|
||||
augs = [
|
||||
(AutoContrast, 0, 1),
|
||||
(Brightness, 0.1, 1.9),
|
||||
(Color, 0.1, 1.9),
|
||||
(Contrast, 0.1, 1.9),
|
||||
(Equalize, 0, 1),
|
||||
(Identity, 0, 1),
|
||||
(Invert, 0, 1),
|
||||
(Posterize, 4, 8),
|
||||
(Rotate, -30, 30),
|
||||
(Sharpness, 0.1, 1.9),
|
||||
(ShearX, -0.3, 0.3),
|
||||
(ShearY, -0.3, 0.3),
|
||||
(Solarize, 0, 256),
|
||||
(TranslateX, -0.3, 0.3),
|
||||
(TranslateY, -0.3, 0.3),
|
||||
]
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
def fixmatch_list():
|
||||
# https://arxiv.org/abs/2001.07685
|
||||
augs = [
|
||||
(AutoContrast, 0, 1),
|
||||
(Brightness, 0.05, 0.95),
|
||||
(Color, 0.05, 0.95),
|
||||
(Contrast, 0.05, 0.95),
|
||||
(Equalize, 0, 1),
|
||||
(Identity, 0, 1),
|
||||
(Posterize, 4, 8),
|
||||
(Rotate, -30, 30),
|
||||
(Sharpness, 0.05, 0.95),
|
||||
(ShearX, -0.3, 0.3),
|
||||
(ShearY, -0.3, 0.3),
|
||||
(Solarize, 0, 256),
|
||||
(TranslateX, -0.3, 0.3),
|
||||
(TranslateY, -0.3, 0.3),
|
||||
]
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
class RandAugment:
|
||||
|
||||
def __init__(self, n=2, m=10):
|
||||
assert 0 <= m <= 30
|
||||
self.n = n
|
||||
self.m = m
|
||||
self.augment_list = randaugment_list()
|
||||
|
||||
def __call__(self, img):
|
||||
ops = random.choices(self.augment_list, k=self.n)
|
||||
|
||||
for op, minval, maxval in ops:
|
||||
val = (self.m / 30) * (maxval-minval) + minval
|
||||
img = op(img, val)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class RandAugment2:
|
||||
|
||||
def __init__(self, n=2, p=0.6):
|
||||
self.n = n
|
||||
self.p = p
|
||||
self.augment_list = randaugment_list2()
|
||||
|
||||
def __call__(self, img):
|
||||
ops = random.choices(self.augment_list, k=self.n)
|
||||
|
||||
for op, minval, maxval in ops:
|
||||
if random.random() > self.p:
|
||||
continue
|
||||
m = random.random()
|
||||
val = m * (maxval-minval) + minval
|
||||
img = op(img, val)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class RandAugmentFixMatch:
|
||||
|
||||
def __init__(self, n=2):
|
||||
self.n = n
|
||||
self.augment_list = fixmatch_list()
|
||||
|
||||
def __call__(self, img):
|
||||
ops = random.choices(self.augment_list, k=self.n)
|
||||
|
||||
for op, minval, maxval in ops:
|
||||
m = random.random()
|
||||
val = m * (maxval-minval) + minval
|
||||
img = op(img, val)
|
||||
|
||||
return img
|
||||
341
Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
Normal file
341
Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import (
|
||||
Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,
|
||||
RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,
|
||||
RandomHorizontalFlip
|
||||
)
|
||||
|
||||
from .autoaugment import SVHNPolicy, CIFAR10Policy, ImageNetPolicy
|
||||
from .randaugment import RandAugment, RandAugment2, RandAugmentFixMatch
|
||||
|
||||
AVAI_CHOICES = [
|
||||
"random_flip",
|
||||
"random_resized_crop",
|
||||
"normalize",
|
||||
"instance_norm",
|
||||
"random_crop",
|
||||
"random_translation",
|
||||
"center_crop", # This has become a default operation for test
|
||||
"cutout",
|
||||
"imagenet_policy",
|
||||
"cifar10_policy",
|
||||
"svhn_policy",
|
||||
"randaugment",
|
||||
"randaugment_fixmatch",
|
||||
"randaugment2",
|
||||
"gaussian_noise",
|
||||
"colorjitter",
|
||||
"randomgrayscale",
|
||||
"gaussian_blur",
|
||||
]
|
||||
|
||||
INTERPOLATION_MODES = {
|
||||
"bilinear": Image.BILINEAR,
|
||||
"bicubic": Image.BICUBIC,
|
||||
"nearest": Image.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
class Random2DTranslation:
|
||||
"""Given an image of (height, width), we resize it to
|
||||
(height*1.125, width*1.125), and then perform random cropping.
|
||||
|
||||
Args:
|
||||
height (int): target image height.
|
||||
width (int): target image width.
|
||||
p (float, optional): probability that this operation takes place.
|
||||
Default is 0.5.
|
||||
interpolation (int, optional): desired interpolation. Default is
|
||||
``PIL.Image.BILINEAR``
|
||||
"""
|
||||
|
||||
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.p = p
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return img.resize((self.width, self.height), self.interpolation)
|
||||
|
||||
new_width = int(round(self.width * 1.125))
|
||||
new_height = int(round(self.height * 1.125))
|
||||
resized_img = img.resize((new_width, new_height), self.interpolation)
|
||||
|
||||
x_maxrange = new_width - self.width
|
||||
y_maxrange = new_height - self.height
|
||||
x1 = int(round(random.uniform(0, x_maxrange)))
|
||||
y1 = int(round(random.uniform(0, y_maxrange)))
|
||||
croped_img = resized_img.crop(
|
||||
(x1, y1, x1 + self.width, y1 + self.height)
|
||||
)
|
||||
|
||||
return croped_img
|
||||
|
||||
|
||||
class InstanceNormalization:
|
||||
"""Normalize data using per-channel mean and standard deviation.
|
||||
|
||||
Reference:
|
||||
- Ulyanov et al. Instance normalization: The missing in- gredient
|
||||
for fast stylization. ArXiv 2016.
|
||||
- Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.
|
||||
ICLR 2018.
|
||||
"""
|
||||
|
||||
def __init__(self, eps=1e-8):
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, img):
|
||||
C, H, W = img.shape
|
||||
img_re = img.reshape(C, H * W)
|
||||
mean = img_re.mean(1).view(C, 1, 1)
|
||||
std = img_re.std(1).view(C, 1, 1)
|
||||
return (img-mean) / (std + self.eps)
|
||||
|
||||
|
||||
class Cutout:
|
||||
"""Randomly mask out one or more patches from an image.
|
||||
|
||||
https://github.com/uoguelph-mlrg/Cutout
|
||||
|
||||
Args:
|
||||
n_holes (int, optional): number of patches to cut out
|
||||
of each image. Default is 1.
|
||||
length (int, optinal): length (in pixels) of each square
|
||||
patch. Default is 16.
|
||||
"""
|
||||
|
||||
def __init__(self, n_holes=1, length=16):
|
||||
self.n_holes = n_holes
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): tensor image of size (C, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor: image with n_holes of dimension
|
||||
length x length cut out of it.
|
||||
"""
|
||||
h = img.size(1)
|
||||
w = img.size(2)
|
||||
|
||||
mask = np.ones((h, w), np.float32)
|
||||
|
||||
for n in range(self.n_holes):
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
return img * mask
|
||||
|
||||
|
||||
class GaussianNoise:
|
||||
"""Add gaussian noise."""
|
||||
|
||||
def __init__(self, mean=0, std=0.15, p=0.5):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return img
|
||||
noise = torch.randn(img.size()) * self.std + self.mean
|
||||
return img + noise
|
||||
|
||||
|
||||
def build_transform(cfg, is_train=True, choices=None):
|
||||
"""Build transformation function.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): config.
|
||||
is_train (bool, optional): for training (True) or test (False).
|
||||
Default is True.
|
||||
choices (list, optional): list of strings which will overwrite
|
||||
cfg.INPUT.TRANSFORMS if given. Default is None.
|
||||
"""
|
||||
if cfg.INPUT.NO_TRANSFORM:
|
||||
print("Note: no transform is applied!")
|
||||
return None
|
||||
|
||||
if choices is None:
|
||||
choices = cfg.INPUT.TRANSFORMS
|
||||
|
||||
for choice in choices:
|
||||
assert choice in AVAI_CHOICES
|
||||
|
||||
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
|
||||
|
||||
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||
|
||||
if is_train:
|
||||
return _build_transform_train(cfg, choices, target_size, normalize)
|
||||
else:
|
||||
return _build_transform_test(cfg, choices, target_size, normalize)
|
||||
|
||||
|
||||
def _build_transform_train(cfg, choices, target_size, normalize):
|
||||
print("Building transform_train")
|
||||
tfm_train = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
|
||||
# Make sure the image size matches the target size
|
||||
conditions = []
|
||||
conditions += ["random_crop" not in choices]
|
||||
conditions += ["random_resized_crop" not in choices]
|
||||
if all(conditions):
|
||||
print(f"+ resize to {target_size}")
|
||||
tfm_train += [Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
|
||||
if "random_translation" in choices:
|
||||
print("+ random translation")
|
||||
tfm_train += [
|
||||
Random2DTranslation(cfg.INPUT.SIZE[0], cfg.INPUT.SIZE[1])
|
||||
]
|
||||
|
||||
if "random_crop" in choices:
|
||||
crop_padding = cfg.INPUT.CROP_PADDING
|
||||
print("+ random crop (padding = {})".format(crop_padding))
|
||||
tfm_train += [RandomCrop(cfg.INPUT.SIZE, padding=crop_padding)]
|
||||
|
||||
if "random_resized_crop" in choices:
|
||||
print(f"+ random resized crop (size={cfg.INPUT.SIZE})")
|
||||
tfm_train += [
|
||||
RandomResizedCrop(cfg.INPUT.SIZE, interpolation=interp_mode)
|
||||
]
|
||||
|
||||
if "center_crop" in choices:
|
||||
print(f"+ center crop (size={cfg.INPUT.SIZE})")
|
||||
tfm_train += [CenterCrop(cfg.INPUT.SIZE)]
|
||||
|
||||
if "random_flip" in choices:
|
||||
print("+ random flip")
|
||||
tfm_train += [RandomHorizontalFlip()]
|
||||
|
||||
if "imagenet_policy" in choices:
|
||||
print("+ imagenet policy")
|
||||
tfm_train += [ImageNetPolicy()]
|
||||
|
||||
if "cifar10_policy" in choices:
|
||||
print("+ cifar10 policy")
|
||||
tfm_train += [CIFAR10Policy()]
|
||||
|
||||
if "svhn_policy" in choices:
|
||||
print("+ svhn policy")
|
||||
tfm_train += [SVHNPolicy()]
|
||||
|
||||
if "randaugment" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
m_ = cfg.INPUT.RANDAUGMENT_M
|
||||
print("+ randaugment (n={}, m={})".format(n_, m_))
|
||||
tfm_train += [RandAugment(n_, m_)]
|
||||
|
||||
if "randaugment_fixmatch" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print("+ randaugment_fixmatch (n={})".format(n_))
|
||||
tfm_train += [RandAugmentFixMatch(n_)]
|
||||
|
||||
if "randaugment2" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print("+ randaugment2 (n={})".format(n_))
|
||||
tfm_train += [RandAugment2(n_)]
|
||||
|
||||
if "colorjitter" in choices:
|
||||
print("+ color jitter")
|
||||
tfm_train += [
|
||||
ColorJitter(
|
||||
brightness=cfg.INPUT.COLORJITTER_B,
|
||||
contrast=cfg.INPUT.COLORJITTER_C,
|
||||
saturation=cfg.INPUT.COLORJITTER_S,
|
||||
hue=cfg.INPUT.COLORJITTER_H,
|
||||
)
|
||||
]
|
||||
|
||||
if "randomgrayscale" in choices:
|
||||
print("+ random gray scale")
|
||||
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
|
||||
|
||||
if "gaussian_blur" in choices:
|
||||
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
|
||||
tfm_train += [
|
||||
RandomApply([GaussianBlur(cfg.INPUT.GB_K)], p=cfg.INPUT.GB_P)
|
||||
]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_train += [ToTensor()]
|
||||
|
||||
if "cutout" in choices:
|
||||
cutout_n = cfg.INPUT.CUTOUT_N
|
||||
cutout_len = cfg.INPUT.CUTOUT_LEN
|
||||
print("+ cutout (n_holes={}, length={})".format(cutout_n, cutout_len))
|
||||
tfm_train += [Cutout(cutout_n, cutout_len)]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
"+ normalization (mean={}, "
|
||||
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
|
||||
)
|
||||
tfm_train += [normalize]
|
||||
|
||||
if "gaussian_noise" in choices:
|
||||
print(
|
||||
"+ gaussian noise (mean={}, std={})".format(
|
||||
cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD
|
||||
)
|
||||
)
|
||||
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_train += [InstanceNormalization()]
|
||||
|
||||
tfm_train = Compose(tfm_train)
|
||||
|
||||
return tfm_train
|
||||
|
||||
|
||||
def _build_transform_test(cfg, choices, target_size, normalize):
|
||||
print("Building transform_test")
|
||||
tfm_test = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
|
||||
print(f"+ resize the smaller edge to {max(cfg.INPUT.SIZE)}")
|
||||
tfm_test += [Resize(max(cfg.INPUT.SIZE), interpolation=interp_mode)]
|
||||
|
||||
print(f"+ {target_size} center crop")
|
||||
tfm_test += [CenterCrop(cfg.INPUT.SIZE)]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_test += [ToTensor()]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
"+ normalization (mean={}, "
|
||||
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
|
||||
)
|
||||
tfm_test += [normalize]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_test += [InstanceNormalization()]
|
||||
|
||||
tfm_test = Compose(tfm_test)
|
||||
|
||||
return tfm_test
|
||||
6
Dassl.ProGrad.pytorch/dassl/engine/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/engine/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .build import TRAINER_REGISTRY, build_trainer # isort:skip
|
||||
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip
|
||||
|
||||
from .da import *
|
||||
from .dg import *
|
||||
from .ssl import *
|
||||
11
Dassl.ProGrad.pytorch/dassl/engine/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/engine/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
TRAINER_REGISTRY = Registry("TRAINER")
|
||||
|
||||
|
||||
def build_trainer(cfg):
|
||||
avai_trainers = TRAINER_REGISTRY.registered_names()
|
||||
check_availability(cfg.TRAINER.NAME, avai_trainers)
|
||||
if cfg.VERBOSE:
|
||||
print("Loading trainer: {}".format(cfg.TRAINER.NAME))
|
||||
return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)
|
||||
9
Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py
Normal file
9
Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .mcd import MCD
|
||||
from .mme import MME
|
||||
from .adda import ADDA
|
||||
from .dael import DAEL
|
||||
from .dann import DANN
|
||||
from .adabn import AdaBN
|
||||
from .m3sda import M3SDA
|
||||
from .source_only import SourceOnly
|
||||
from .self_ensembling import SelfEnsembling
|
||||
38
Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
Normal file
38
Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
|
||||
from dassl.utils import check_isfile
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class AdaBN(TrainerXU):
|
||||
"""Adaptive Batch Normalization.
|
||||
|
||||
https://arxiv.org/abs/1603.04779.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.done_reset_bn_stats = False
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert check_isfile(
|
||||
cfg.MODEL.INIT_WEIGHTS
|
||||
), "The weights of source model must be provided"
|
||||
|
||||
def before_epoch(self):
|
||||
if not self.done_reset_bn_stats:
|
||||
for m in self.model.modules():
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("BatchNorm") != -1:
|
||||
m.reset_running_stats()
|
||||
|
||||
self.done_reset_bn_stats = True
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_u = batch_u["img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
self.model(input_u)
|
||||
|
||||
return None
|
||||
85
Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
Normal file
85
Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import check_isfile, count_num_param, open_specified_layers
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.modeling import build_head
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class ADDA(TrainerXU):
|
||||
"""Adversarial Discriminative Domain Adaptation.
|
||||
|
||||
https://arxiv.org/abs/1702.05464.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.open_layers = ["backbone"]
|
||||
if isinstance(self.model.head, nn.Module):
|
||||
self.open_layers.append("head")
|
||||
|
||||
self.source_model = copy.deepcopy(self.model)
|
||||
self.source_model.eval()
|
||||
for param in self.source_model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
self.build_critic()
|
||||
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert check_isfile(
|
||||
cfg.MODEL.INIT_WEIGHTS
|
||||
), "The weights of source model must be provided"
|
||||
|
||||
def build_critic(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building critic network")
|
||||
fdim = self.model.fdim
|
||||
critic_body = build_head(
|
||||
"mlp",
|
||||
verbose=cfg.VERBOSE,
|
||||
in_features=fdim,
|
||||
hidden_layers=[fdim, fdim // 2],
|
||||
activation="leaky_relu",
|
||||
)
|
||||
self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
|
||||
print("# params: {:,}".format(count_num_param(self.critic)))
|
||||
self.critic.to(self.device)
|
||||
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
|
||||
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
|
||||
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
open_specified_layers(self.model, self.open_layers)
|
||||
input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
|
||||
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
|
||||
|
||||
_, feat_x = self.source_model(input_x, return_feature=True)
|
||||
_, feat_u = self.model(input_u, return_feature=True)
|
||||
|
||||
logit_xd = self.critic(feat_x)
|
||||
logit_ud = self.critic(feat_u.detach())
|
||||
|
||||
loss_critic = self.bce(logit_xd, domain_x)
|
||||
loss_critic += self.bce(logit_ud, domain_u)
|
||||
self.model_backward_and_update(loss_critic, "critic")
|
||||
|
||||
logit_ud = self.critic(feat_u)
|
||||
loss_model = self.bce(logit_ud, 1 - domain_u)
|
||||
self.model_backward_and_update(loss_model, "model")
|
||||
|
||||
loss_summary = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"loss_model": loss_model.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
210
Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
Normal file
210
Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
from dassl.data.transforms import build_transform
|
||||
from dassl.modeling.ops.utils import create_onehot
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
|
||||
def __init__(self, n_source, fdim, num_classes):
|
||||
super().__init__()
|
||||
self.linears = nn.ModuleList(
|
||||
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
|
||||
)
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
def forward(self, i, x):
|
||||
x = self.linears[i](x)
|
||||
x = self.softmax(x)
|
||||
return x
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DAEL(TrainerXU):
|
||||
"""Domain Adaptive Ensemble Learning.
|
||||
|
||||
https://arxiv.org/abs/2003.07325.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
if n_domain <= 0:
|
||||
n_domain = self.num_source_domains
|
||||
self.split_batch = batch_size // n_domain
|
||||
self.n_domain = n_domain
|
||||
|
||||
self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
|
||||
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
|
||||
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
|
||||
|
||||
def build_data_loader(self):
|
||||
cfg = self.cfg
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
custom_tfm_train = [tfm_train]
|
||||
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
|
||||
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||
custom_tfm_train += [tfm_train_strong]
|
||||
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||
self.train_loader_x = dm.train_loader_x
|
||||
self.train_loader_u = dm.train_loader_u
|
||||
self.val_loader = dm.val_loader
|
||||
self.test_loader = dm.test_loader
|
||||
self.num_classes = dm.num_classes
|
||||
self.num_source_domains = dm.num_source_domains
|
||||
self.lab2cname = dm.lab2cname
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
fdim = self.F.fdim
|
||||
|
||||
print("Building E")
|
||||
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
|
||||
self.E.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.E)))
|
||||
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
|
||||
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
|
||||
self.register_model("E", self.E, self.optim_E, self.sched_E)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed_data = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
|
||||
|
||||
input_x = torch.split(input_x, self.split_batch, 0)
|
||||
input_x2 = torch.split(input_x2, self.split_batch, 0)
|
||||
label_x = torch.split(label_x, self.split_batch, 0)
|
||||
domain_x = torch.split(domain_x, self.split_batch, 0)
|
||||
domain_x = [d[0].item() for d in domain_x]
|
||||
|
||||
# Generate pseudo label
|
||||
with torch.no_grad():
|
||||
feat_u = self.F(input_u)
|
||||
pred_u = []
|
||||
for k in range(self.num_source_domains):
|
||||
pred_uk = self.E(k, feat_u)
|
||||
pred_uk = pred_uk.unsqueeze(1)
|
||||
pred_u.append(pred_uk)
|
||||
pred_u = torch.cat(pred_u, 1) # (B, K, C)
|
||||
# Get the highest probability and index (label) for each expert
|
||||
experts_max_p, experts_max_idx = pred_u.max(2) # (B, K)
|
||||
# Get the most confident expert
|
||||
max_expert_p, max_expert_idx = experts_max_p.max(1) # (B)
|
||||
pseudo_label_u = []
|
||||
for i, experts_label in zip(max_expert_idx, experts_max_idx):
|
||||
pseudo_label_u.append(experts_label[i])
|
||||
pseudo_label_u = torch.stack(pseudo_label_u, 0)
|
||||
pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
|
||||
pseudo_label_u = pseudo_label_u.to(self.device)
|
||||
label_u_mask = (max_expert_p >= self.conf_thre).float()
|
||||
|
||||
loss_x = 0
|
||||
loss_cr = 0
|
||||
acc_x = 0
|
||||
|
||||
feat_x = [self.F(x) for x in input_x]
|
||||
feat_x2 = [self.F(x) for x in input_x2]
|
||||
feat_u2 = self.F(input_u2)
|
||||
|
||||
for feat_xi, feat_x2i, label_xi, i in zip(
|
||||
feat_x, feat_x2, label_x, domain_x
|
||||
):
|
||||
cr_s = [j for j in domain_x if j != i]
|
||||
|
||||
# Learning expert
|
||||
pred_xi = self.E(i, feat_xi)
|
||||
loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
|
||||
expert_label_xi = pred_xi.detach()
|
||||
acc_x += compute_accuracy(pred_xi.detach(),
|
||||
label_xi.max(1)[1])[0].item()
|
||||
|
||||
# Consistency regularization
|
||||
cr_pred = []
|
||||
for j in cr_s:
|
||||
pred_j = self.E(j, feat_x2i)
|
||||
pred_j = pred_j.unsqueeze(1)
|
||||
cr_pred.append(pred_j)
|
||||
cr_pred = torch.cat(cr_pred, 1)
|
||||
cr_pred = cr_pred.mean(1)
|
||||
loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()
|
||||
|
||||
loss_x /= self.n_domain
|
||||
loss_cr /= self.n_domain
|
||||
acc_x /= self.n_domain
|
||||
|
||||
# Unsupervised loss
|
||||
pred_u = []
|
||||
for k in range(self.num_source_domains):
|
||||
pred_uk = self.E(k, feat_u2)
|
||||
pred_uk = pred_uk.unsqueeze(1)
|
||||
pred_u.append(pred_uk)
|
||||
pred_u = torch.cat(pred_u, 1)
|
||||
pred_u = pred_u.mean(1)
|
||||
l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
|
||||
loss_u = (l_u * label_u_mask).mean()
|
||||
|
||||
loss = 0
|
||||
loss += loss_x
|
||||
loss += loss_cr
|
||||
loss += loss_u * self.weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": acc_x,
|
||||
"loss_cr": loss_cr.item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
input_x2 = batch_x["img2"]
|
||||
label_x = batch_x["label"]
|
||||
domain_x = batch_x["domain"]
|
||||
input_u = batch_u["img"]
|
||||
input_u2 = batch_u["img2"]
|
||||
|
||||
label_x = create_onehot(label_x, self.num_classes)
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
input_x2 = input_x2.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
input_u2 = input_u2.to(self.device)
|
||||
|
||||
return input_x, input_x2, label_x, domain_x, input_u, input_u2
|
||||
|
||||
def model_inference(self, input):
|
||||
f = self.F(input)
|
||||
p = []
|
||||
for k in range(self.num_source_domains):
|
||||
p_k = self.E(k, f)
|
||||
p_k = p_k.unsqueeze(1)
|
||||
p.append(p_k)
|
||||
p = torch.cat(p, 1)
|
||||
p = p.mean(1)
|
||||
return p
|
||||
78
Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
Normal file
78
Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling import build_head
|
||||
from dassl.modeling.ops import ReverseGrad
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DANN(TrainerXU):
|
||||
"""Domain-Adversarial Neural Networks.
|
||||
|
||||
https://arxiv.org/abs/1505.07818.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.build_critic()
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def build_critic(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building critic network")
|
||||
fdim = self.model.fdim
|
||||
critic_body = build_head(
|
||||
"mlp",
|
||||
verbose=cfg.VERBOSE,
|
||||
in_features=fdim,
|
||||
hidden_layers=[fdim, fdim],
|
||||
activation="leaky_relu",
|
||||
)
|
||||
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
|
||||
print("# params: {:,}".format(count_num_param(self.critic)))
|
||||
self.critic.to(self.device)
|
||||
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
|
||||
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
|
||||
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
|
||||
self.revgrad = ReverseGrad()
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
|
||||
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
progress = global_step / (self.max_epoch * self.num_batches)
|
||||
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
|
||||
|
||||
logit_x, feat_x = self.model(input_x, return_feature=True)
|
||||
_, feat_u = self.model(input_u, return_feature=True)
|
||||
|
||||
loss_x = self.ce(logit_x, label_x)
|
||||
|
||||
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
|
||||
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
|
||||
output_xd = self.critic(feat_x)
|
||||
output_ud = self.critic(feat_u)
|
||||
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
|
||||
|
||||
loss = loss_x + loss_d
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_d": loss_d.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
208
Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
Normal file
208
Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
class PairClassifiers(nn.Module):
|
||||
|
||||
def __init__(self, fdim, num_classes):
|
||||
super().__init__()
|
||||
self.c1 = nn.Linear(fdim, num_classes)
|
||||
self.c2 = nn.Linear(fdim, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
z1 = self.c1(x)
|
||||
if not self.training:
|
||||
return z1
|
||||
z2 = self.c2(x)
|
||||
return z1, z2
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class M3SDA(TrainerXU):
|
||||
"""Moment Matching for Multi-Source Domain Adaptation.
|
||||
|
||||
https://arxiv.org/abs/1812.01754.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
if n_domain <= 0:
|
||||
n_domain = self.num_source_domains
|
||||
self.split_batch = batch_size // n_domain
|
||||
self.n_domain = n_domain
|
||||
|
||||
self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
|
||||
self.lmda = cfg.TRAINER.M3SDA.LMDA
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
fdim = self.F.fdim
|
||||
|
||||
print("Building C")
|
||||
self.C = nn.ModuleList(
|
||||
[
|
||||
PairClassifiers(fdim, self.num_classes)
|
||||
for _ in range(self.num_source_domains)
|
||||
]
|
||||
)
|
||||
self.C.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.C)))
|
||||
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
|
||||
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
|
||||
self.register_model("C", self.C, self.optim_C, self.sched_C)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, label_x, domain_x, input_u = parsed
|
||||
|
||||
input_x = torch.split(input_x, self.split_batch, 0)
|
||||
label_x = torch.split(label_x, self.split_batch, 0)
|
||||
domain_x = torch.split(domain_x, self.split_batch, 0)
|
||||
domain_x = [d[0].item() for d in domain_x]
|
||||
|
||||
# Step A
|
||||
loss_x = 0
|
||||
feat_x = []
|
||||
|
||||
for x, y, d in zip(input_x, label_x, domain_x):
|
||||
f = self.F(x)
|
||||
z1, z2 = self.C[d](f)
|
||||
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
|
||||
|
||||
feat_x.append(f)
|
||||
|
||||
loss_x /= self.n_domain
|
||||
|
||||
feat_u = self.F(input_u)
|
||||
loss_msda = self.moment_distance(feat_x, feat_u)
|
||||
|
||||
loss_step_A = loss_x + loss_msda * self.lmda
|
||||
self.model_backward_and_update(loss_step_A)
|
||||
|
||||
# Step B
|
||||
with torch.no_grad():
|
||||
feat_u = self.F(input_u)
|
||||
|
||||
loss_x, loss_dis = 0, 0
|
||||
|
||||
for x, y, d in zip(input_x, label_x, domain_x):
|
||||
with torch.no_grad():
|
||||
f = self.F(x)
|
||||
z1, z2 = self.C[d](f)
|
||||
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
|
||||
|
||||
z1, z2 = self.C[d](feat_u)
|
||||
p1 = F.softmax(z1, 1)
|
||||
p2 = F.softmax(z2, 1)
|
||||
loss_dis += self.discrepancy(p1, p2)
|
||||
|
||||
loss_x /= self.n_domain
|
||||
loss_dis /= self.n_domain
|
||||
|
||||
loss_step_B = loss_x - loss_dis
|
||||
self.model_backward_and_update(loss_step_B, "C")
|
||||
|
||||
# Step C
|
||||
for _ in range(self.n_step_F):
|
||||
feat_u = self.F(input_u)
|
||||
|
||||
loss_dis = 0
|
||||
|
||||
for d in domain_x:
|
||||
z1, z2 = self.C[d](feat_u)
|
||||
p1 = F.softmax(z1, 1)
|
||||
p2 = F.softmax(z2, 1)
|
||||
loss_dis += self.discrepancy(p1, p2)
|
||||
|
||||
loss_dis /= self.n_domain
|
||||
loss_step_C = loss_dis
|
||||
|
||||
self.model_backward_and_update(loss_step_C, "F")
|
||||
|
||||
loss_summary = {
|
||||
"loss_step_A": loss_step_A.item(),
|
||||
"loss_step_B": loss_step_B.item(),
|
||||
"loss_step_C": loss_step_C.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def moment_distance(self, x, u):
|
||||
# x (list): a list of feature matrix.
|
||||
# u (torch.Tensor): feature matrix.
|
||||
x_mean = [xi.mean(0) for xi in x]
|
||||
u_mean = u.mean(0)
|
||||
dist1 = self.pairwise_distance(x_mean, u_mean)
|
||||
|
||||
x_var = [xi.var(0) for xi in x]
|
||||
u_var = u.var(0)
|
||||
dist2 = self.pairwise_distance(x_var, u_var)
|
||||
|
||||
return (dist1+dist2) / 2
|
||||
|
||||
def pairwise_distance(self, x, u):
|
||||
# x (list): a list of feature vector.
|
||||
# u (torch.Tensor): feature vector.
|
||||
dist = 0
|
||||
count = 0
|
||||
|
||||
for xi in x:
|
||||
dist += self.euclidean(xi, u)
|
||||
count += 1
|
||||
|
||||
for i in range(len(x) - 1):
|
||||
for j in range(i + 1, len(x)):
|
||||
dist += self.euclidean(x[i], x[j])
|
||||
count += 1
|
||||
|
||||
return dist / count
|
||||
|
||||
def euclidean(self, input1, input2):
|
||||
return ((input1 - input2)**2).sum().sqrt()
|
||||
|
||||
def discrepancy(self, y1, y2):
|
||||
return (y1 - y2).abs().mean()
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
label_x = batch_x["label"]
|
||||
domain_x = batch_x["domain"]
|
||||
input_u = batch_u["img"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
|
||||
return input_x, label_x, domain_x, input_u
|
||||
|
||||
def model_inference(self, input):
|
||||
f = self.F(input)
|
||||
p = 0
|
||||
for C_i in self.C:
|
||||
z = C_i(f)
|
||||
p += F.softmax(z, 1)
|
||||
p = p / len(self.C)
|
||||
return p
|
||||
105
Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
Normal file
105
Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MCD(TrainerXU):
|
||||
"""Maximum Classifier Discrepancy.
|
||||
|
||||
https://arxiv.org/abs/1712.02560.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.n_step_F = cfg.TRAINER.MCD.N_STEP_F
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
fdim = self.F.fdim
|
||||
|
||||
print("Building C1")
|
||||
self.C1 = nn.Linear(fdim, self.num_classes)
|
||||
self.C1.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.C1)))
|
||||
self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
|
||||
self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
|
||||
self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)
|
||||
|
||||
print("Building C2")
|
||||
self.C2 = nn.Linear(fdim, self.num_classes)
|
||||
self.C2.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.C2)))
|
||||
self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
|
||||
self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
|
||||
self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, label_x, input_u = parsed
|
||||
|
||||
# Step A
|
||||
feat_x = self.F(input_x)
|
||||
logit_x1 = self.C1(feat_x)
|
||||
logit_x2 = self.C2(feat_x)
|
||||
loss_x1 = F.cross_entropy(logit_x1, label_x)
|
||||
loss_x2 = F.cross_entropy(logit_x2, label_x)
|
||||
loss_step_A = loss_x1 + loss_x2
|
||||
self.model_backward_and_update(loss_step_A)
|
||||
|
||||
# Step B
|
||||
with torch.no_grad():
|
||||
feat_x = self.F(input_x)
|
||||
logit_x1 = self.C1(feat_x)
|
||||
logit_x2 = self.C2(feat_x)
|
||||
loss_x1 = F.cross_entropy(logit_x1, label_x)
|
||||
loss_x2 = F.cross_entropy(logit_x2, label_x)
|
||||
loss_x = loss_x1 + loss_x2
|
||||
|
||||
with torch.no_grad():
|
||||
feat_u = self.F(input_u)
|
||||
pred_u1 = F.softmax(self.C1(feat_u), 1)
|
||||
pred_u2 = F.softmax(self.C2(feat_u), 1)
|
||||
loss_dis = self.discrepancy(pred_u1, pred_u2)
|
||||
|
||||
loss_step_B = loss_x - loss_dis
|
||||
self.model_backward_and_update(loss_step_B, ["C1", "C2"])
|
||||
|
||||
# Step C
|
||||
for _ in range(self.n_step_F):
|
||||
feat_u = self.F(input_u)
|
||||
pred_u1 = F.softmax(self.C1(feat_u), 1)
|
||||
pred_u2 = F.softmax(self.C2(feat_u), 1)
|
||||
loss_step_C = self.discrepancy(pred_u1, pred_u2)
|
||||
self.model_backward_and_update(loss_step_C, "F")
|
||||
|
||||
loss_summary = {
|
||||
"loss_step_A": loss_step_A.item(),
|
||||
"loss_step_B": loss_step_B.item(),
|
||||
"loss_step_C": loss_step_C.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def discrepancy(self, y1, y2):
|
||||
return (y1 - y2).abs().mean()
|
||||
|
||||
def model_inference(self, input):
|
||||
feat = self.F(input)
|
||||
return self.C1(feat)
|
||||
86
Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
Normal file
86
Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling.ops import ReverseGrad
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
class Prototypes(nn.Module):
|
||||
|
||||
def __init__(self, fdim, num_classes, temp=0.05):
|
||||
super().__init__()
|
||||
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
|
||||
self.temp = temp
|
||||
|
||||
def forward(self, x):
|
||||
x = F.normalize(x, p=2, dim=1)
|
||||
out = self.prototypes(x)
|
||||
out = out / self.temp
|
||||
return out
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MME(TrainerXU):
|
||||
"""Minimax Entropy.
|
||||
|
||||
https://arxiv.org/abs/1904.06487.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.lmda = cfg.TRAINER.MME.LMDA
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
|
||||
print("Building C")
|
||||
self.C = Prototypes(self.F.fdim, self.num_classes)
|
||||
self.C.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.C)))
|
||||
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
|
||||
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
|
||||
self.register_model("C", self.C, self.optim_C, self.sched_C)
|
||||
|
||||
self.revgrad = ReverseGrad()
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
|
||||
feat_x = self.F(input_x)
|
||||
logit_x = self.C(feat_x)
|
||||
loss_x = F.cross_entropy(logit_x, label_x)
|
||||
self.model_backward_and_update(loss_x)
|
||||
|
||||
feat_u = self.F(input_u)
|
||||
feat_u = self.revgrad(feat_u)
|
||||
logit_u = self.C(feat_u)
|
||||
prob_u = F.softmax(logit_u, 1)
|
||||
loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
|
||||
self.model_backward_and_update(loss_u * self.lmda)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.C(self.F(input))
|
||||
78
Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
Normal file
78
Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import copy
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class SelfEnsembling(TrainerXU):
|
||||
"""Self-ensembling for visual domain adaptation.
|
||||
|
||||
https://arxiv.org/abs/1706.05208.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA
|
||||
self.conf_thre = cfg.TRAINER.SE.CONF_THRE
|
||||
self.rampup = cfg.TRAINER.SE.RAMPUP
|
||||
|
||||
self.teacher = copy.deepcopy(self.model)
|
||||
self.teacher.train()
|
||||
for param in self.teacher.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.K_TRANSFORMS == 2
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, label_x, input_u1, input_u2 = parsed
|
||||
|
||||
logit_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(logit_x, label_x)
|
||||
|
||||
prob_u = F.softmax(self.model(input_u1), 1)
|
||||
t_prob_u = F.softmax(self.teacher(input_u2), 1)
|
||||
loss_u = ((prob_u - t_prob_u)**2).sum(1)
|
||||
|
||||
if self.conf_thre:
|
||||
max_prob = t_prob_u.max(1)[0]
|
||||
mask = (max_prob > self.conf_thre).float()
|
||||
loss_u = (loss_u * mask).mean()
|
||||
else:
|
||||
weight_u = sigmoid_rampup(global_step, self.rampup)
|
||||
loss_u = loss_u.mean() * weight_u
|
||||
|
||||
loss = loss_x + loss_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
|
||||
ema_model_update(self.model, self.teacher, ema_alpha)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"][0]
|
||||
label_x = batch_x["label"]
|
||||
input_u = batch_u["img"]
|
||||
input_u1, input_u2 = input_u
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u1 = input_u1.to(self.device)
|
||||
input_u2 = input_u2.to(self.device)
|
||||
|
||||
return input_x, label_x, input_u1, input_u2
|
||||
34
Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
Normal file
34
Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class SourceOnly(TrainerXU):
|
||||
"""Baseline model for domain adaptation, which is
|
||||
trained using source data only.
|
||||
"""
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input, label = self.parse_batch_train(batch_x, batch_u)
|
||||
output = self.model(input)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss": loss.item(),
|
||||
"acc": compute_accuracy(output, label)[0].item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input = batch_x["img"]
|
||||
label = batch_x["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
4
Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py
Normal file
4
Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .ddaig import DDAIG
|
||||
from .daeldg import DAELDG
|
||||
from .vanilla import Vanilla
|
||||
from .crossgrad import CrossGrad
|
||||
83
Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
Normal file
83
Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class CrossGrad(TrainerX):
|
||||
"""Cross-gradient training.
|
||||
|
||||
https://arxiv.org/abs/1804.10745.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.eps_f = cfg.TRAINER.CG.EPS_F
|
||||
self.eps_d = cfg.TRAINER.CG.EPS_D
|
||||
self.alpha_f = cfg.TRAINER.CG.ALPHA_F
|
||||
self.alpha_d = cfg.TRAINER.CG.ALPHA_D
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
|
||||
print("Building D")
|
||||
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
|
||||
self.D.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.D)))
|
||||
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
|
||||
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
|
||||
self.register_model("D", self.D, self.optim_D, self.sched_D)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
input, label, domain = self.parse_batch_train(batch)
|
||||
|
||||
input.requires_grad = True
|
||||
|
||||
# Compute domain perturbation
|
||||
loss_d = F.cross_entropy(self.D(input), domain)
|
||||
loss_d.backward()
|
||||
grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
|
||||
input_d = input.data + self.eps_f * grad_d
|
||||
|
||||
# Compute label perturbation
|
||||
input.grad.data.zero_()
|
||||
loss_f = F.cross_entropy(self.F(input), label)
|
||||
loss_f.backward()
|
||||
grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
|
||||
input_f = input.data + self.eps_d * grad_f
|
||||
|
||||
input = input.detach()
|
||||
|
||||
# Update label net
|
||||
loss_f1 = F.cross_entropy(self.F(input), label)
|
||||
loss_f2 = F.cross_entropy(self.F(input_d), label)
|
||||
loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
|
||||
self.model_backward_and_update(loss_f, "F")
|
||||
|
||||
# Update domain net
|
||||
loss_d1 = F.cross_entropy(self.D(input), domain)
|
||||
loss_d2 = F.cross_entropy(self.D(input_f), domain)
|
||||
loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
|
||||
self.model_backward_and_update(loss_d, "D")
|
||||
|
||||
loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.F(input)
|
||||
169
Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
Normal file
169
Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
from dassl.data.transforms import build_transform
|
||||
from dassl.modeling.ops.utils import create_onehot
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
|
||||
def __init__(self, n_source, fdim, num_classes):
|
||||
super().__init__()
|
||||
self.linears = nn.ModuleList(
|
||||
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
|
||||
)
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
def forward(self, i, x):
|
||||
x = self.linears[i](x)
|
||||
x = self.softmax(x)
|
||||
return x
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DAELDG(TrainerX):
|
||||
"""Domain Adaptive Ensemble Learning.
|
||||
|
||||
DG version: only use labeled source data.
|
||||
|
||||
https://arxiv.org/abs/2003.07325.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
if n_domain <= 0:
|
||||
n_domain = self.num_source_domains
|
||||
self.split_batch = batch_size // n_domain
|
||||
self.n_domain = n_domain
|
||||
|
||||
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
|
||||
|
||||
def build_data_loader(self):
|
||||
cfg = self.cfg
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
custom_tfm_train = [tfm_train]
|
||||
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
|
||||
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||
custom_tfm_train += [tfm_train_strong]
|
||||
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||
self.train_loader_x = dm.train_loader_x
|
||||
self.train_loader_u = dm.train_loader_u
|
||||
self.val_loader = dm.val_loader
|
||||
self.test_loader = dm.test_loader
|
||||
self.num_classes = dm.num_classes
|
||||
self.num_source_domains = dm.num_source_domains
|
||||
self.lab2cname = dm.lab2cname
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
fdim = self.F.fdim
|
||||
|
||||
print("Building E")
|
||||
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
|
||||
self.E.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.E)))
|
||||
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
|
||||
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
|
||||
self.register_model("E", self.E, self.optim_E, self.sched_E)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
parsed_data = self.parse_batch_train(batch)
|
||||
input, input2, label, domain = parsed_data
|
||||
|
||||
input = torch.split(input, self.split_batch, 0)
|
||||
input2 = torch.split(input2, self.split_batch, 0)
|
||||
label = torch.split(label, self.split_batch, 0)
|
||||
domain = torch.split(domain, self.split_batch, 0)
|
||||
domain = [d[0].item() for d in domain]
|
||||
|
||||
loss_x = 0
|
||||
loss_cr = 0
|
||||
acc = 0
|
||||
|
||||
feat = [self.F(x) for x in input]
|
||||
feat2 = [self.F(x) for x in input2]
|
||||
|
||||
for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
|
||||
cr_s = [j for j in domain if j != i]
|
||||
|
||||
# Learning expert
|
||||
pred_i = self.E(i, feat_i)
|
||||
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
|
||||
expert_label_i = pred_i.detach()
|
||||
acc += compute_accuracy(pred_i.detach(),
|
||||
label_i.max(1)[1])[0].item()
|
||||
|
||||
# Consistency regularization
|
||||
cr_pred = []
|
||||
for j in cr_s:
|
||||
pred_j = self.E(j, feat2_i)
|
||||
pred_j = pred_j.unsqueeze(1)
|
||||
cr_pred.append(pred_j)
|
||||
cr_pred = torch.cat(cr_pred, 1)
|
||||
cr_pred = cr_pred.mean(1)
|
||||
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
|
||||
|
||||
loss_x /= self.n_domain
|
||||
loss_cr /= self.n_domain
|
||||
acc /= self.n_domain
|
||||
|
||||
loss = 0
|
||||
loss += loss_x
|
||||
loss += loss_cr
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc": acc,
|
||||
"loss_cr": loss_cr.item()
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
input2 = batch["img2"]
|
||||
label = batch["label"]
|
||||
domain = batch["domain"]
|
||||
|
||||
label = create_onehot(label, self.num_classes)
|
||||
|
||||
input = input.to(self.device)
|
||||
input2 = input2.to(self.device)
|
||||
label = label.to(self.device)
|
||||
|
||||
return input, input2, label, domain
|
||||
|
||||
def model_inference(self, input):
|
||||
f = self.F(input)
|
||||
p = []
|
||||
for k in range(self.num_source_domains):
|
||||
p_k = self.E(k, f)
|
||||
p_k = p_k.unsqueeze(1)
|
||||
p.append(p_k)
|
||||
p = torch.cat(p, 1)
|
||||
p = p.mean(1)
|
||||
return p
|
||||
107
Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
Normal file
107
Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import count_num_param
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.modeling import build_network
|
||||
from dassl.engine.trainer import SimpleNet
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class DDAIG(TrainerX):
|
||||
"""Deep Domain-Adversarial Image Generation.
|
||||
|
||||
https://arxiv.org/abs/2003.06054.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.lmda = cfg.TRAINER.DDAIG.LMDA
|
||||
self.clamp = cfg.TRAINER.DDAIG.CLAMP
|
||||
self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
|
||||
self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
|
||||
self.warmup = cfg.TRAINER.DDAIG.WARMUP
|
||||
self.alpha = cfg.TRAINER.DDAIG.ALPHA
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building F")
|
||||
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||
self.F.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.F)))
|
||||
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||
|
||||
print("Building D")
|
||||
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
|
||||
self.D.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.D)))
|
||||
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
|
||||
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
|
||||
self.register_model("D", self.D, self.optim_D, self.sched_D)
|
||||
|
||||
print("Building G")
|
||||
self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
|
||||
self.G.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.G)))
|
||||
self.optim_G = build_optimizer(self.G, cfg.OPTIM)
|
||||
self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
|
||||
self.register_model("G", self.G, self.optim_G, self.sched_G)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
input, label, domain = self.parse_batch_train(batch)
|
||||
|
||||
#############
|
||||
# Update G
|
||||
#############
|
||||
input_p = self.G(input, lmda=self.lmda)
|
||||
if self.clamp:
|
||||
input_p = torch.clamp(
|
||||
input_p, min=self.clamp_min, max=self.clamp_max
|
||||
)
|
||||
loss_g = 0
|
||||
# Minimize label loss
|
||||
loss_g += F.cross_entropy(self.F(input_p), label)
|
||||
# Maximize domain loss
|
||||
loss_g -= F.cross_entropy(self.D(input_p), domain)
|
||||
self.model_backward_and_update(loss_g, "G")
|
||||
|
||||
# Perturb data with new G
|
||||
with torch.no_grad():
|
||||
input_p = self.G(input, lmda=self.lmda)
|
||||
if self.clamp:
|
||||
input_p = torch.clamp(
|
||||
input_p, min=self.clamp_min, max=self.clamp_max
|
||||
)
|
||||
|
||||
#############
|
||||
# Update F
|
||||
#############
|
||||
loss_f = F.cross_entropy(self.F(input), label)
|
||||
if (self.epoch + 1) > self.warmup:
|
||||
loss_fp = F.cross_entropy(self.F(input_p), label)
|
||||
loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
|
||||
self.model_backward_and_update(loss_f, "F")
|
||||
|
||||
#############
|
||||
# Update D
|
||||
#############
|
||||
loss_d = F.cross_entropy(self.D(input), domain)
|
||||
self.model_backward_and_update(loss_d, "D")
|
||||
|
||||
loss_summary = {
|
||||
"loss_g": loss_g.item(),
|
||||
"loss_f": loss_f.item(),
|
||||
"loss_d": loss_d.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.F(input)
|
||||
32
Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py
Normal file
32
Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class Vanilla(TrainerX):
|
||||
"""Vanilla baseline."""
|
||||
|
||||
def forward_backward(self, batch):
|
||||
input, label = self.parse_batch_train(batch)
|
||||
output = self.model(input)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss": loss.item(),
|
||||
"acc": compute_accuracy(output, label)[0].item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
5
Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .entmin import EntMin
|
||||
from .fixmatch import FixMatch
|
||||
from .mixmatch import MixMatch
|
||||
from .mean_teacher import MeanTeacher
|
||||
from .sup_baseline import SupBaseline
|
||||
41
Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
Normal file
41
Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class EntMin(TrainerXU):
|
||||
"""Entropy Minimization.
|
||||
|
||||
http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.lmda = cfg.TRAINER.ENTMIN.LMDA
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
|
||||
output_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(output_x, label_x)
|
||||
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()
|
||||
|
||||
loss = loss_x + loss_u * self.lmda
|
||||
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
112
Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
Normal file
112
Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.data.transforms import build_transform
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class FixMatch(TrainerXU):
|
||||
"""FixMatch: Simplifying Semi-Supervised Learning with
|
||||
Consistency and Confidence.
|
||||
|
||||
https://arxiv.org/abs/2001.07685.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U
|
||||
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0
|
||||
|
||||
def build_data_loader(self):
|
||||
cfg = self.cfg
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
custom_tfm_train = [tfm_train]
|
||||
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS
|
||||
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||
custom_tfm_train += [tfm_train_strong]
|
||||
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||
self.train_loader_x = self.dm.train_loader_x
|
||||
self.train_loader_u = self.dm.train_loader_u
|
||||
self.val_loader = self.dm.val_loader
|
||||
self.test_loader = self.dm.test_loader
|
||||
self.num_classes = self.dm.num_classes
|
||||
|
||||
def assess_y_pred_quality(self, y_pred, y_true, mask):
|
||||
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
|
||||
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
|
||||
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
|
||||
keep_rate = mask.sum() / mask.numel()
|
||||
output = {
|
||||
"acc_thre": acc_thre,
|
||||
"acc_raw": acc_raw,
|
||||
"keep_rate": keep_rate
|
||||
}
|
||||
return output
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed_data = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
|
||||
input_u = torch.cat([input_x, input_u], 0)
|
||||
input_u2 = torch.cat([input_x2, input_u2], 0)
|
||||
n_x = input_x.size(0)
|
||||
|
||||
# Generate pseudo labels
|
||||
with torch.no_grad():
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
max_prob, label_u_pred = output_u.max(1)
|
||||
mask_u = (max_prob >= self.conf_thre).float()
|
||||
|
||||
# Evaluate pseudo labels' accuracy
|
||||
y_u_pred_stats = self.assess_y_pred_quality(
|
||||
label_u_pred[n_x:], label_u, mask_u[n_x:]
|
||||
)
|
||||
|
||||
# Supervised loss
|
||||
output_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(output_x, label_x)
|
||||
|
||||
# Unsupervised loss
|
||||
output_u = self.model(input_u2)
|
||||
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none")
|
||||
loss_u = (loss_u * mask_u).mean()
|
||||
|
||||
loss = loss_x + loss_u * self.weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"],
|
||||
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"],
|
||||
"y_u_pred_keep": y_u_pred_stats["keep_rate"],
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
input_x2 = batch_x["img2"]
|
||||
label_x = batch_x["label"]
|
||||
input_u = batch_u["img"]
|
||||
input_u2 = batch_u["img2"]
|
||||
# label_u is used only for evaluating pseudo labels' accuracy
|
||||
label_u = batch_u["label"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
input_x2 = input_x2.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
input_u2 = input_u2.to(self.device)
|
||||
label_u = label_u.to(self.device)
|
||||
|
||||
return input_x, input_x2, label_x, input_u, input_u2, label_u
|
||||
54
Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py
Normal file
54
Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import copy
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MeanTeacher(TrainerXU):
|
||||
"""Mean teacher.
|
||||
|
||||
https://arxiv.org/abs/1703.01780.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U
|
||||
self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA
|
||||
self.rampup = cfg.TRAINER.MEANTEA.RAMPUP
|
||||
|
||||
self.teacher = copy.deepcopy(self.model)
|
||||
self.teacher.train()
|
||||
for param in self.teacher.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
|
||||
logit_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(logit_x, label_x)
|
||||
|
||||
target_u = F.softmax(self.teacher(input_u), 1)
|
||||
prob_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = ((prob_u - target_u)**2).sum(1).mean()
|
||||
|
||||
weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
|
||||
loss = loss_x + loss_u*weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
|
||||
ema_model_update(self.model, self.teacher, ema_alpha)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
98
Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py
Normal file
98
Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.modeling.ops import mixup
|
||||
from dassl.modeling.ops.utils import (
|
||||
sharpen_prob, create_onehot, linear_rampup, shuffle_index
|
||||
)
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MixMatch(TrainerXU):
|
||||
"""MixMatch: A Holistic Approach to Semi-Supervised Learning.
|
||||
|
||||
https://arxiv.org/abs/1905.02249.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U
|
||||
self.temp = cfg.TRAINER.MIXMATCH.TEMP
|
||||
self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA
|
||||
self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.K_TRANSFORMS > 1
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
num_x = input_x.shape[0]
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
weight_u = self.weight_u * linear_rampup(global_step, self.rampup)
|
||||
|
||||
# Generate pseudo-label for unlabeled data
|
||||
with torch.no_grad():
|
||||
output_u = 0
|
||||
for input_ui in input_u:
|
||||
output_ui = F.softmax(self.model(input_ui), 1)
|
||||
output_u += output_ui
|
||||
output_u /= len(input_u)
|
||||
label_u = sharpen_prob(output_u, self.temp)
|
||||
label_u = [label_u] * len(input_u)
|
||||
label_u = torch.cat(label_u, 0)
|
||||
input_u = torch.cat(input_u, 0)
|
||||
|
||||
# Combine and shuffle labeled and unlabeled data
|
||||
input_xu = torch.cat([input_x, input_u], 0)
|
||||
label_xu = torch.cat([label_x, label_u], 0)
|
||||
input_xu, label_xu = shuffle_index(input_xu, label_xu)
|
||||
|
||||
# Mixup
|
||||
input_x, label_x = mixup(
|
||||
input_x,
|
||||
input_xu[:num_x],
|
||||
label_x,
|
||||
label_xu[:num_x],
|
||||
self.beta,
|
||||
preserve_order=True,
|
||||
)
|
||||
|
||||
input_u, label_u = mixup(
|
||||
input_u,
|
||||
input_xu[num_x:],
|
||||
label_u,
|
||||
label_xu[num_x:],
|
||||
self.beta,
|
||||
preserve_order=True,
|
||||
)
|
||||
|
||||
# Compute losses
|
||||
output_x = F.softmax(self.model(input_x), 1)
|
||||
loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()
|
||||
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = ((label_u - output_u)**2).mean()
|
||||
|
||||
loss = loss_x + loss_u*weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"][0]
|
||||
label_x = batch_x["label"]
|
||||
label_x = create_onehot(label_x, self.num_classes)
|
||||
input_u = batch_u["img"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = [input_ui.to(self.device) for input_ui in input_u]
|
||||
|
||||
return input_x, label_x, input_u
|
||||
32
Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py
Normal file
32
Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class SupBaseline(TrainerXU):
|
||||
"""Supervised Baseline."""
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input, label = self.parse_batch_train(batch_x, batch_u)
|
||||
output = self.model(input)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss": loss.item(),
|
||||
"acc": compute_accuracy(output, label)[0].item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input = batch_x["img"]
|
||||
label = batch_x["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
735
Dassl.ProGrad.pytorch/dassl/engine/trainer.py
Normal file
735
Dassl.ProGrad.pytorch/dassl/engine/trainer.py
Normal file
@@ -0,0 +1,735 @@
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.utils import (
|
||||
MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,
|
||||
save_checkpoint, mkdir_if_missing, resume_from_checkpoint,
|
||||
load_pretrained_weights
|
||||
)
|
||||
from dassl.modeling import build_head, build_backbone
|
||||
from dassl.evaluation import build_evaluator
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
"""A simple neural network composed of a CNN backbone
|
||||
and optionally a head such as mlp for classification.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, model_cfg, num_classes, **kwargs):
|
||||
super().__init__()
|
||||
self.backbone = build_backbone(
|
||||
model_cfg.BACKBONE.NAME,
|
||||
verbose=cfg.VERBOSE,
|
||||
pretrained=model_cfg.BACKBONE.PRETRAINED,
|
||||
**kwargs,
|
||||
)
|
||||
fdim = self.backbone.out_features
|
||||
|
||||
self.head = None
|
||||
if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:
|
||||
self.head = build_head(
|
||||
model_cfg.HEAD.NAME,
|
||||
verbose=cfg.VERBOSE,
|
||||
in_features=fdim,
|
||||
hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,
|
||||
activation=model_cfg.HEAD.ACTIVATION,
|
||||
bn=model_cfg.HEAD.BN,
|
||||
dropout=model_cfg.HEAD.DROPOUT,
|
||||
**kwargs,
|
||||
)
|
||||
fdim = self.head.out_features
|
||||
|
||||
self.classifier = None
|
||||
if num_classes > 0:
|
||||
self.classifier = nn.Linear(fdim, num_classes)
|
||||
|
||||
self._fdim = fdim
|
||||
|
||||
@property
|
||||
def fdim(self):
|
||||
return self._fdim
|
||||
|
||||
def forward(self, x, return_feature=False):
|
||||
f = self.backbone(x)
|
||||
if self.head is not None:
|
||||
f = self.head(f)
|
||||
|
||||
if self.classifier is None:
|
||||
return f
|
||||
|
||||
y = self.classifier(f)
|
||||
|
||||
if return_feature:
|
||||
return y, f
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class TrainerBase:
|
||||
"""Base class for iterative trainer."""
|
||||
|
||||
def __init__(self):
|
||||
self._models = OrderedDict()
|
||||
self._optims = OrderedDict()
|
||||
self._scheds = OrderedDict()
|
||||
self._writer = None
|
||||
|
||||
def register_model(self, name="model", model=None, optim=None, sched=None):
|
||||
if self.__dict__.get("_models") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign model before super().__init__() call"
|
||||
)
|
||||
|
||||
if self.__dict__.get("_optims") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign optim before super().__init__() call"
|
||||
)
|
||||
|
||||
if self.__dict__.get("_scheds") is None:
|
||||
raise AttributeError(
|
||||
"Cannot assign sched before super().__init__() call"
|
||||
)
|
||||
|
||||
assert name not in self._models, "Found duplicate model names"
|
||||
|
||||
self._models[name] = model
|
||||
self._optims[name] = optim
|
||||
self._scheds[name] = sched
|
||||
|
||||
def get_model_names(self, names=None):
|
||||
names_real = list(self._models.keys())
|
||||
if names is not None:
|
||||
names = tolist_if_not(names)
|
||||
for name in names:
|
||||
assert name in names_real
|
||||
return names
|
||||
else:
|
||||
return names_real
|
||||
|
||||
def save_model(self, epoch, directory, is_best=False, model_name=""):
|
||||
names = self.get_model_names()
|
||||
|
||||
for name in names:
|
||||
model_dict = self._models[name].state_dict()
|
||||
|
||||
optim_dict = None
|
||||
if self._optims[name] is not None:
|
||||
optim_dict = self._optims[name].state_dict()
|
||||
|
||||
sched_dict = None
|
||||
if self._scheds[name] is not None:
|
||||
sched_dict = self._scheds[name].state_dict()
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"state_dict": model_dict,
|
||||
"epoch": epoch + 1,
|
||||
"optimizer": optim_dict,
|
||||
"scheduler": sched_dict,
|
||||
},
|
||||
osp.join(directory, name),
|
||||
is_best=is_best,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
def resume_model_if_exist(self, directory):
|
||||
names = self.get_model_names()
|
||||
file_missing = False
|
||||
|
||||
for name in names:
|
||||
path = osp.join(directory, name)
|
||||
if not osp.exists(path):
|
||||
file_missing = True
|
||||
break
|
||||
|
||||
if file_missing:
|
||||
print("No checkpoint found, train from scratch")
|
||||
return 0
|
||||
|
||||
print(
|
||||
'Found checkpoint in "{}". Will resume training'.format(directory)
|
||||
)
|
||||
|
||||
for name in names:
|
||||
path = osp.join(directory, name)
|
||||
start_epoch = resume_from_checkpoint(
|
||||
path, self._models[name], self._optims[name],
|
||||
self._scheds[name]
|
||||
)
|
||||
|
||||
return start_epoch
|
||||
|
||||
def load_model(self, directory, epoch=None):
|
||||
if not directory:
|
||||
print(
|
||||
"Note that load_model() is skipped as no pretrained "
|
||||
"model is given (ignore this if it's done on purpose)"
|
||||
)
|
||||
return
|
||||
|
||||
names = self.get_model_names()
|
||||
|
||||
# By default, the best model is loaded
|
||||
model_file = "model-best.pth.tar"
|
||||
|
||||
if epoch is not None:
|
||||
model_file = "model.pth.tar-" + str(epoch)
|
||||
|
||||
for name in names:
|
||||
model_path = osp.join(directory, name, model_file)
|
||||
|
||||
if not osp.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
'Model not found at "{}"'.format(model_path)
|
||||
)
|
||||
|
||||
checkpoint = load_checkpoint(model_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint["epoch"]
|
||||
|
||||
print(
|
||||
"Loading weights to {} "
|
||||
'from "{}" (epoch = {})'.format(name, model_path, epoch)
|
||||
)
|
||||
self._models[name].load_state_dict(state_dict)
|
||||
|
||||
def set_model_mode(self, mode="train", names=None):
|
||||
names = self.get_model_names(names)
|
||||
|
||||
for name in names:
|
||||
if mode == "train":
|
||||
self._models[name].train()
|
||||
elif mode in ["test", "eval"]:
|
||||
self._models[name].eval()
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
def update_lr(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
|
||||
for name in names:
|
||||
if self._scheds[name] is not None:
|
||||
self._scheds[name].step()
|
||||
|
||||
def detect_anomaly(self, loss):
|
||||
if not torch.isfinite(loss).all():
|
||||
raise FloatingPointError("Loss is infinite or NaN!")
|
||||
|
||||
def init_writer(self, log_dir):
|
||||
if self.__dict__.get("_writer") is None or self._writer is None:
|
||||
print(
|
||||
"Initializing summary writer for tensorboard "
|
||||
"with log_dir={}".format(log_dir)
|
||||
)
|
||||
self._writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def close_writer(self):
|
||||
if self._writer is not None:
|
||||
self._writer.close()
|
||||
|
||||
def write_scalar(self, tag, scalar_value, global_step=None):
|
||||
if self._writer is None:
|
||||
# Do nothing if writer is not initialized
|
||||
# Note that writer is only used when training is needed
|
||||
pass
|
||||
else:
|
||||
self._writer.add_scalar(tag, scalar_value, global_step)
|
||||
|
||||
def train(self, start_epoch, max_epoch):
|
||||
"""Generic training loops."""
|
||||
self.start_epoch = start_epoch
|
||||
self.max_epoch = max_epoch
|
||||
|
||||
self.before_train()
|
||||
for self.epoch in range(self.start_epoch, self.max_epoch):
|
||||
self.before_epoch()
|
||||
self.run_epoch()
|
||||
self.after_epoch()
|
||||
self.after_train()
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def after_train(self):
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def run_epoch(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_batch_test(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_backward(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def model_inference(self, input):
|
||||
raise NotImplementedError
|
||||
|
||||
def model_zero_grad(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
for name in names:
|
||||
if self._optims[name] is not None:
|
||||
self._optims[name].zero_grad()
|
||||
|
||||
def model_backward(self, loss):
|
||||
self.detect_anomaly(loss)
|
||||
loss.backward()
|
||||
|
||||
def model_update(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
for name in names:
|
||||
if self._optims[name] is not None:
|
||||
self._optims[name].step()
|
||||
|
||||
def model_backward_and_update(self, loss, names=None):
|
||||
self.model_zero_grad(names)
|
||||
self.model_backward(loss)
|
||||
self.model_update(names)
|
||||
|
||||
def prograd_backward_and_update(
|
||||
self, loss_a, loss_b, lambda_=1, names=None
|
||||
):
|
||||
# loss_b not increase is okay
|
||||
# loss_a has to decline
|
||||
self.model_zero_grad(names)
|
||||
# get name of the model parameters
|
||||
names = self.get_model_names(names)
|
||||
# backward loss_a
|
||||
self.detect_anomaly(loss_b)
|
||||
loss_b.backward(retain_graph=True)
|
||||
# normalize gradient
|
||||
b_grads = []
|
||||
for name in names:
|
||||
for p in self._models[name].parameters():
|
||||
b_grads.append(p.grad.clone())
|
||||
|
||||
# optimizer don't step
|
||||
for name in names:
|
||||
self._optims[name].zero_grad()
|
||||
|
||||
# backward loss_a
|
||||
self.detect_anomaly(loss_a)
|
||||
loss_a.backward()
|
||||
for name in names:
|
||||
for p, b_grad in zip(self._models[name].parameters(), b_grads):
|
||||
# calculate cosine distance
|
||||
b_grad_norm = b_grad / torch.linalg.norm(b_grad)
|
||||
a_grad = p.grad.clone()
|
||||
a_grad_norm = a_grad / torch.linalg.norm(a_grad)
|
||||
|
||||
if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:
|
||||
p.grad = a_grad - lambda_ * torch.dot(
|
||||
a_grad.flatten(), b_grad_norm.flatten()
|
||||
) * b_grad_norm
|
||||
|
||||
# optimizer
|
||||
for name in names:
|
||||
self._optims[name].step()
|
||||
|
||||
|
||||
class SimpleTrainer(TrainerBase):
|
||||
"""A simple trainer class implementing generic functions."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.check_cfg(cfg)
|
||||
|
||||
if torch.cuda.is_available() and cfg.USE_CUDA:
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# Save as attributes some frequently used variables
|
||||
self.start_epoch = self.epoch = 0
|
||||
self.max_epoch = cfg.OPTIM.MAX_EPOCH
|
||||
self.output_dir = cfg.OUTPUT_DIR
|
||||
|
||||
self.cfg = cfg
|
||||
self.build_data_loader()
|
||||
self.build_model()
|
||||
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
|
||||
self.best_result = -np.inf
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
"""Check whether some variables are set correctly for
|
||||
the trainer (optional).
|
||||
|
||||
For example, a trainer might require a particular sampler
|
||||
for training such as 'RandomDomainSampler', so it is good
|
||||
to do the checking:
|
||||
|
||||
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
|
||||
"""
|
||||
pass
|
||||
|
||||
def build_data_loader(self):
|
||||
"""Create essential data-related attributes.
|
||||
|
||||
A re-implementation of this method must create the
|
||||
same attributes (except self.dm).
|
||||
"""
|
||||
dm = DataManager(self.cfg)
|
||||
|
||||
self.train_loader_x = dm.train_loader_x
|
||||
self.train_loader_u = dm.train_loader_u # optional, can be None
|
||||
self.val_loader = dm.val_loader # optional, can be None
|
||||
self.test_loader = dm.test_loader
|
||||
self.num_classes = dm.num_classes
|
||||
self.num_source_domains = dm.num_source_domains
|
||||
self.lab2cname = dm.lab2cname # dict {label: classname}
|
||||
|
||||
self.dm = dm
|
||||
|
||||
def build_model(self):
|
||||
"""Build and register model.
|
||||
|
||||
The default builds a classification model along with its
|
||||
optimizer and scheduler.
|
||||
|
||||
Custom trainers can re-implement this method if necessary.
|
||||
"""
|
||||
cfg = self.cfg
|
||||
|
||||
print("Building model")
|
||||
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||
if cfg.MODEL.INIT_WEIGHTS:
|
||||
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
|
||||
self.model.to(self.device)
|
||||
print("# params: {:,}".format(count_num_param(self.model)))
|
||||
self.optim = build_optimizer(self.model, cfg.OPTIM)
|
||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||
self.register_model("model", self.model, self.optim, self.sched)
|
||||
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count > 1:
|
||||
print(
|
||||
f"Detected {device_count} GPUs. Wrap the model with nn.DataParallel"
|
||||
)
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
def train(self):
|
||||
super().train(self.start_epoch, self.max_epoch)
|
||||
|
||||
def before_train(self):
|
||||
directory = self.cfg.OUTPUT_DIR
|
||||
if self.cfg.RESUME:
|
||||
directory = self.cfg.RESUME
|
||||
self.start_epoch = self.resume_model_if_exist(directory)
|
||||
|
||||
# Initialize summary writer
|
||||
writer_dir = osp.join(self.output_dir, "tensorboard")
|
||||
mkdir_if_missing(writer_dir)
|
||||
self.init_writer(writer_dir)
|
||||
|
||||
# Remember the starting time (for computing the elapsed time)
|
||||
self.time_start = time.time()
|
||||
|
||||
def after_train(self):
|
||||
print("Finished training")
|
||||
|
||||
do_test = not self.cfg.TEST.NO_TEST
|
||||
if do_test:
|
||||
if self.cfg.TEST.FINAL_MODEL == "best_val":
|
||||
print("Deploy the model with the best val performance")
|
||||
self.load_model(self.output_dir)
|
||||
self.test()
|
||||
|
||||
# Show elapsed time
|
||||
elapsed = round(time.time() - self.time_start)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed: {}".format(elapsed))
|
||||
|
||||
# Close writer
|
||||
self.close_writer()
|
||||
|
||||
def after_epoch(self):
|
||||
last_epoch = (self.epoch + 1) == self.max_epoch
|
||||
do_test = not self.cfg.TEST.NO_TEST
|
||||
meet_checkpoint_freq = (
|
||||
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
|
||||
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
|
||||
)
|
||||
|
||||
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
|
||||
curr_result = self.test(split="val")
|
||||
is_best = curr_result > self.best_result
|
||||
if is_best:
|
||||
self.best_result = curr_result
|
||||
self.save_model(
|
||||
self.epoch,
|
||||
self.output_dir,
|
||||
model_name="model-best.pth.tar"
|
||||
)
|
||||
|
||||
if meet_checkpoint_freq or last_epoch:
|
||||
self.save_model(self.epoch, self.output_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def output_test(self, split=None):
|
||||
"""testing pipline, which could also output the results."""
|
||||
self.set_model_mode("eval")
|
||||
self.evaluator.reset()
|
||||
|
||||
output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')
|
||||
res_json = {}
|
||||
|
||||
if split is None:
|
||||
split = self.cfg.TEST.SPLIT
|
||||
|
||||
if split == "val" and self.val_loader is not None:
|
||||
data_loader = self.val_loader
|
||||
print("Do evaluation on {} set".format(split))
|
||||
else:
|
||||
data_loader = self.test_loader
|
||||
print("Do evaluation on test set")
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
img_path = batch['impath']
|
||||
input, label = self.parse_batch_test(batch)
|
||||
output = self.model_inference(input)
|
||||
self.evaluator.process(output, label)
|
||||
for i in range(len(img_path)):
|
||||
res_json[img_path[i]] = {
|
||||
'predict': output[i].cpu().numpy().tolist(),
|
||||
'gt': label[i].cpu().numpy().tolist()
|
||||
}
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(res_json, f)
|
||||
results = self.evaluator.evaluate()
|
||||
|
||||
for k, v in results.items():
|
||||
tag = "{}/{}".format(split, k)
|
||||
self.write_scalar(tag, v, self.epoch)
|
||||
|
||||
return list(results.values())[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def test(self, split=None):
|
||||
"""A generic testing pipeline."""
|
||||
self.set_model_mode("eval")
|
||||
self.evaluator.reset()
|
||||
|
||||
if split is None:
|
||||
split = self.cfg.TEST.SPLIT
|
||||
|
||||
if split == "val" and self.val_loader is not None:
|
||||
data_loader = self.val_loader
|
||||
print("Do evaluation on {} set".format(split))
|
||||
else:
|
||||
data_loader = self.test_loader
|
||||
print("Do evaluation on test set")
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
input, label = self.parse_batch_test(batch)
|
||||
output = self.model_inference(input)
|
||||
self.evaluator.process(output, label)
|
||||
|
||||
results = self.evaluator.evaluate()
|
||||
|
||||
for k, v in results.items():
|
||||
tag = "{}/{}".format(split, k)
|
||||
self.write_scalar(tag, v, self.epoch)
|
||||
|
||||
return list(results.values())[0]
|
||||
|
||||
def model_inference(self, input):
|
||||
return self.model(input)
|
||||
|
||||
def parse_batch_test(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
|
||||
return input, label
|
||||
|
||||
def get_current_lr(self, names=None):
|
||||
names = self.get_model_names(names)
|
||||
name = names[0]
|
||||
return self._optims[name].param_groups[0]["lr"]
|
||||
|
||||
|
||||
class TrainerXU(SimpleTrainer):
|
||||
"""A base trainer using both labeled and unlabeled data.
|
||||
|
||||
In the context of domain adaptation, labeled and unlabeled data
|
||||
come from source and target domains respectively.
|
||||
|
||||
When it comes to semi-supervised learning, all data comes from the
|
||||
same domain.
|
||||
"""
|
||||
|
||||
def run_epoch(self):
|
||||
self.set_model_mode("train")
|
||||
losses = MetricMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
||||
# Decide to iterate over labeled or unlabeled dataset
|
||||
len_train_loader_x = len(self.train_loader_x)
|
||||
len_train_loader_u = len(self.train_loader_u)
|
||||
if self.cfg.TRAIN.COUNT_ITER == "train_x":
|
||||
self.num_batches = len_train_loader_x
|
||||
elif self.cfg.TRAIN.COUNT_ITER == "train_u":
|
||||
self.num_batches = len_train_loader_u
|
||||
elif self.cfg.TRAIN.COUNT_ITER == "smaller_one":
|
||||
self.num_batches = min(len_train_loader_x, len_train_loader_u)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
train_loader_x_iter = iter(self.train_loader_x)
|
||||
train_loader_u_iter = iter(self.train_loader_u)
|
||||
|
||||
end = time.time()
|
||||
for self.batch_idx in range(self.num_batches):
|
||||
try:
|
||||
batch_x = next(train_loader_x_iter)
|
||||
except StopIteration:
|
||||
train_loader_x_iter = iter(self.train_loader_x)
|
||||
batch_x = next(train_loader_x_iter)
|
||||
|
||||
try:
|
||||
batch_u = next(train_loader_u_iter)
|
||||
except StopIteration:
|
||||
train_loader_u_iter = iter(self.train_loader_u)
|
||||
batch_u = next(train_loader_u_iter)
|
||||
|
||||
data_time.update(time.time() - end)
|
||||
loss_summary = self.forward_backward(batch_x, batch_u)
|
||||
batch_time.update(time.time() - end)
|
||||
losses.update(loss_summary)
|
||||
|
||||
if (
|
||||
self.batch_idx + 1
|
||||
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
|
||||
nb_remain = 0
|
||||
nb_remain += self.num_batches - self.batch_idx - 1
|
||||
nb_remain += (
|
||||
self.max_epoch - self.epoch - 1
|
||||
) * self.num_batches
|
||||
eta_seconds = batch_time.avg * nb_remain
|
||||
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
print(
|
||||
"epoch [{0}/{1}][{2}/{3}]\t"
|
||||
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
||||
"eta {eta}\t"
|
||||
"{losses}\t"
|
||||
"lr {lr:.6e}".format(
|
||||
self.epoch + 1,
|
||||
self.max_epoch,
|
||||
self.batch_idx + 1,
|
||||
self.num_batches,
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
eta=eta,
|
||||
losses=losses,
|
||||
lr=self.get_current_lr(),
|
||||
)
|
||||
)
|
||||
|
||||
n_iter = self.epoch * self.num_batches + self.batch_idx
|
||||
for name, meter in losses.meters.items():
|
||||
self.write_scalar("train/" + name, meter.avg, n_iter)
|
||||
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
||||
|
||||
end = time.time()
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
label_x = batch_x["label"]
|
||||
input_u = batch_u["img"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
|
||||
return input_x, label_x, input_u
|
||||
|
||||
|
||||
class TrainerX(SimpleTrainer):
|
||||
"""A base trainer using labeled data only."""
|
||||
|
||||
def run_epoch(self):
|
||||
self.set_model_mode("train")
|
||||
losses = MetricMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
self.num_batches = len(self.train_loader_x)
|
||||
|
||||
end = time.time()
|
||||
for self.batch_idx, batch in enumerate(self.train_loader_x):
|
||||
data_time.update(time.time() - end)
|
||||
loss_summary = self.forward_backward(batch)
|
||||
batch_time.update(time.time() - end)
|
||||
losses.update(loss_summary)
|
||||
|
||||
if (
|
||||
self.batch_idx + 1
|
||||
) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:
|
||||
nb_remain = 0
|
||||
nb_remain += self.num_batches - self.batch_idx - 1
|
||||
nb_remain += (
|
||||
self.max_epoch - self.epoch - 1
|
||||
) * self.num_batches
|
||||
eta_seconds = batch_time.avg * nb_remain
|
||||
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
print(
|
||||
"epoch [{0}/{1}][{2}/{3}]\t"
|
||||
"time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
||||
"eta {eta}\t"
|
||||
"{losses}\t"
|
||||
"lr {lr:.6e}".format(
|
||||
self.epoch + 1,
|
||||
self.max_epoch,
|
||||
self.batch_idx + 1,
|
||||
self.num_batches,
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
eta=eta,
|
||||
losses=losses,
|
||||
lr=self.get_current_lr(),
|
||||
)
|
||||
)
|
||||
|
||||
n_iter = self.epoch * self.num_batches + self.batch_idx
|
||||
for name, meter in losses.meters.items():
|
||||
self.write_scalar("train/" + name, meter.avg, n_iter)
|
||||
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
||||
|
||||
end = time.time()
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
domain = batch["domain"]
|
||||
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
domain = domain.to(self.device)
|
||||
|
||||
return input, label, domain
|
||||
3
Dassl.ProGrad.pytorch/dassl/evaluation/__init__.py
Normal file
3
Dassl.ProGrad.pytorch/dassl/evaluation/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .build import build_evaluator, EVALUATOR_REGISTRY # isort:skip
|
||||
|
||||
from .evaluator import EvaluatorBase, Classification
|
||||
11
Dassl.ProGrad.pytorch/dassl/evaluation/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/evaluation/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
EVALUATOR_REGISTRY = Registry("EVALUATOR")
|
||||
|
||||
|
||||
def build_evaluator(cfg, **kwargs):
|
||||
avai_evaluators = EVALUATOR_REGISTRY.registered_names()
|
||||
check_availability(cfg.TEST.EVALUATOR, avai_evaluators)
|
||||
if cfg.VERBOSE:
|
||||
print("Loading evaluator: {}".format(cfg.TEST.EVALUATOR))
|
||||
return EVALUATOR_REGISTRY.get(cfg.TEST.EVALUATOR)(cfg, **kwargs)
|
||||
127
Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py
Normal file
127
Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from collections import OrderedDict, defaultdict
|
||||
import torch
|
||||
from sklearn.metrics import f1_score, confusion_matrix
|
||||
|
||||
from .build import EVALUATOR_REGISTRY
|
||||
|
||||
|
||||
class EvaluatorBase:
|
||||
"""Base evaluator."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def process(self, mo, gt):
|
||||
raise NotImplementedError
|
||||
|
||||
def evaluate(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@EVALUATOR_REGISTRY.register()
|
||||
class Classification(EvaluatorBase):
|
||||
"""Evaluator for classification."""
|
||||
|
||||
def __init__(self, cfg, lab2cname=None, **kwargs):
|
||||
super().__init__(cfg)
|
||||
self._lab2cname = lab2cname
|
||||
self._correct = 0
|
||||
self._total = 0
|
||||
self._per_class_res = None
|
||||
self._y_true = []
|
||||
self._y_pred = []
|
||||
if cfg.TEST.PER_CLASS_RESULT:
|
||||
assert lab2cname is not None
|
||||
self._per_class_res = defaultdict(list)
|
||||
|
||||
def reset(self):
|
||||
self._correct = 0
|
||||
self._total = 0
|
||||
self._y_true = []
|
||||
self._y_pred = []
|
||||
if self._per_class_res is not None:
|
||||
self._per_class_res = defaultdict(list)
|
||||
|
||||
def process(self, mo, gt):
|
||||
# mo (torch.Tensor): model output [batch, num_classes]
|
||||
# gt (torch.LongTensor): ground truth [batch]
|
||||
pred = mo.max(1)[1]
|
||||
matches = pred.eq(gt).float()
|
||||
self._correct += int(matches.sum().item())
|
||||
self._total += gt.shape[0]
|
||||
|
||||
self._y_true.extend(gt.data.cpu().numpy().tolist())
|
||||
self._y_pred.extend(pred.data.cpu().numpy().tolist())
|
||||
|
||||
if self._per_class_res is not None:
|
||||
for i, label in enumerate(gt):
|
||||
label = label.item()
|
||||
matches_i = int(matches[i].item())
|
||||
self._per_class_res[label].append(matches_i)
|
||||
|
||||
def evaluate(self):
|
||||
results = OrderedDict()
|
||||
acc = 100.0 * self._correct / self._total
|
||||
err = 100.0 - acc
|
||||
macro_f1 = 100.0 * f1_score(
|
||||
self._y_true,
|
||||
self._y_pred,
|
||||
average="macro",
|
||||
labels=np.unique(self._y_true)
|
||||
)
|
||||
|
||||
# The first value will be returned by trainer.test()
|
||||
results["accuracy"] = acc
|
||||
results["error_rate"] = err
|
||||
results["macro_f1"] = macro_f1
|
||||
|
||||
print(
|
||||
"=> result\n"
|
||||
f"* total: {self._total:,}\n"
|
||||
f"* correct: {self._correct:,}\n"
|
||||
f"* accuracy: {acc:.2f}%\n"
|
||||
f"* error: {err:.2f}%\n"
|
||||
f"* macro_f1: {macro_f1:.2f}%"
|
||||
)
|
||||
|
||||
if self._per_class_res is not None:
|
||||
labels = list(self._per_class_res.keys())
|
||||
labels.sort()
|
||||
|
||||
print("=> per-class result")
|
||||
accs = []
|
||||
|
||||
for label in labels:
|
||||
classname = self._lab2cname[label]
|
||||
res = self._per_class_res[label]
|
||||
correct = sum(res)
|
||||
total = len(res)
|
||||
acc = 100.0 * correct / total
|
||||
accs.append(acc)
|
||||
print(
|
||||
"* class: {} ({})\t"
|
||||
"total: {:,}\t"
|
||||
"correct: {:,}\t"
|
||||
"acc: {:.2f}%".format(
|
||||
label, classname, total, correct, acc
|
||||
)
|
||||
)
|
||||
mean_acc = np.mean(accs)
|
||||
print("* average: {:.2f}%".format(mean_acc))
|
||||
|
||||
results["perclass_accuracy"] = mean_acc
|
||||
|
||||
if self.cfg.TEST.COMPUTE_CMAT:
|
||||
cmat = confusion_matrix(
|
||||
self._y_true, self._y_pred, normalize="true"
|
||||
)
|
||||
save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
|
||||
torch.save(cmat, save_path)
|
||||
print('Confusion matrix is saved to "{}"'.format(save_path))
|
||||
|
||||
return results
|
||||
4
Dassl.ProGrad.pytorch/dassl/metrics/__init__.py
Normal file
4
Dassl.ProGrad.pytorch/dassl/metrics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .accuracy import compute_accuracy
|
||||
from .distance import (
|
||||
cosine_distance, compute_distance_matrix, euclidean_squared_distance
|
||||
)
|
||||
30
Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py
Normal file
30
Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
def compute_accuracy(output, target, topk=(1, )):
|
||||
"""Computes the accuracy over the k top predictions for
|
||||
the specified values of k.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): prediction matrix with shape (batch_size, num_classes).
|
||||
target (torch.LongTensor): ground truth labels with shape (batch_size).
|
||||
topk (tuple, optional): accuracy at top-k will be computed. For example,
|
||||
topk=(1, 5) means accuracy at top-1 and top-5 will be computed.
|
||||
|
||||
Returns:
|
||||
list: accuracy at top-k.
|
||||
"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
if isinstance(output, (tuple, list)):
|
||||
output = output[0]
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
acc = correct_k.mul_(100.0 / batch_size)
|
||||
res.append(acc)
|
||||
|
||||
return res
|
||||
77
Dassl.ProGrad.pytorch/dassl/metrics/distance.py
Normal file
77
Dassl.ProGrad.pytorch/dassl/metrics/distance.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Source: https://github.com/KaiyangZhou/deep-person-reid
|
||||
"""
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def compute_distance_matrix(input1, input2, metric="euclidean"):
|
||||
"""A wrapper function for computing distance matrix.
|
||||
|
||||
Each input matrix has the shape (n_data, feature_dim).
|
||||
|
||||
Args:
|
||||
input1 (torch.Tensor): 2-D feature matrix.
|
||||
input2 (torch.Tensor): 2-D feature matrix.
|
||||
metric (str, optional): "euclidean" or "cosine".
|
||||
Default is "euclidean".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: distance matrix.
|
||||
"""
|
||||
# check input
|
||||
assert isinstance(input1, torch.Tensor)
|
||||
assert isinstance(input2, torch.Tensor)
|
||||
assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
|
||||
input1.dim()
|
||||
)
|
||||
assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
|
||||
input2.dim()
|
||||
)
|
||||
assert input1.size(1) == input2.size(1)
|
||||
|
||||
if metric == "euclidean":
|
||||
distmat = euclidean_squared_distance(input1, input2)
|
||||
elif metric == "cosine":
|
||||
distmat = cosine_distance(input1, input2)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance metric: {}. "
|
||||
'Please choose either "euclidean" or "cosine"'.format(metric)
|
||||
)
|
||||
|
||||
return distmat
|
||||
|
||||
|
||||
def euclidean_squared_distance(input1, input2):
|
||||
"""Computes euclidean squared distance.
|
||||
|
||||
Args:
|
||||
input1 (torch.Tensor): 2-D feature matrix.
|
||||
input2 (torch.Tensor): 2-D feature matrix.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: distance matrix.
|
||||
"""
|
||||
m, n = input1.size(0), input2.size(0)
|
||||
mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
|
||||
mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat = mat1 + mat2
|
||||
distmat.addmm_(1, -2, input1, input2.t())
|
||||
return distmat
|
||||
|
||||
|
||||
def cosine_distance(input1, input2):
|
||||
"""Computes cosine distance.
|
||||
|
||||
Args:
|
||||
input1 (torch.Tensor): 2-D feature matrix.
|
||||
input2 (torch.Tensor): 2-D feature matrix.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: distance matrix.
|
||||
"""
|
||||
input1_normed = F.normalize(input1, p=2, dim=1)
|
||||
input2_normed = F.normalize(input2, p=2, dim=1)
|
||||
distmat = 1 - torch.mm(input1_normed, input2_normed.t())
|
||||
return distmat
|
||||
3
Dassl.ProGrad.pytorch/dassl/modeling/__init__.py
Normal file
3
Dassl.ProGrad.pytorch/dassl/modeling/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .head import HEAD_REGISTRY, build_head
|
||||
from .network import NETWORK_REGISTRY, build_network
|
||||
from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone
|
||||
27
Dassl.ProGrad.pytorch/dassl/modeling/backbone/__init__.py
Normal file
27
Dassl.ProGrad.pytorch/dassl/modeling/backbone/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from .build import build_backbone, BACKBONE_REGISTRY # isort:skip
|
||||
from .backbone import Backbone # isort:skip
|
||||
|
||||
from .vgg import vgg16
|
||||
from .resnet import (
|
||||
resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1,
|
||||
resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1,
|
||||
resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123,
|
||||
resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12,
|
||||
resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123,
|
||||
resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123
|
||||
)
|
||||
from .alexnet import alexnet
|
||||
from .mobilenetv2 import mobilenetv2
|
||||
from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2
|
||||
from .cnn_digitsdg import cnn_digitsdg
|
||||
from .efficientnet import (
|
||||
efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3,
|
||||
efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
|
||||
)
|
||||
from .shufflenetv2 import (
|
||||
shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5,
|
||||
shufflenet_v2_x2_0
|
||||
)
|
||||
from .cnn_digitsingle import cnn_digitsingle
|
||||
from .preact_resnet18 import preact_resnet18
|
||||
from .cnn_digit5_m3sda import cnn_digit5_m3sda
|
||||
64
Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py
Normal file
64
Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
model_urls = {
|
||||
"alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth",
|
||||
}
|
||||
|
||||
|
||||
class AlexNet(Backbone):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2),
|
||||
)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
|
||||
# Note that self.classifier outputs features rather than logits
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(),
|
||||
nn.Linear(256 * 6 * 6, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self._out_features = 4096
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model.load_state_dict(pretrain_dict, strict=False)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def alexnet(pretrained=True, **kwargs):
|
||||
model = AlexNet()
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["alexnet"])
|
||||
|
||||
return model
|
||||
17
Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py
Normal file
17
Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Backbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
"""Output feature dimension."""
|
||||
if self.__dict__.get("_out_features") is None:
|
||||
return None
|
||||
return self._out_features
|
||||
11
Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
BACKBONE_REGISTRY = Registry("BACKBONE")
|
||||
|
||||
|
||||
def build_backbone(name, verbose=True, **kwargs):
|
||||
avai_backbones = BACKBONE_REGISTRY.registered_names()
|
||||
check_availability(name, avai_backbones)
|
||||
if verbose:
|
||||
print("Backbone: {}".format(name))
|
||||
return BACKBONE_REGISTRY.get(name)(**kwargs)
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Reference
|
||||
|
||||
https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class FeatureExtractor(Backbone):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
|
||||
self.bn2 = nn.BatchNorm2d(64)
|
||||
self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
|
||||
self.bn3 = nn.BatchNorm2d(128)
|
||||
self.fc1 = nn.Linear(8192, 3072)
|
||||
self.bn1_fc = nn.BatchNorm1d(3072)
|
||||
self.fc2 = nn.Linear(3072, 2048)
|
||||
self.bn2_fc = nn.BatchNorm1d(2048)
|
||||
|
||||
self._out_features = 2048
|
||||
|
||||
def _check_input(self, x):
|
||||
H, W = x.shape[2:]
|
||||
assert (
|
||||
H == 32 and W == 32
|
||||
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input(x)
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = x.view(x.size(0), 8192)
|
||||
x = F.relu(self.bn1_fc(self.fc1(x)))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = F.relu(self.bn2_fc(self.fc2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def cnn_digit5_m3sda(**kwargs):
|
||||
"""
|
||||
This architecture was used for the Digit-5 dataset in:
|
||||
|
||||
- Peng et al. Moment Matching for Multi-Source
|
||||
Domain Adaptation. ICCV 2019.
|
||||
"""
|
||||
return FeatureExtractor()
|
||||
@@ -0,0 +1,61 @@
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.utils import init_network_weights
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class Convolution(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.conv(x))
|
||||
|
||||
|
||||
class ConvNet(Backbone):
|
||||
|
||||
def __init__(self, c_hidden=64):
|
||||
super().__init__()
|
||||
self.conv1 = Convolution(3, c_hidden)
|
||||
self.conv2 = Convolution(c_hidden, c_hidden)
|
||||
self.conv3 = Convolution(c_hidden, c_hidden)
|
||||
self.conv4 = Convolution(c_hidden, c_hidden)
|
||||
|
||||
self._out_features = 2**2 * c_hidden
|
||||
|
||||
def _check_input(self, x):
|
||||
H, W = x.shape[2:]
|
||||
assert (
|
||||
H == 32 and W == 32
|
||||
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input(x)
|
||||
x = self.conv1(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.conv2(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.conv3(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.conv4(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def cnn_digitsdg(**kwargs):
|
||||
"""
|
||||
This architecture was used for DigitsDG dataset in:
|
||||
|
||||
- Zhou et al. Deep Domain-Adversarial Image Generation
|
||||
for Domain Generalisation. AAAI 2020.
|
||||
"""
|
||||
model = ConvNet(c_hidden=64)
|
||||
init_network_weights(model, init_type="kaiming")
|
||||
return model
|
||||
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
This model is built based on
|
||||
https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.utils import init_network_weights
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class CNN(Backbone):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, 5)
|
||||
self.conv2 = nn.Conv2d(64, 128, 5)
|
||||
self.fc3 = nn.Linear(5 * 5 * 128, 1024)
|
||||
self.fc4 = nn.Linear(1024, 1024)
|
||||
|
||||
self._out_features = 1024
|
||||
|
||||
def _check_input(self, x):
|
||||
H, W = x.shape[2:]
|
||||
assert (
|
||||
H == 32 and W == 32
|
||||
), "Input to network must be 32x32, " "but got {}x{}".format(H, W)
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input(x)
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
x = self.fc3(x)
|
||||
x = F.relu(x)
|
||||
|
||||
x = self.fc4(x)
|
||||
x = F.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def cnn_digitsingle(**kwargs):
|
||||
model = CNN()
|
||||
init_network_weights(model, init_type="kaiming")
|
||||
return model
|
||||
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Source: https://github.com/lukemelas/EfficientNet-PyTorch.
|
||||
"""
|
||||
__version__ = "0.6.4"
|
||||
from .model import (
|
||||
EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2,
|
||||
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6,
|
||||
efficientnet_b7
|
||||
)
|
||||
from .utils import (
|
||||
BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params
|
||||
)
|
||||
@@ -0,0 +1,371 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .utils import (
|
||||
Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats,
|
||||
get_model_params, efficientnet_params, get_same_padding_conv2d,
|
||||
load_pretrained_weights, calculate_output_image_size
|
||||
)
|
||||
from ..build import BACKBONE_REGISTRY
|
||||
from ..backbone import Backbone
|
||||
|
||||
|
||||
class MBConvBlock(nn.Module):
|
||||
"""
|
||||
Mobile Inverted Residual Bottleneck Block
|
||||
|
||||
Args:
|
||||
block_args (namedtuple): BlockArgs, see above
|
||||
global_params (namedtuple): GlobalParam, see above
|
||||
|
||||
Attributes:
|
||||
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, global_params, image_size=None):
|
||||
super().__init__()
|
||||
self._block_args = block_args
|
||||
self._bn_mom = 1 - global_params.batch_norm_momentum
|
||||
self._bn_eps = global_params.batch_norm_epsilon
|
||||
self.has_se = (self._block_args.se_ratio
|
||||
is not None) and (0 < self._block_args.se_ratio <= 1)
|
||||
self.id_skip = block_args.id_skip # skip connection and drop connect
|
||||
|
||||
# Expansion phase
|
||||
inp = self._block_args.input_filters # number of input channels
|
||||
oup = (
|
||||
self._block_args.input_filters * self._block_args.expand_ratio
|
||||
) # number of output channels
|
||||
if self._block_args.expand_ratio != 1:
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._expand_conv = Conv2d(
|
||||
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn0 = nn.BatchNorm2d(
|
||||
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
# image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing
|
||||
|
||||
# Depthwise convolution phase
|
||||
k = self._block_args.kernel_size
|
||||
s = self._block_args.stride
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._depthwise_conv = Conv2d(
|
||||
in_channels=oup,
|
||||
out_channels=oup,
|
||||
groups=oup, # groups makes it depthwise
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
bias=False,
|
||||
)
|
||||
self._bn1 = nn.BatchNorm2d(
|
||||
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
image_size = calculate_output_image_size(image_size, s)
|
||||
|
||||
# Squeeze and Excitation layer, if desired
|
||||
if self.has_se:
|
||||
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
||||
num_squeezed_channels = max(
|
||||
1,
|
||||
int(
|
||||
self._block_args.input_filters * self._block_args.se_ratio
|
||||
)
|
||||
)
|
||||
self._se_reduce = Conv2d(
|
||||
in_channels=oup,
|
||||
out_channels=num_squeezed_channels,
|
||||
kernel_size=1
|
||||
)
|
||||
self._se_expand = Conv2d(
|
||||
in_channels=num_squeezed_channels,
|
||||
out_channels=oup,
|
||||
kernel_size=1
|
||||
)
|
||||
|
||||
# Output phase
|
||||
final_oup = self._block_args.output_filters
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._project_conv = Conv2d(
|
||||
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn2 = nn.BatchNorm2d(
|
||||
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
|
||||
)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
|
||||
def forward(self, inputs, drop_connect_rate=None):
|
||||
"""
|
||||
:param inputs: input tensor
|
||||
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
|
||||
:return: output of block
|
||||
"""
|
||||
|
||||
# Expansion and Depthwise Convolution
|
||||
x = inputs
|
||||
if self._block_args.expand_ratio != 1:
|
||||
x = self._swish(self._bn0(self._expand_conv(inputs)))
|
||||
x = self._swish(self._bn1(self._depthwise_conv(x)))
|
||||
|
||||
# Squeeze and Excitation
|
||||
if self.has_se:
|
||||
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
||||
x_squeezed = self._se_expand(
|
||||
self._swish(self._se_reduce(x_squeezed))
|
||||
)
|
||||
x = torch.sigmoid(x_squeezed) * x
|
||||
|
||||
x = self._bn2(self._project_conv(x))
|
||||
|
||||
# Skip connection and drop connect
|
||||
input_filters, output_filters = (
|
||||
self._block_args.input_filters,
|
||||
self._block_args.output_filters,
|
||||
)
|
||||
if (
|
||||
self.id_skip and self._block_args.stride == 1
|
||||
and input_filters == output_filters
|
||||
):
|
||||
if drop_connect_rate:
|
||||
x = drop_connect(
|
||||
x, p=drop_connect_rate, training=self.training
|
||||
)
|
||||
x = x + inputs # skip connection
|
||||
return x
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
|
||||
|
||||
class EfficientNet(Backbone):
|
||||
"""
|
||||
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
|
||||
|
||||
Args:
|
||||
blocks_args (list): A list of BlockArgs to construct blocks
|
||||
global_params (namedtuple): A set of GlobalParams shared between blocks
|
||||
|
||||
Example:
|
||||
model = EfficientNet.from_pretrained('efficientnet-b0')
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, blocks_args=None, global_params=None):
|
||||
super().__init__()
|
||||
assert isinstance(blocks_args, list), "blocks_args should be a list"
|
||||
assert len(blocks_args) > 0, "block args must be greater than 0"
|
||||
self._global_params = global_params
|
||||
self._blocks_args = blocks_args
|
||||
|
||||
# Batch norm parameters
|
||||
bn_mom = 1 - self._global_params.batch_norm_momentum
|
||||
bn_eps = self._global_params.batch_norm_epsilon
|
||||
|
||||
# Get stem static or dynamic convolution depending on image size
|
||||
image_size = global_params.image_size
|
||||
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
||||
|
||||
# Stem
|
||||
in_channels = 3 # rgb
|
||||
out_channels = round_filters(
|
||||
32, self._global_params
|
||||
) # number of output channels
|
||||
self._conv_stem = Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
||||
)
|
||||
self._bn0 = nn.BatchNorm2d(
|
||||
num_features=out_channels, momentum=bn_mom, eps=bn_eps
|
||||
)
|
||||
image_size = calculate_output_image_size(image_size, 2)
|
||||
|
||||
# Build blocks
|
||||
self._blocks = nn.ModuleList([])
|
||||
for block_args in self._blocks_args:
|
||||
|
||||
# Update block input and output filters based on depth multiplier.
|
||||
block_args = block_args._replace(
|
||||
input_filters=round_filters(
|
||||
block_args.input_filters, self._global_params
|
||||
),
|
||||
output_filters=round_filters(
|
||||
block_args.output_filters, self._global_params
|
||||
),
|
||||
num_repeat=round_repeats(
|
||||
block_args.num_repeat, self._global_params
|
||||
),
|
||||
)
|
||||
|
||||
# The first block needs to take care of stride and filter size increase.
|
||||
self._blocks.append(
|
||||
MBConvBlock(
|
||||
block_args, self._global_params, image_size=image_size
|
||||
)
|
||||
)
|
||||
image_size = calculate_output_image_size(
|
||||
image_size, block_args.stride
|
||||
)
|
||||
if block_args.num_repeat > 1:
|
||||
block_args = block_args._replace(
|
||||
input_filters=block_args.output_filters, stride=1
|
||||
)
|
||||
for _ in range(block_args.num_repeat - 1):
|
||||
self._blocks.append(
|
||||
MBConvBlock(
|
||||
block_args, self._global_params, image_size=image_size
|
||||
)
|
||||
)
|
||||
# image_size = calculate_output_image_size(image_size, block_args.stride) # ?
|
||||
|
||||
# Head
|
||||
in_channels = block_args.output_filters # output of final block
|
||||
out_channels = round_filters(1280, self._global_params)
|
||||
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
||||
self._conv_head = Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, bias=False
|
||||
)
|
||||
self._bn1 = nn.BatchNorm2d(
|
||||
num_features=out_channels, momentum=bn_mom, eps=bn_eps
|
||||
)
|
||||
|
||||
# Final linear layer
|
||||
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
||||
# self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
|
||||
self._out_features = out_channels
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
for block in self._blocks:
|
||||
block.set_swish(memory_efficient)
|
||||
|
||||
def extract_features(self, inputs):
|
||||
"""Returns output of the final convolution layer"""
|
||||
|
||||
# Stem
|
||||
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
||||
|
||||
# Blocks
|
||||
for idx, block in enumerate(self._blocks):
|
||||
drop_connect_rate = self._global_params.drop_connect_rate
|
||||
if drop_connect_rate:
|
||||
drop_connect_rate *= float(idx) / len(self._blocks)
|
||||
x = block(x, drop_connect_rate=drop_connect_rate)
|
||||
|
||||
# Head
|
||||
x = self._swish(self._bn1(self._conv_head(x)))
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Calls extract_features to extract features, applies
|
||||
final linear layer, and returns logits.
|
||||
"""
|
||||
bs = inputs.size(0)
|
||||
# Convolution layers
|
||||
x = self.extract_features(inputs)
|
||||
|
||||
# Pooling and final linear layer
|
||||
x = self._avg_pooling(x)
|
||||
x = x.view(bs, -1)
|
||||
x = self._dropout(x)
|
||||
# x = self._fc(x)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, model_name, override_params=None):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
blocks_args, global_params = get_model_params(
|
||||
model_name, override_params
|
||||
)
|
||||
return cls(blocks_args, global_params)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name, advprop=False, num_classes=1000, in_channels=3
|
||||
):
|
||||
model = cls.from_name(
|
||||
model_name, override_params={"num_classes": num_classes}
|
||||
)
|
||||
load_pretrained_weights(
|
||||
model, model_name, load_fc=(num_classes == 1000), advprop=advprop
|
||||
)
|
||||
model._change_in_channels(in_channels)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_image_size(cls, model_name):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
_, _, res, _ = efficientnet_params(model_name)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _check_model_name_is_valid(cls, model_name):
|
||||
"""Validates model name."""
|
||||
valid_models = ["efficientnet-b" + str(i) for i in range(9)]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError(
|
||||
"model_name should be one of: " + ", ".join(valid_models)
|
||||
)
|
||||
|
||||
def _change_in_channels(model, in_channels):
|
||||
if in_channels != 3:
|
||||
Conv2d = get_same_padding_conv2d(
|
||||
image_size=model._global_params.image_size
|
||||
)
|
||||
out_channels = round_filters(32, model._global_params)
|
||||
model._conv_stem = Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
||||
)
|
||||
|
||||
|
||||
def build_efficientnet(name, pretrained):
|
||||
if pretrained:
|
||||
return EfficientNet.from_pretrained("efficientnet-{}".format(name))
|
||||
else:
|
||||
return EfficientNet.from_name("efficientnet-{}".format(name))
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b0(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b0", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b1(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b1", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b2(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b2", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b3(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b3", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b4(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b4", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b5(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b5", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b6(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b6", pretrained)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def efficientnet_b7(pretrained=True, **kwargs):
|
||||
return build_efficientnet("b7", pretrained)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
This file contains helper functions for building the model and for loading model parameters.
|
||||
These helper functions are built to mirror those in the official TensorFlow implementation.
|
||||
"""
|
||||
|
||||
import re
|
||||
import math
|
||||
import collections
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils import model_zoo
|
||||
|
||||
########################################################################
|
||||
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
|
||||
########################################################################
|
||||
|
||||
# Parameters for the entire model (stem, all blocks, and head)
|
||||
GlobalParams = collections.namedtuple(
|
||||
"GlobalParams",
|
||||
[
|
||||
"batch_norm_momentum",
|
||||
"batch_norm_epsilon",
|
||||
"dropout_rate",
|
||||
"num_classes",
|
||||
"width_coefficient",
|
||||
"depth_coefficient",
|
||||
"depth_divisor",
|
||||
"min_depth",
|
||||
"drop_connect_rate",
|
||||
"image_size",
|
||||
],
|
||||
)
|
||||
|
||||
# Parameters for an individual model block
|
||||
BlockArgs = collections.namedtuple(
|
||||
"BlockArgs",
|
||||
[
|
||||
"kernel_size",
|
||||
"num_repeat",
|
||||
"input_filters",
|
||||
"output_filters",
|
||||
"expand_ratio",
|
||||
"id_skip",
|
||||
"stride",
|
||||
"se_ratio",
|
||||
],
|
||||
)
|
||||
|
||||
# Change namedtuple defaults
|
||||
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
|
||||
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
|
||||
|
||||
|
||||
class SwishImplementation(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, i):
|
||||
result = i * torch.sigmoid(i)
|
||||
ctx.save_for_backward(i)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
i = ctx.saved_variables[0]
|
||||
sigmoid_i = torch.sigmoid(i)
|
||||
return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i)))
|
||||
|
||||
|
||||
class MemoryEfficientSwish(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return SwishImplementation.apply(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def round_filters(filters, global_params):
|
||||
"""Calculate and round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.width_coefficient
|
||||
if not multiplier:
|
||||
return filters
|
||||
divisor = global_params.depth_divisor
|
||||
min_depth = global_params.min_depth
|
||||
filters *= multiplier
|
||||
min_depth = min_depth or divisor
|
||||
new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor)
|
||||
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
|
||||
def round_repeats(repeats, global_params):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.depth_coefficient
|
||||
if not multiplier:
|
||||
return repeats
|
||||
return int(math.ceil(multiplier * repeats))
|
||||
|
||||
|
||||
def drop_connect(inputs, p, training):
|
||||
"""Drop connect."""
|
||||
if not training:
|
||||
return inputs
|
||||
batch_size = inputs.shape[0]
|
||||
keep_prob = 1 - p
|
||||
random_tensor = keep_prob
|
||||
random_tensor += torch.rand(
|
||||
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
|
||||
)
|
||||
binary_tensor = torch.floor(random_tensor)
|
||||
output = inputs / keep_prob * binary_tensor
|
||||
return output
|
||||
|
||||
|
||||
def get_same_padding_conv2d(image_size=None):
|
||||
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
||||
Static padding is necessary for ONNX exporting of models."""
|
||||
if image_size is None:
|
||||
return Conv2dDynamicSamePadding
|
||||
else:
|
||||
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
||||
|
||||
|
||||
def get_width_and_height_from_size(x):
|
||||
"""Obtains width and height from a int or tuple"""
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
if isinstance(x, list) or isinstance(x, tuple):
|
||||
return x
|
||||
else:
|
||||
raise TypeError()
|
||||
|
||||
|
||||
def calculate_output_image_size(input_image_size, stride):
|
||||
"""
|
||||
Calculates the output image size when using Conv2dSamePadding with a stride.
|
||||
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
||||
"""
|
||||
if input_image_size is None:
|
||||
return None
|
||||
image_height, image_width = get_width_and_height_from_size(
|
||||
input_image_size
|
||||
)
|
||||
stride = stride if isinstance(stride, int) else stride[0]
|
||||
image_height = int(math.ceil(image_height / stride))
|
||||
image_width = int(math.ceil(image_width / stride))
|
||||
return [image_height, image_width]
|
||||
|
||||
|
||||
class Conv2dDynamicSamePadding(nn.Conv2d):
|
||||
"""2D Convolutions like TensorFlow, for a dynamic image size"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
||||
groups, bias
|
||||
)
|
||||
self.stride = self.stride if len(self.stride
|
||||
) == 2 else [self.stride[0]] * 2
|
||||
|
||||
def forward(self, x):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
sh, sw = self.stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max(
|
||||
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
|
||||
)
|
||||
pad_w = max(
|
||||
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
|
||||
)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(
|
||||
x,
|
||||
[pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2]
|
||||
)
|
||||
return F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
class Conv2dStaticSamePadding(nn.Conv2d):
|
||||
"""2D Convolutions like TensorFlow, for a fixed image size"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
image_size=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
||||
self.stride = self.stride if len(self.stride
|
||||
) == 2 else [self.stride[0]] * 2
|
||||
|
||||
# Calculate padding based on image size and save it
|
||||
assert image_size is not None
|
||||
ih, iw = (image_size,
|
||||
image_size) if isinstance(image_size, int) else image_size
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
sh, sw = self.stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max(
|
||||
(oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0
|
||||
)
|
||||
pad_w = max(
|
||||
(ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0
|
||||
)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
self.static_padding = nn.ZeroPad2d(
|
||||
(pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2)
|
||||
)
|
||||
else:
|
||||
self.static_padding = Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.static_padding(x)
|
||||
x = F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self, ):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
########################################################################
|
||||
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
|
||||
########################################################################
|
||||
|
||||
|
||||
def efficientnet_params(model_name):
|
||||
"""Map EfficientNet model name to parameter coefficients."""
|
||||
params_dict = {
|
||||
# Coefficients: width,depth,res,dropout
|
||||
"efficientnet-b0": (1.0, 1.0, 224, 0.2),
|
||||
"efficientnet-b1": (1.0, 1.1, 240, 0.2),
|
||||
"efficientnet-b2": (1.1, 1.2, 260, 0.3),
|
||||
"efficientnet-b3": (1.2, 1.4, 300, 0.3),
|
||||
"efficientnet-b4": (1.4, 1.8, 380, 0.4),
|
||||
"efficientnet-b5": (1.6, 2.2, 456, 0.4),
|
||||
"efficientnet-b6": (1.8, 2.6, 528, 0.5),
|
||||
"efficientnet-b7": (2.0, 3.1, 600, 0.5),
|
||||
"efficientnet-b8": (2.2, 3.6, 672, 0.5),
|
||||
"efficientnet-l2": (4.3, 5.3, 800, 0.5),
|
||||
}
|
||||
return params_dict[model_name]
|
||||
|
||||
|
||||
class BlockDecoder(object):
|
||||
"""Block Decoder for readability, straight from the official TensorFlow repository"""
|
||||
|
||||
@staticmethod
|
||||
def _decode_block_string(block_string):
|
||||
"""Gets a block through a string notation of arguments."""
|
||||
assert isinstance(block_string, str)
|
||||
|
||||
ops = block_string.split("_")
|
||||
options = {}
|
||||
for op in ops:
|
||||
splits = re.split(r"(\d.*)", op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
# Check stride
|
||||
assert ("s" in options and len(options["s"]) == 1) or (
|
||||
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
|
||||
)
|
||||
|
||||
return BlockArgs(
|
||||
kernel_size=int(options["k"]),
|
||||
num_repeat=int(options["r"]),
|
||||
input_filters=int(options["i"]),
|
||||
output_filters=int(options["o"]),
|
||||
expand_ratio=int(options["e"]),
|
||||
id_skip=("noskip" not in block_string),
|
||||
se_ratio=float(options["se"]) if "se" in options else None,
|
||||
stride=[int(options["s"][0])],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _encode_block_string(block):
|
||||
"""Encodes a block to a string."""
|
||||
args = [
|
||||
"r%d" % block.num_repeat,
|
||||
"k%d" % block.kernel_size,
|
||||
"s%d%d" % (block.strides[0], block.strides[1]),
|
||||
"e%s" % block.expand_ratio,
|
||||
"i%d" % block.input_filters,
|
||||
"o%d" % block.output_filters,
|
||||
]
|
||||
if 0 < block.se_ratio <= 1:
|
||||
args.append("se%s" % block.se_ratio)
|
||||
if block.id_skip is False:
|
||||
args.append("noskip")
|
||||
return "_".join(args)
|
||||
|
||||
@staticmethod
|
||||
def decode(string_list):
|
||||
"""
|
||||
Decodes a list of string notations to specify blocks inside the network.
|
||||
|
||||
:param string_list: a list of strings, each string is a notation of block
|
||||
:return: a list of BlockArgs namedtuples of block args
|
||||
"""
|
||||
assert isinstance(string_list, list)
|
||||
blocks_args = []
|
||||
for block_string in string_list:
|
||||
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
||||
return blocks_args
|
||||
|
||||
@staticmethod
|
||||
def encode(blocks_args):
|
||||
"""
|
||||
Encodes a list of BlockArgs to a list of strings.
|
||||
|
||||
:param blocks_args: a list of BlockArgs namedtuples of block args
|
||||
:return: a list of strings, each string is a notation of block
|
||||
"""
|
||||
block_strings = []
|
||||
for block in blocks_args:
|
||||
block_strings.append(BlockDecoder._encode_block_string(block))
|
||||
return block_strings
|
||||
|
||||
|
||||
def efficientnet(
|
||||
width_coefficient=None,
|
||||
depth_coefficient=None,
|
||||
dropout_rate=0.2,
|
||||
drop_connect_rate=0.2,
|
||||
image_size=None,
|
||||
num_classes=1000,
|
||||
):
|
||||
"""Creates a efficientnet model."""
|
||||
|
||||
blocks_args = [
|
||||
"r1_k3_s11_e1_i32_o16_se0.25",
|
||||
"r2_k3_s22_e6_i16_o24_se0.25",
|
||||
"r2_k5_s22_e6_i24_o40_se0.25",
|
||||
"r3_k3_s22_e6_i40_o80_se0.25",
|
||||
"r3_k5_s11_e6_i80_o112_se0.25",
|
||||
"r4_k5_s22_e6_i112_o192_se0.25",
|
||||
"r1_k3_s11_e6_i192_o320_se0.25",
|
||||
]
|
||||
blocks_args = BlockDecoder.decode(blocks_args)
|
||||
|
||||
global_params = GlobalParams(
|
||||
batch_norm_momentum=0.99,
|
||||
batch_norm_epsilon=1e-3,
|
||||
dropout_rate=dropout_rate,
|
||||
drop_connect_rate=drop_connect_rate,
|
||||
# data_format='channels_last', # removed, this is always true in PyTorch
|
||||
num_classes=num_classes,
|
||||
width_coefficient=width_coefficient,
|
||||
depth_coefficient=depth_coefficient,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
return blocks_args, global_params
|
||||
|
||||
|
||||
def get_model_params(model_name, override_params):
|
||||
"""Get the block args and global params for a given model"""
|
||||
if model_name.startswith("efficientnet"):
|
||||
w, d, s, p = efficientnet_params(model_name)
|
||||
# note: all models have drop connect rate = 0.2
|
||||
blocks_args, global_params = efficientnet(
|
||||
width_coefficient=w,
|
||||
depth_coefficient=d,
|
||||
dropout_rate=p,
|
||||
image_size=s
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"model name is not pre-defined: %s" % model_name
|
||||
)
|
||||
if override_params:
|
||||
# ValueError will be raised here if override_params has fields not included in global_params.
|
||||
global_params = global_params._replace(**override_params)
|
||||
return blocks_args, global_params
|
||||
|
||||
|
||||
url_map = {
|
||||
"efficientnet-b0":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
|
||||
"efficientnet-b1":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
|
||||
"efficientnet-b2":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
|
||||
"efficientnet-b3":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
|
||||
"efficientnet-b4":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||
"efficientnet-b5":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||
"efficientnet-b6":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||
"efficientnet-b7":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
|
||||
}
|
||||
|
||||
url_map_advprop = {
|
||||
"efficientnet-b0":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
|
||||
"efficientnet-b1":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
|
||||
"efficientnet-b2":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
|
||||
"efficientnet-b3":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
|
||||
"efficientnet-b4":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
|
||||
"efficientnet-b5":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
|
||||
"efficientnet-b6":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
|
||||
"efficientnet-b7":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
|
||||
"efficientnet-b8":
|
||||
"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
|
||||
}
|
||||
|
||||
|
||||
def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
|
||||
"""Loads pretrained weights, and downloads if loading for the first time."""
|
||||
# AutoAugment or Advprop (different preprocessing)
|
||||
url_map_ = url_map_advprop if advprop else url_map
|
||||
state_dict = model_zoo.load_url(url_map_[model_name])
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
"""
|
||||
if load_fc:
|
||||
model.load_state_dict(state_dict)
|
||||
else:
|
||||
state_dict.pop('_fc.weight')
|
||||
state_dict.pop('_fc.bias')
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
||||
|
||||
print('Loaded pretrained weights for {}'.format(model_name))
|
||||
"""
|
||||
217
Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py
Normal file
217
Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
model_urls = {
|
||||
"mobilenet_v2":
|
||||
"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
}
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor/2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self, in_planes, out_planes, kernel_size=3, stride=1, groups=1
|
||||
):
|
||||
padding = (kernel_size-1) // 2
|
||||
super().__init__(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True),
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
ConvBNReLU(
|
||||
hidden_dim, hidden_dim, stride=stride, groups=hidden_dim
|
||||
),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(Backbone):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width_mult=1.0,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None,
|
||||
):
|
||||
"""
|
||||
MobileNet V2.
|
||||
|
||||
Args:
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if (
|
||||
len(inverted_residual_setting) == 0
|
||||
or len(inverted_residual_setting[0]) != 4
|
||||
):
|
||||
raise ValueError(
|
||||
"inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".
|
||||
format(inverted_residual_setting)
|
||||
)
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(
|
||||
input_channel * width_mult, round_nearest
|
||||
)
|
||||
self.last_channel = _make_divisible(
|
||||
last_channel * max(1.0, width_mult), round_nearest
|
||||
)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(
|
||||
block(
|
||||
input_channel, output_channel, stride, expand_ratio=t
|
||||
)
|
||||
)
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(
|
||||
ConvBNReLU(input_channel, self.last_channel, kernel_size=1)
|
||||
)
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
self._out_features = self.last_channel
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
x = x.mean([2, 3])
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
if model_url is None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"ImageNet pretrained weights are unavailable for this model"
|
||||
)
|
||||
return
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def mobilenetv2(pretrained=True, **kwargs):
|
||||
model = MobileNetV2(**kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["mobilenet_v2"])
|
||||
return model
|
||||
135
Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py
Normal file
135
Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class PreActBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super().__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super().__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, self.expansion * planes, kernel_size=1, bias=False
|
||||
)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out = self.conv3(F.relu(self.bn3(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class PreActResNet(Backbone):
|
||||
|
||||
def __init__(self, block, num_blocks):
|
||||
super().__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
|
||||
self._out_features = 512 * block.expansion
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
return out
|
||||
|
||||
|
||||
"""
|
||||
Preact-ResNet18 was used for the CIFAR10 and
|
||||
SVHN datasets (both are SSL tasks) in
|
||||
|
||||
- Wang et al. Semi-Supervised Learning by
|
||||
Augmented Distribution Alignment. ICCV 2019.
|
||||
"""
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def preact_resnet18(**kwargs):
|
||||
return PreActResNet(PreActBlock, [2, 2, 2, 2])
|
||||
589
Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py
Normal file
589
Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py
Normal file
@@ -0,0 +1,589 @@
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
model_urls = {
|
||||
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
||||
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
||||
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
||||
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
||||
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, planes * self.expansion, kernel_size=1, bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(Backbone):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block,
|
||||
layers,
|
||||
ms_class=None,
|
||||
ms_layers=[],
|
||||
ms_p=0.5,
|
||||
ms_a=0.1,
|
||||
**kwargs
|
||||
):
|
||||
self.inplanes = 64
|
||||
super().__init__()
|
||||
|
||||
# backbone network
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self._out_features = 512 * block.expansion
|
||||
|
||||
self.mixstyle = None
|
||||
if ms_layers:
|
||||
self.mixstyle = ms_class(p=ms_p, alpha=ms_a)
|
||||
for layer_name in ms_layers:
|
||||
assert layer_name in ["layer1", "layer2", "layer3"]
|
||||
print(f"Insert MixStyle after {ms_layers}")
|
||||
self.ms_layers = ms_layers
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
if "layer1" in self.ms_layers:
|
||||
x = self.mixstyle(x)
|
||||
x = self.layer2(x)
|
||||
if "layer2" in self.ms_layers:
|
||||
x = self.mixstyle(x)
|
||||
x = self.layer3(x)
|
||||
if "layer3" in self.ms_layers:
|
||||
x = self.mixstyle(x)
|
||||
return self.layer4(x)
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.global_avgpool(f)
|
||||
return v.view(v.size(0), -1)
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model.load_state_dict(pretrain_dict, strict=False)
|
||||
|
||||
|
||||
"""
|
||||
Residual network configurations:
|
||||
--
|
||||
resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
|
||||
resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
|
||||
resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
|
||||
resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
|
||||
resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
|
||||
"""
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18(pretrained=True, **kwargs):
|
||||
model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet34(pretrained=True, **kwargs):
|
||||
model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet34"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50(pretrained=True, **kwargs):
|
||||
model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101(pretrained=True, **kwargs):
|
||||
model = ResNet(block=Bottleneck, layers=[3, 4, 23, 3])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet152(pretrained=True, **kwargs):
|
||||
model = ResNet(block=Bottleneck, layers=[3, 8, 36, 3])
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet152"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
"""
|
||||
Residual networks with mixstyle
|
||||
"""
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_ms_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_ms_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_ms_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_ms_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_ms_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_ms_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_ms_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_ms_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_ms_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import MixStyle
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=MixStyle,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
"""
|
||||
Residual networks with efdmix
|
||||
"""
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_efdmix_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_efdmix_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet18_efdmix_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=BasicBlock,
|
||||
layers=[2, 2, 2, 2],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet18"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_efdmix_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_efdmix_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet50_efdmix_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 6, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet50"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_efdmix_l123(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2", "layer3"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_efdmix_l12(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1", "layer2"],
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def resnet101_efdmix_l1(pretrained=True, **kwargs):
|
||||
from dassl.modeling.ops import EFDMix
|
||||
|
||||
model = ResNet(
|
||||
block=Bottleneck,
|
||||
layers=[3, 4, 23, 3],
|
||||
ms_class=EFDMix,
|
||||
ms_layers=["layer1"]
|
||||
)
|
||||
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["resnet101"])
|
||||
|
||||
return model
|
||||
229
Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py
Normal file
229
Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Code source: https://github.com/pytorch/vision
|
||||
"""
|
||||
import torch
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch import nn
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
model_urls = {
|
||||
"shufflenetv2_x0.5":
|
||||
"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
|
||||
"shufflenetv2_x1.0":
|
||||
"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
|
||||
"shufflenetv2_x1.5": None,
|
||||
"shufflenetv2_x2.0": None,
|
||||
}
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
batchsize, num_channels, height, width = x.data.size()
|
||||
channels_per_group = num_channels // groups
|
||||
|
||||
# reshape
|
||||
x = x.view(batchsize, groups, channels_per_group, height, width)
|
||||
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
|
||||
# flatten
|
||||
x = x.view(batchsize, -1, height, width)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
|
||||
def __init__(self, inp, oup, stride):
|
||||
super().__init__()
|
||||
|
||||
if not (1 <= stride <= 3):
|
||||
raise ValueError("illegal stride value")
|
||||
self.stride = stride
|
||||
|
||||
branch_features = oup // 2
|
||||
assert (self.stride != 1) or (inp == branch_features << 1)
|
||||
|
||||
if self.stride > 1:
|
||||
self.branch1 = nn.Sequential(
|
||||
self.depthwise_conv(
|
||||
inp, inp, kernel_size=3, stride=self.stride, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(inp),
|
||||
nn.Conv2d(
|
||||
inp,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False
|
||||
),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inp if (self.stride > 1) else branch_features,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
self.depthwise_conv(
|
||||
branch_features,
|
||||
branch_features,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.Conv2d(
|
||||
branch_features,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(branch_features),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
|
||||
return nn.Conv2d(
|
||||
i, o, kernel_size, stride, padding, bias=bias, groups=i
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
x1, x2 = x.chunk(2, dim=1)
|
||||
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
||||
else:
|
||||
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
||||
|
||||
out = channel_shuffle(out, 2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNetV2(Backbone):
|
||||
|
||||
def __init__(self, stages_repeats, stages_out_channels, **kwargs):
|
||||
super().__init__()
|
||||
if len(stages_repeats) != 3:
|
||||
raise ValueError(
|
||||
"expected stages_repeats as list of 3 positive ints"
|
||||
)
|
||||
if len(stages_out_channels) != 5:
|
||||
raise ValueError(
|
||||
"expected stages_out_channels as list of 5 positive ints"
|
||||
)
|
||||
self._stage_out_channels = stages_out_channels
|
||||
|
||||
input_channels = 3
|
||||
output_channels = self._stage_out_channels[0]
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
|
||||
for name, repeats, output_channels in zip(
|
||||
stage_names, stages_repeats, self._stage_out_channels[1:]
|
||||
):
|
||||
seq = [InvertedResidual(input_channels, output_channels, 2)]
|
||||
for i in range(repeats - 1):
|
||||
seq.append(
|
||||
InvertedResidual(output_channels, output_channels, 1)
|
||||
)
|
||||
setattr(self, name, nn.Sequential(*seq))
|
||||
input_channels = output_channels
|
||||
|
||||
output_channels = self._stage_out_channels[-1]
|
||||
self.conv5 = nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
self._out_features = output_channels
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = self.conv5(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = self.featuremaps(x)
|
||||
v = self.global_avgpool(f)
|
||||
return v.view(v.size(0), -1)
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
if model_url is None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"ImageNet pretrained weights are unavailable for this model"
|
||||
)
|
||||
return
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {
|
||||
k: v
|
||||
for k, v in pretrain_dict.items()
|
||||
if k in model_dict and model_dict[k].size() == v.size()
|
||||
}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def shufflenet_v2_x0_5(pretrained=True, **kwargs):
|
||||
model = ShuffleNetV2([4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["shufflenetv2_x0.5"])
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def shufflenet_v2_x1_0(pretrained=True, **kwargs):
|
||||
model = ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["shufflenetv2_x1.0"])
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def shufflenet_v2_x1_5(pretrained=True, **kwargs):
|
||||
model = ShuffleNetV2([4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["shufflenetv2_x1.5"])
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def shufflenet_v2_x2_0(pretrained=True, **kwargs):
|
||||
model = ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, model_urls["shufflenetv2_x2.0"])
|
||||
return model
|
||||
147
Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py
Normal file
147
Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
model_urls = {
|
||||
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
|
||||
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
|
||||
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
||||
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
||||
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
||||
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
||||
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
||||
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
|
||||
}
|
||||
|
||||
|
||||
class VGG(Backbone):
|
||||
|
||||
def __init__(self, features, init_weights=True):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
||||
# Note that self.classifier outputs features rather than logits
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 7 * 7, 4096),
|
||||
nn.ReLU(True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(True),
|
||||
nn.Dropout(),
|
||||
)
|
||||
|
||||
self._out_features = 4096
|
||||
|
||||
if init_weights:
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
return self.classifier(x)
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def make_layers(cfg, batch_norm=False):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for v in cfg:
|
||||
if v == "M":
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = v
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
cfgs = {
|
||||
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
||||
"B":
|
||||
[64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
||||
"D": [
|
||||
64,
|
||||
64,
|
||||
"M",
|
||||
128,
|
||||
128,
|
||||
"M",
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
],
|
||||
"E": [
|
||||
64,
|
||||
64,
|
||||
"M",
|
||||
128,
|
||||
128,
|
||||
"M",
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _vgg(arch, cfg, batch_norm, pretrained):
|
||||
init_weights = False if pretrained else True
|
||||
model = VGG(
|
||||
make_layers(cfgs[cfg], batch_norm=batch_norm),
|
||||
init_weights=init_weights
|
||||
)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=True)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def vgg16(pretrained=True, **kwargs):
|
||||
return _vgg("vgg16", "D", False, pretrained)
|
||||
150
Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py
Normal file
150
Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Modified from https://github.com/xternalz/WideResNet-pytorch
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from .backbone import Backbone
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||
super().__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.relu1 = nn.LeakyReLU(0.01, inplace=True)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
self.relu2 = nn.LeakyReLU(0.01, inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False
|
||||
)
|
||||
self.droprate = dropRate
|
||||
self.equalInOut = in_planes == out_planes
|
||||
self.convShortcut = (
|
||||
(not self.equalInOut) and nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=False,
|
||||
) or None
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if not self.equalInOut:
|
||||
x = self.relu1(self.bn1(x))
|
||||
else:
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, training=self.training)
|
||||
out = self.conv2(out)
|
||||
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
||||
|
||||
|
||||
class NetworkBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = self._make_layer(
|
||||
block, in_planes, out_planes, nb_layers, stride, dropRate
|
||||
)
|
||||
|
||||
def _make_layer(
|
||||
self, block, in_planes, out_planes, nb_layers, stride, dropRate
|
||||
):
|
||||
layers = []
|
||||
for i in range(int(nb_layers)):
|
||||
layers.append(
|
||||
block(
|
||||
i == 0 and in_planes or out_planes,
|
||||
out_planes,
|
||||
i == 0 and stride or 1,
|
||||
dropRate,
|
||||
)
|
||||
)
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class WideResNet(Backbone):
|
||||
|
||||
def __init__(self, depth, widen_factor, dropRate=0.0):
|
||||
super().__init__()
|
||||
nChannels = [
|
||||
16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor
|
||||
]
|
||||
assert (depth-4) % 6 == 0
|
||||
n = (depth-4) / 6
|
||||
block = BasicBlock
|
||||
# 1st conv before any network block
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
# 1st block
|
||||
self.block1 = NetworkBlock(
|
||||
n, nChannels[0], nChannels[1], block, 1, dropRate
|
||||
)
|
||||
# 2nd block
|
||||
self.block2 = NetworkBlock(
|
||||
n, nChannels[1], nChannels[2], block, 2, dropRate
|
||||
)
|
||||
# 3rd block
|
||||
self.block3 = NetworkBlock(
|
||||
n, nChannels[2], nChannels[3], block, 2, dropRate
|
||||
)
|
||||
# global average pooling and classifier
|
||||
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
||||
self.relu = nn.LeakyReLU(0.01, inplace=True)
|
||||
|
||||
self._out_features = nChannels[3]
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.block1(out)
|
||||
out = self.block2(out)
|
||||
out = self.block3(out)
|
||||
out = self.relu(self.bn1(out))
|
||||
out = F.adaptive_avg_pool2d(out, 1)
|
||||
return out.view(out.size(0), -1)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def wide_resnet_28_2(**kwargs):
|
||||
return WideResNet(28, 2)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def wide_resnet_16_4(**kwargs):
|
||||
return WideResNet(16, 4)
|
||||
3
Dassl.ProGrad.pytorch/dassl/modeling/head/__init__.py
Normal file
3
Dassl.ProGrad.pytorch/dassl/modeling/head/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .build import build_head, HEAD_REGISTRY # isort:skip
|
||||
|
||||
from .mlp import mlp
|
||||
11
Dassl.ProGrad.pytorch/dassl/modeling/head/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/modeling/head/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
HEAD_REGISTRY = Registry("HEAD")
|
||||
|
||||
|
||||
def build_head(name, verbose=True, **kwargs):
|
||||
avai_heads = HEAD_REGISTRY.registered_names()
|
||||
check_availability(name, avai_heads)
|
||||
if verbose:
|
||||
print("Head: {}".format(name))
|
||||
return HEAD_REGISTRY.get(name)(**kwargs)
|
||||
50
Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py
Normal file
50
Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import functools
|
||||
import torch.nn as nn
|
||||
|
||||
from .build import HEAD_REGISTRY
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features=2048,
|
||||
hidden_layers=[],
|
||||
activation="relu",
|
||||
bn=True,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(hidden_layers, int):
|
||||
hidden_layers = [hidden_layers]
|
||||
|
||||
assert len(hidden_layers) > 0
|
||||
self.out_features = hidden_layers[-1]
|
||||
|
||||
mlp = []
|
||||
|
||||
if activation == "relu":
|
||||
act_fn = functools.partial(nn.ReLU, inplace=True)
|
||||
elif activation == "leaky_relu":
|
||||
act_fn = functools.partial(nn.LeakyReLU, inplace=True)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for hidden_dim in hidden_layers:
|
||||
mlp += [nn.Linear(in_features, hidden_dim)]
|
||||
if bn:
|
||||
mlp += [nn.BatchNorm1d(hidden_dim)]
|
||||
mlp += [act_fn()]
|
||||
if dropout > 0:
|
||||
mlp += [nn.Dropout(dropout)]
|
||||
in_features = hidden_dim
|
||||
|
||||
self.mlp = nn.Sequential(*mlp)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
@HEAD_REGISTRY.register()
|
||||
def mlp(**kwargs):
|
||||
return MLP(**kwargs)
|
||||
5
Dassl.ProGrad.pytorch/dassl/modeling/network/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/modeling/network/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .build import build_network, NETWORK_REGISTRY # isort:skip
|
||||
|
||||
from .ddaig_fcn import (
|
||||
fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn
|
||||
)
|
||||
11
Dassl.ProGrad.pytorch/dassl/modeling/network/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/modeling/network/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
NETWORK_REGISTRY = Registry("NETWORK")
|
||||
|
||||
|
||||
def build_network(name, verbose=True, **kwargs):
|
||||
avai_models = NETWORK_REGISTRY.registered_names()
|
||||
check_availability(name, avai_models)
|
||||
if verbose:
|
||||
print("Network: {}".format(name))
|
||||
return NETWORK_REGISTRY.get(name)(**kwargs)
|
||||
329
Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py
Normal file
329
Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
||||
"""
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .build import NETWORK_REGISTRY
|
||||
|
||||
|
||||
def init_network_weights(model, init_type="normal", gain=0.02):
|
||||
|
||||
def _init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, "weight") and (
|
||||
classname.find("Conv") != -1 or classname.find("Linear") != -1
|
||||
):
|
||||
if init_type == "normal":
|
||||
nn.init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == "xavier":
|
||||
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == "kaiming":
|
||||
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
||||
elif init_type == "orthogonal":
|
||||
nn.init.orthogonal_(m.weight.data, gain=gain)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"initialization method {} is not implemented".
|
||||
format(init_type)
|
||||
)
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find("BatchNorm2d") != -1:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find("InstanceNorm2d") != -1:
|
||||
if m.weight is not None and m.bias is not None:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
model.apply(_init_func)
|
||||
|
||||
|
||||
def get_norm_layer(norm_type="instance"):
|
||||
if norm_type == "batch":
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||||
elif norm_type == "instance":
|
||||
norm_layer = functools.partial(
|
||||
nn.InstanceNorm2d, affine=False, track_running_stats=False
|
||||
)
|
||||
elif norm_type == "none":
|
||||
norm_layer = None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"normalization layer [%s] is not found" % norm_type
|
||||
)
|
||||
return norm_layer
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||
super().__init__()
|
||||
self.conv_block = self.build_conv_block(
|
||||
dim, padding_type, norm_layer, use_dropout, use_bias
|
||||
)
|
||||
|
||||
def build_conv_block(
|
||||
self, dim, padding_type, norm_layer, use_dropout, use_bias
|
||||
):
|
||||
conv_block = []
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"padding [%s] is not implemented" % padding_type
|
||||
)
|
||||
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim),
|
||||
nn.ReLU(True),
|
||||
]
|
||||
if use_dropout:
|
||||
conv_block += [nn.Dropout(0.5)]
|
||||
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"padding [%s] is not implemented" % padding_type
|
||||
)
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim),
|
||||
]
|
||||
|
||||
return nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.conv_block(x)
|
||||
|
||||
|
||||
class LocNet(nn.Module):
|
||||
"""Localization network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nc,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
use_dropout=False,
|
||||
padding_type="zero",
|
||||
image_size=32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
backbone = []
|
||||
backbone += [
|
||||
nn.Conv2d(
|
||||
input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False
|
||||
)
|
||||
]
|
||||
backbone += [nn.BatchNorm2d(nc)]
|
||||
backbone += [nn.ReLU(True)]
|
||||
for _ in range(n_blocks):
|
||||
backbone += [
|
||||
ResnetBlock(
|
||||
nc,
|
||||
padding_type=padding_type,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
use_dropout=use_dropout,
|
||||
use_bias=False,
|
||||
)
|
||||
]
|
||||
backbone += [nn.MaxPool2d(2, stride=2)]
|
||||
self.backbone = nn.Sequential(*backbone)
|
||||
reduced_imsize = int(image_size * 0.5**(n_blocks + 1))
|
||||
self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc_loc(x)
|
||||
x = torch.tanh(x)
|
||||
x = x.view(-1, 2, 2)
|
||||
theta = x.data.new_zeros(x.size(0), 2, 3)
|
||||
theta[:, :, :2] = x
|
||||
return theta
|
||||
|
||||
|
||||
class FCN(nn.Module):
|
||||
"""Fully convolutional network."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nc,
|
||||
output_nc,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
use_dropout=False,
|
||||
padding_type="reflect",
|
||||
gctx=True,
|
||||
stn=False,
|
||||
image_size=32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
backbone = []
|
||||
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
backbone += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
backbone += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
backbone += [
|
||||
nn.Conv2d(
|
||||
input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False
|
||||
)
|
||||
]
|
||||
backbone += [norm_layer(nc)]
|
||||
backbone += [nn.ReLU(True)]
|
||||
|
||||
for _ in range(n_blocks):
|
||||
backbone += [
|
||||
ResnetBlock(
|
||||
nc,
|
||||
padding_type=padding_type,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=use_dropout,
|
||||
use_bias=False,
|
||||
)
|
||||
]
|
||||
self.backbone = nn.Sequential(*backbone)
|
||||
|
||||
# global context fusion layer
|
||||
self.gctx_fusion = None
|
||||
if gctx:
|
||||
self.gctx_fusion = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False
|
||||
),
|
||||
norm_layer(nc),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
self.regress = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True
|
||||
),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.locnet = None
|
||||
if stn:
|
||||
self.locnet = LocNet(
|
||||
input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size
|
||||
)
|
||||
|
||||
def init_loc_layer(self):
|
||||
"""Initialize the weights/bias with identity transformation."""
|
||||
if self.locnet is not None:
|
||||
self.locnet.fc_loc.weight.data.zero_()
|
||||
self.locnet.fc_loc.bias.data.copy_(
|
||||
torch.tensor([1, 0, 0, 1], dtype=torch.float)
|
||||
)
|
||||
|
||||
def stn(self, x):
|
||||
"""Spatial transformer network."""
|
||||
theta = self.locnet(x)
|
||||
grid = F.affine_grid(theta, x.size())
|
||||
return F.grid_sample(x, grid), theta
|
||||
|
||||
def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): input mini-batch.
|
||||
lmda (float): multiplier for perturbation.
|
||||
return_p (bool): return perturbation.
|
||||
return_stn_output (bool): return the output of stn.
|
||||
"""
|
||||
theta = None
|
||||
if self.locnet is not None:
|
||||
x, theta = self.stn(x)
|
||||
input = x
|
||||
|
||||
x = self.backbone(x)
|
||||
if self.gctx_fusion is not None:
|
||||
c = F.adaptive_avg_pool2d(x, (1, 1))
|
||||
c = c.expand_as(x)
|
||||
x = torch.cat([x, c], 1)
|
||||
x = self.gctx_fusion(x)
|
||||
|
||||
p = self.regress(x)
|
||||
x_p = input + lmda*p
|
||||
|
||||
if return_stn_output:
|
||||
return x_p, p, input
|
||||
|
||||
if return_p:
|
||||
return x_p, p
|
||||
|
||||
return x_p
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x32_gctx(**kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x64_gctx(**kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x32_gctx_stn(image_size=32, **kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(
|
||||
3,
|
||||
3,
|
||||
nc=32,
|
||||
n_blocks=3,
|
||||
norm_layer=norm_layer,
|
||||
stn=True,
|
||||
image_size=image_size
|
||||
)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
net.init_loc_layer()
|
||||
return net
|
||||
|
||||
|
||||
@NETWORK_REGISTRY.register()
|
||||
def fcn_3x64_gctx_stn(image_size=224, **kwargs):
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
net = FCN(
|
||||
3,
|
||||
3,
|
||||
nc=64,
|
||||
n_blocks=3,
|
||||
norm_layer=norm_layer,
|
||||
stn=True,
|
||||
image_size=image_size
|
||||
)
|
||||
init_network_weights(net, init_type="normal", gain=0.02)
|
||||
net.init_loc_layer()
|
||||
return net
|
||||
16
Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py
Normal file
16
Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .mmd import MaximumMeanDiscrepancy
|
||||
from .dsbn import DSBN1d, DSBN2d
|
||||
from .mixup import mixup
|
||||
from .efdmix import (
|
||||
EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix,
|
||||
crossdomain_efdmix, run_without_efdmix
|
||||
)
|
||||
from .mixstyle import (
|
||||
MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle,
|
||||
deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle
|
||||
)
|
||||
from .transnorm import TransNorm1d, TransNorm2d
|
||||
from .sequential2 import Sequential2
|
||||
from .reverse_grad import ReverseGrad
|
||||
from .cross_entropy import cross_entropy
|
||||
from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance
|
||||
30
Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py
Normal file
30
Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def cross_entropy(input, target, label_smooth=0, reduction="mean"):
|
||||
"""Cross entropy loss.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): logit matrix with shape of (batch, num_classes).
|
||||
target (torch.LongTensor): int label matrix.
|
||||
label_smooth (float, optional): label smoothing hyper-parameter.
|
||||
Default is 0.
|
||||
reduction (str, optional): how the losses for a mini-batch
|
||||
will be aggregated. Default is 'mean'.
|
||||
"""
|
||||
num_classes = input.shape[1]
|
||||
log_prob = F.log_softmax(input, dim=1)
|
||||
zeros = torch.zeros(log_prob.size())
|
||||
target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1)
|
||||
target = target.type_as(input)
|
||||
target = (1-label_smooth) * target + label_smooth/num_classes
|
||||
loss = (-target * log_prob).sum(1)
|
||||
if reduction == "mean":
|
||||
return loss.mean()
|
||||
elif reduction == "sum":
|
||||
return loss.sum()
|
||||
elif reduction == "none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError
|
||||
45
Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py
Normal file
45
Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _DSBN(nn.Module):
|
||||
"""Domain Specific Batch Normalization.
|
||||
|
||||
Args:
|
||||
num_features (int): number of features.
|
||||
n_domain (int): number of domains.
|
||||
bn_type (str): type of bn. Choices are ['1d', '2d'].
|
||||
"""
|
||||
|
||||
def __init__(self, num_features, n_domain, bn_type):
|
||||
super().__init__()
|
||||
if bn_type == "1d":
|
||||
BN = nn.BatchNorm1d
|
||||
elif bn_type == "2d":
|
||||
BN = nn.BatchNorm2d
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))
|
||||
|
||||
self.valid_domain_idxs = list(range(n_domain))
|
||||
self.n_domain = n_domain
|
||||
self.domain_idx = 0
|
||||
|
||||
def select_bn(self, domain_idx=0):
|
||||
assert domain_idx in self.valid_domain_idxs
|
||||
self.domain_idx = domain_idx
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn[self.domain_idx](x)
|
||||
|
||||
|
||||
class DSBN1d(_DSBN):
|
||||
|
||||
def __init__(self, num_features, n_domain):
|
||||
super().__init__(num_features, n_domain, "1d")
|
||||
|
||||
|
||||
class DSBN2d(_DSBN):
|
||||
|
||||
def __init__(self, num_features, n_domain):
|
||||
super().__init__(num_features, n_domain, "2d")
|
||||
118
Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py
Normal file
118
Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def deactivate_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.set_activation_status(False)
|
||||
|
||||
|
||||
def activate_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.set_activation_status(True)
|
||||
|
||||
|
||||
def random_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.update_mix_method("random")
|
||||
|
||||
|
||||
def crossdomain_efdmix(m):
|
||||
if type(m) == EFDMix:
|
||||
m.update_mix_method("crossdomain")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_without_efdmix(model):
|
||||
# Assume MixStyle was initially activated
|
||||
try:
|
||||
model.apply(deactivate_efdmix)
|
||||
yield
|
||||
finally:
|
||||
model.apply(activate_efdmix)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_with_efdmix(model, mix=None):
|
||||
# Assume MixStyle was initially deactivated
|
||||
if mix == "random":
|
||||
model.apply(random_efdmix)
|
||||
|
||||
elif mix == "crossdomain":
|
||||
model.apply(crossdomain_efdmix)
|
||||
|
||||
try:
|
||||
model.apply(activate_efdmix)
|
||||
yield
|
||||
finally:
|
||||
model.apply(deactivate_efdmix)
|
||||
|
||||
|
||||
class EFDMix(nn.Module):
|
||||
"""EFDMix.
|
||||
|
||||
Reference:
|
||||
Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
|
||||
"""
|
||||
Args:
|
||||
p (float): probability of using MixStyle.
|
||||
alpha (float): parameter of the Beta distribution.
|
||||
eps (float): scaling parameter to avoid numerical issues.
|
||||
mix (str): how to mix.
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.beta = torch.distributions.Beta(alpha, alpha)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
self.mix = mix
|
||||
self._activated = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
|
||||
)
|
||||
|
||||
def set_activation_status(self, status=True):
|
||||
self._activated = status
|
||||
|
||||
def update_mix_method(self, mix="random"):
|
||||
self.mix = mix
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self._activated:
|
||||
return x
|
||||
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
|
||||
B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||
x_view = x.view(B, C, -1)
|
||||
value_x, index_x = torch.sort(x_view) # sort inputs
|
||||
lmda = self.beta.sample((B, 1, 1))
|
||||
lmda = lmda.to(x.device)
|
||||
|
||||
if self.mix == "random":
|
||||
# random shuffle
|
||||
perm = torch.randperm(B)
|
||||
|
||||
elif self.mix == "crossdomain":
|
||||
# split into two halves and swap the order
|
||||
perm = torch.arange(B - 1, -1, -1) # inverse index
|
||||
perm_b, perm_a = perm.chunk(2)
|
||||
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
|
||||
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
|
||||
perm = torch.cat([perm_b, perm_a], 0)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
inverse_index = index_x.argsort(-1)
|
||||
x_view_copy = value_x[perm].gather(-1, inverse_index)
|
||||
new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)
|
||||
return new_x.view(B, C, W, H)
|
||||
124
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def deactivate_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.set_activation_status(False)
|
||||
|
||||
|
||||
def activate_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.set_activation_status(True)
|
||||
|
||||
|
||||
def random_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.update_mix_method("random")
|
||||
|
||||
|
||||
def crossdomain_mixstyle(m):
|
||||
if type(m) == MixStyle:
|
||||
m.update_mix_method("crossdomain")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_without_mixstyle(model):
|
||||
# Assume MixStyle was initially activated
|
||||
try:
|
||||
model.apply(deactivate_mixstyle)
|
||||
yield
|
||||
finally:
|
||||
model.apply(activate_mixstyle)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def run_with_mixstyle(model, mix=None):
|
||||
# Assume MixStyle was initially deactivated
|
||||
if mix == "random":
|
||||
model.apply(random_mixstyle)
|
||||
|
||||
elif mix == "crossdomain":
|
||||
model.apply(crossdomain_mixstyle)
|
||||
|
||||
try:
|
||||
model.apply(activate_mixstyle)
|
||||
yield
|
||||
finally:
|
||||
model.apply(deactivate_mixstyle)
|
||||
|
||||
|
||||
class MixStyle(nn.Module):
|
||||
"""MixStyle.
|
||||
|
||||
Reference:
|
||||
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
|
||||
"""
|
||||
|
||||
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
|
||||
"""
|
||||
Args:
|
||||
p (float): probability of using MixStyle.
|
||||
alpha (float): parameter of the Beta distribution.
|
||||
eps (float): scaling parameter to avoid numerical issues.
|
||||
mix (str): how to mix.
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.beta = torch.distributions.Beta(alpha, alpha)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
self.mix = mix
|
||||
self._activated = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
|
||||
)
|
||||
|
||||
def set_activation_status(self, status=True):
|
||||
self._activated = status
|
||||
|
||||
def update_mix_method(self, mix="random"):
|
||||
self.mix = mix
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self._activated:
|
||||
return x
|
||||
|
||||
if random.random() > self.p:
|
||||
return x
|
||||
|
||||
B = x.size(0)
|
||||
|
||||
mu = x.mean(dim=[2, 3], keepdim=True)
|
||||
var = x.var(dim=[2, 3], keepdim=True)
|
||||
sig = (var + self.eps).sqrt()
|
||||
mu, sig = mu.detach(), sig.detach()
|
||||
x_normed = (x-mu) / sig
|
||||
|
||||
lmda = self.beta.sample((B, 1, 1, 1))
|
||||
lmda = lmda.to(x.device)
|
||||
|
||||
if self.mix == "random":
|
||||
# random shuffle
|
||||
perm = torch.randperm(B)
|
||||
|
||||
elif self.mix == "crossdomain":
|
||||
# split into two halves and swap the order
|
||||
perm = torch.arange(B - 1, -1, -1) # inverse index
|
||||
perm_b, perm_a = perm.chunk(2)
|
||||
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
|
||||
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
|
||||
perm = torch.cat([perm_b, perm_a], 0)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
mu2, sig2 = mu[perm], sig[perm]
|
||||
mu_mix = mu*lmda + mu2 * (1-lmda)
|
||||
sig_mix = sig*lmda + sig2 * (1-lmda)
|
||||
|
||||
return x_normed*sig_mix + mu_mix
|
||||
23
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
Normal file
23
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
def mixup(x1, x2, y1, y2, beta, preserve_order=False):
|
||||
"""Mixup.
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): data with shape of (b, c, h, w).
|
||||
x2 (torch.Tensor): data with shape of (b, c, h, w).
|
||||
y1 (torch.Tensor): label with shape of (b, n).
|
||||
y2 (torch.Tensor): label with shape of (b, n).
|
||||
beta (float): hyper-parameter for Beta sampling.
|
||||
preserve_order (bool): apply lmda=max(lmda, 1-lmda).
|
||||
Default is False.
|
||||
"""
|
||||
lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])
|
||||
if preserve_order:
|
||||
lmda = torch.max(lmda, 1 - lmda)
|
||||
lmda = lmda.to(x1.device)
|
||||
xmix = x1*lmda + x2 * (1-lmda)
|
||||
lmda = lmda[:, :, 0, 0]
|
||||
ymix = y1*lmda + y2 * (1-lmda)
|
||||
return xmix, ymix
|
||||
91
Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py
Normal file
91
Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class MaximumMeanDiscrepancy(nn.Module):
|
||||
|
||||
def __init__(self, kernel_type="rbf", normalize=False):
|
||||
super().__init__()
|
||||
self.kernel_type = kernel_type
|
||||
self.normalize = normalize
|
||||
|
||||
def forward(self, x, y):
|
||||
# x, y: two batches of data with shape (batch, dim)
|
||||
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
|
||||
if self.normalize:
|
||||
x = F.normalize(x, dim=1)
|
||||
y = F.normalize(y, dim=1)
|
||||
if self.kernel_type == "linear":
|
||||
return self.linear_mmd(x, y)
|
||||
elif self.kernel_type == "poly":
|
||||
return self.poly_mmd(x, y)
|
||||
elif self.kernel_type == "rbf":
|
||||
return self.rbf_mmd(x, y)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def linear_mmd(self, x, y):
|
||||
# k(x, y) = x^T y
|
||||
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
|
||||
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
|
||||
k_xy = torch.mm(x, y.t())
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
|
||||
# k(x, y) = (alpha * x^T y + c)^d
|
||||
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
|
||||
k_xx = (alpha*k_xx + c).pow(d)
|
||||
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
|
||||
k_yy = (alpha*k_yy + c).pow(d)
|
||||
k_xy = torch.mm(x, y.t())
|
||||
k_xy = (alpha*k_xy + c).pow(d)
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
def rbf_mmd(self, x, y):
|
||||
# k_xx
|
||||
d_xx = self.euclidean_squared_distance(x, x)
|
||||
d_xx = self.remove_self_distance(d_xx)
|
||||
k_xx = self.rbf_kernel_mixture(d_xx)
|
||||
# k_yy
|
||||
d_yy = self.euclidean_squared_distance(y, y)
|
||||
d_yy = self.remove_self_distance(d_yy)
|
||||
k_yy = self.rbf_kernel_mixture(d_yy)
|
||||
# k_xy
|
||||
d_xy = self.euclidean_squared_distance(x, y)
|
||||
k_xy = self.rbf_kernel_mixture(d_xy)
|
||||
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
||||
|
||||
@staticmethod
|
||||
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
|
||||
K = 0
|
||||
for sigma in sigmas:
|
||||
gamma = 1.0 / (2.0 * sigma**2)
|
||||
K += torch.exp(-gamma * exponent)
|
||||
return K
|
||||
|
||||
@staticmethod
|
||||
def remove_self_distance(distmat):
|
||||
tmp_list = []
|
||||
for i, row in enumerate(distmat):
|
||||
row1 = torch.cat([row[:i], row[i + 1:]])
|
||||
tmp_list.append(row1)
|
||||
return torch.stack(tmp_list)
|
||||
|
||||
@staticmethod
|
||||
def euclidean_squared_distance(x, y):
|
||||
m, n = x.size(0), y.size(0)
|
||||
distmat = (
|
||||
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
||||
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
)
|
||||
# distmat.addmm_(1, -2, x, y.t())
|
||||
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
|
||||
return distmat
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
|
||||
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
|
||||
d = mmd(input1, input2)
|
||||
print(d.item())
|
||||
147
Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
Normal file
147
Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class OptimalTransport(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def distance(batch1, batch2, dist_metric="cosine"):
|
||||
if dist_metric == "cosine":
|
||||
batch1 = F.normalize(batch1, p=2, dim=1)
|
||||
batch2 = F.normalize(batch2, p=2, dim=1)
|
||||
dist_mat = 1 - torch.mm(batch1, batch2.t())
|
||||
elif dist_metric == "euclidean":
|
||||
m, n = batch1.size(0), batch2.size(0)
|
||||
dist_mat = (
|
||||
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
||||
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
)
|
||||
dist_mat.addmm_(
|
||||
1, -2, batch1, batch2.t()
|
||||
) # squared euclidean distance
|
||||
elif dist_metric == "fast_euclidean":
|
||||
batch1 = batch1.unsqueeze(-2)
|
||||
batch2 = batch2.unsqueeze(-3)
|
||||
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown cost function: {}. Expected to "
|
||||
"be one of [cosine | euclidean]".format(dist_metric)
|
||||
)
|
||||
return dist_mat
|
||||
|
||||
|
||||
class SinkhornDivergence(OptimalTransport):
|
||||
thre = 1e-3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist_metric="cosine",
|
||||
eps=0.01,
|
||||
max_iter=5,
|
||||
bp_to_sinkhorn=False
|
||||
):
|
||||
super().__init__()
|
||||
self.dist_metric = dist_metric
|
||||
self.eps = eps
|
||||
self.max_iter = max_iter
|
||||
self.bp_to_sinkhorn = bp_to_sinkhorn
|
||||
|
||||
def forward(self, x, y):
|
||||
# x, y: two batches of data with shape (batch, dim)
|
||||
W_xy = self.transport_cost(x, y)
|
||||
W_xx = self.transport_cost(x, x)
|
||||
W_yy = self.transport_cost(y, y)
|
||||
return 2*W_xy - W_xx - W_yy
|
||||
|
||||
def transport_cost(self, x, y, return_pi=False):
|
||||
C = self.distance(x, y, dist_metric=self.dist_metric)
|
||||
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
|
||||
if not self.bp_to_sinkhorn:
|
||||
pi = pi.detach()
|
||||
cost = torch.sum(pi * C)
|
||||
if return_pi:
|
||||
return cost, pi
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def sinkhorn_iterate(C, eps, max_iter, thre):
|
||||
nx, ny = C.shape
|
||||
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
|
||||
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
|
||||
u = torch.zeros_like(mu)
|
||||
v = torch.zeros_like(nu)
|
||||
|
||||
def M(_C, _u, _v):
|
||||
"""Modified cost for logarithmic updates.
|
||||
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
|
||||
"""
|
||||
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
|
||||
|
||||
real_iter = 0 # check if algorithm terminates before max_iter
|
||||
# Sinkhorn iterations
|
||||
for i in range(max_iter):
|
||||
u0 = u
|
||||
u = eps * (
|
||||
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
|
||||
) + u
|
||||
v = (
|
||||
eps * (
|
||||
torch.log(nu + 1e-8) -
|
||||
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
|
||||
) + v
|
||||
)
|
||||
err = (u - u0).abs().sum()
|
||||
real_iter += 1
|
||||
if err.item() < thre:
|
||||
break
|
||||
# Transport plan pi = diag(a)*K*diag(b)
|
||||
return torch.exp(M(C, u, v))
|
||||
|
||||
|
||||
class MinibatchEnergyDistance(SinkhornDivergence):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist_metric="cosine",
|
||||
eps=0.01,
|
||||
max_iter=5,
|
||||
bp_to_sinkhorn=False
|
||||
):
|
||||
super().__init__(
|
||||
dist_metric=dist_metric,
|
||||
eps=eps,
|
||||
max_iter=max_iter,
|
||||
bp_to_sinkhorn=bp_to_sinkhorn,
|
||||
)
|
||||
|
||||
def forward(self, x, y):
|
||||
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
|
||||
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
|
||||
cost = 0
|
||||
cost += self.transport_cost(x1, y1)
|
||||
cost += self.transport_cost(x1, y2)
|
||||
cost += self.transport_cost(x2, y1)
|
||||
cost += self.transport_cost(x2, y2)
|
||||
cost -= 2 * self.transport_cost(x1, x2)
|
||||
cost -= 2 * self.transport_cost(y1, y2)
|
||||
return cost
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# example: https://dfdazac.github.io/sinkhorn.html
|
||||
import numpy as np
|
||||
|
||||
n_points = 5
|
||||
a = np.array([[i, 0] for i in range(n_points)])
|
||||
b = np.array([[i, 1] for i in range(n_points)])
|
||||
x = torch.tensor(a, dtype=torch.float)
|
||||
y = torch.tensor(b, dtype=torch.float)
|
||||
sinkhorn = SinkhornDivergence(
|
||||
dist_metric="euclidean", eps=0.01, max_iter=5
|
||||
)
|
||||
dist, pi = sinkhorn.transport_cost(x, y, True)
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
34
Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
Normal file
34
Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class _ReverseGrad(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, grad_scaling):
|
||||
ctx.grad_scaling = grad_scaling
|
||||
return input.view_as(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_scaling = ctx.grad_scaling
|
||||
return -grad_scaling * grad_output, None
|
||||
|
||||
|
||||
reverse_grad = _ReverseGrad.apply
|
||||
|
||||
|
||||
class ReverseGrad(nn.Module):
|
||||
"""Gradient reversal layer.
|
||||
|
||||
It acts as an identity layer in the forward,
|
||||
but reverses the sign of the gradient in
|
||||
the backward.
|
||||
"""
|
||||
|
||||
def forward(self, x, grad_scaling=1.0):
|
||||
assert (grad_scaling >=
|
||||
0), "grad_scaling must be non-negative, " "but got {}".format(
|
||||
grad_scaling
|
||||
)
|
||||
return reverse_grad(x, grad_scaling)
|
||||
15
Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py
Normal file
15
Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Sequential2(nn.Sequential):
|
||||
"""An alternative sequential container to nn.Sequential,
|
||||
which accepts an arbitrary number of input arguments.
|
||||
"""
|
||||
|
||||
def forward(self, *inputs):
|
||||
for module in self._modules.values():
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
return inputs
|
||||
138
Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py
Normal file
138
Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _TransNorm(nn.Module):
|
||||
"""Transferable normalization.
|
||||
|
||||
Reference:
|
||||
- Wang et al. Transferable Normalization: Towards Improving
|
||||
Transferability of Deep Neural Networks. NeurIPS 2019.
|
||||
|
||||
Args:
|
||||
num_features (int): number of features.
|
||||
eps (float): epsilon.
|
||||
momentum (float): value for updating running_mean and running_var.
|
||||
adaptive_alpha (bool): apply domain adaptive alpha.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True
|
||||
):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.adaptive_alpha = adaptive_alpha
|
||||
|
||||
self.register_buffer("running_mean_s", torch.zeros(num_features))
|
||||
self.register_buffer("running_var_s", torch.ones(num_features))
|
||||
self.register_buffer("running_mean_t", torch.zeros(num_features))
|
||||
self.register_buffer("running_var_t", torch.ones(num_features))
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
|
||||
def resnet_running_stats(self):
|
||||
self.running_mean_s.zero_()
|
||||
self.running_var_s.fill_(1)
|
||||
self.running_mean_t.zero_()
|
||||
self.running_var_t.fill_(1)
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def _check_input(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def _compute_alpha(self, mean_s, var_s, mean_t, var_t):
|
||||
C = self.num_features
|
||||
ratio_s = mean_s / (var_s + self.eps).sqrt()
|
||||
ratio_t = mean_t / (var_t + self.eps).sqrt()
|
||||
dist = (ratio_s - ratio_t).abs()
|
||||
dist_inv = 1 / (1+dist)
|
||||
return C * dist_inv / dist_inv.sum()
|
||||
|
||||
def forward(self, input):
|
||||
self._check_input(input)
|
||||
C = self.num_features
|
||||
if input.dim() == 2:
|
||||
new_shape = (1, C)
|
||||
elif input.dim() == 4:
|
||||
new_shape = (1, C, 1, 1)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
weight = self.weight.view(*new_shape)
|
||||
bias = self.bias.view(*new_shape)
|
||||
|
||||
if not self.training:
|
||||
mean_t = self.running_mean_t.view(*new_shape)
|
||||
var_t = self.running_var_t.view(*new_shape)
|
||||
output = (input-mean_t) / (var_t + self.eps).sqrt()
|
||||
output = output*weight + bias
|
||||
|
||||
if self.adaptive_alpha:
|
||||
mean_s = self.running_mean_s.view(*new_shape)
|
||||
var_s = self.running_var_s.view(*new_shape)
|
||||
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
|
||||
alpha = alpha.reshape(*new_shape)
|
||||
output = (1 + alpha.detach()) * output
|
||||
|
||||
return output
|
||||
|
||||
input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)
|
||||
|
||||
x_s = input_s.transpose(0, 1).reshape(C, -1)
|
||||
mean_s = x_s.mean(1)
|
||||
var_s = x_s.var(1)
|
||||
self.running_mean_s.mul_(self.momentum)
|
||||
self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
|
||||
self.running_var_s.mul_(self.momentum)
|
||||
self.running_var_s.add_((1 - self.momentum) * var_s.data)
|
||||
mean_s = mean_s.reshape(*new_shape)
|
||||
var_s = var_s.reshape(*new_shape)
|
||||
output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
|
||||
output_s = output_s*weight + bias
|
||||
|
||||
x_t = input_t.transpose(0, 1).reshape(C, -1)
|
||||
mean_t = x_t.mean(1)
|
||||
var_t = x_t.var(1)
|
||||
self.running_mean_t.mul_(self.momentum)
|
||||
self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
|
||||
self.running_var_t.mul_(self.momentum)
|
||||
self.running_var_t.add_((1 - self.momentum) * var_t.data)
|
||||
mean_t = mean_t.reshape(*new_shape)
|
||||
var_t = var_t.reshape(*new_shape)
|
||||
output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
|
||||
output_t = output_t*weight + bias
|
||||
|
||||
output = torch.cat([output_s, output_t], 0)
|
||||
|
||||
if self.adaptive_alpha:
|
||||
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
|
||||
alpha = alpha.reshape(*new_shape)
|
||||
output = (1 + alpha.detach()) * output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransNorm1d(_TransNorm):
|
||||
|
||||
def _check_input(self, x):
|
||||
if x.dim() != 2:
|
||||
raise ValueError(
|
||||
"Expected the input to be 2-D, "
|
||||
"but got {}-D".format(x.dim())
|
||||
)
|
||||
|
||||
|
||||
class TransNorm2d(_TransNorm):
|
||||
|
||||
def _check_input(self, x):
|
||||
if x.dim() != 4:
|
||||
raise ValueError(
|
||||
"Expected the input to be 4-D, "
|
||||
"but got {}-D".format(x.dim())
|
||||
)
|
||||
75
Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py
Normal file
75
Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def sharpen_prob(p, temperature=2):
|
||||
"""Sharpening probability with a temperature.
|
||||
|
||||
Args:
|
||||
p (torch.Tensor): probability matrix (batch_size, n_classes)
|
||||
temperature (float): temperature.
|
||||
"""
|
||||
p = p.pow(temperature)
|
||||
return p / p.sum(1, keepdim=True)
|
||||
|
||||
|
||||
def reverse_index(data, label):
|
||||
"""Reverse order."""
|
||||
inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()
|
||||
return data[inv_idx], label[inv_idx]
|
||||
|
||||
|
||||
def shuffle_index(data, label):
|
||||
"""Shuffle order."""
|
||||
rnd_idx = torch.randperm(data.shape[0])
|
||||
return data[rnd_idx], label[rnd_idx]
|
||||
|
||||
|
||||
def create_onehot(label, num_classes):
|
||||
"""Create one-hot tensor.
|
||||
|
||||
We suggest using nn.functional.one_hot.
|
||||
|
||||
Args:
|
||||
label (torch.Tensor): 1-D tensor.
|
||||
num_classes (int): number of classes.
|
||||
"""
|
||||
onehot = torch.zeros(label.shape[0], num_classes)
|
||||
return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)
|
||||
|
||||
|
||||
def sigmoid_rampup(current, rampup_length):
|
||||
"""Exponential rampup.
|
||||
|
||||
Args:
|
||||
current (int): current step.
|
||||
rampup_length (int): maximum step.
|
||||
"""
|
||||
assert rampup_length > 0
|
||||
current = np.clip(current, 0.0, rampup_length)
|
||||
phase = 1.0 - current/rampup_length
|
||||
return float(np.exp(-5.0 * phase * phase))
|
||||
|
||||
|
||||
def linear_rampup(current, rampup_length):
|
||||
"""Linear rampup.
|
||||
|
||||
Args:
|
||||
current (int): current step.
|
||||
rampup_length (int): maximum step.
|
||||
"""
|
||||
assert rampup_length > 0
|
||||
ratio = np.clip(current / rampup_length, 0.0, 1.0)
|
||||
return float(ratio)
|
||||
|
||||
|
||||
def ema_model_update(model, ema_model, alpha):
|
||||
"""Exponential moving average of model parameters.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model being trained.
|
||||
ema_model (nn.Module): ema of the model.
|
||||
alpha (float): ema decay rate.
|
||||
"""
|
||||
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
|
||||
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
|
||||
2
Dassl.ProGrad.pytorch/dassl/optim/__init__.py
Normal file
2
Dassl.ProGrad.pytorch/dassl/optim/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .optimizer import build_optimizer
|
||||
from .lr_scheduler import build_lr_scheduler
|
||||
154
Dassl.ProGrad.pytorch/dassl/optim/lr_scheduler.py
Normal file
154
Dassl.ProGrad.pytorch/dassl/optim/lr_scheduler.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Modified from https://github.com/KaiyangZhou/deep-person-reid
|
||||
"""
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
AVAI_SCHEDS = ["single_step", "multi_step", "cosine"]
|
||||
|
||||
|
||||
class _BaseWarmupScheduler(_LRScheduler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
successor,
|
||||
warmup_epoch,
|
||||
last_epoch=-1,
|
||||
verbose=False
|
||||
):
|
||||
self.successor = successor
|
||||
self.warmup_epoch = warmup_epoch
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def step(self, epoch=None):
|
||||
if self.last_epoch >= self.warmup_epoch:
|
||||
self.successor.step(epoch)
|
||||
self._last_lr = self.successor.get_last_lr()
|
||||
else:
|
||||
super().step(epoch)
|
||||
|
||||
|
||||
class ConstantWarmupScheduler(_BaseWarmupScheduler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
successor,
|
||||
warmup_epoch,
|
||||
cons_lr,
|
||||
last_epoch=-1,
|
||||
verbose=False
|
||||
):
|
||||
self.cons_lr = cons_lr
|
||||
super().__init__(
|
||||
optimizer, successor, warmup_epoch, last_epoch, verbose
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epoch:
|
||||
return self.successor.get_last_lr()
|
||||
return [self.cons_lr for _ in self.base_lrs]
|
||||
|
||||
|
||||
class LinearWarmupScheduler(_BaseWarmupScheduler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
successor,
|
||||
warmup_epoch,
|
||||
min_lr,
|
||||
last_epoch=-1,
|
||||
verbose=False
|
||||
):
|
||||
self.min_lr = min_lr
|
||||
super().__init__(
|
||||
optimizer, successor, warmup_epoch, last_epoch, verbose
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epoch:
|
||||
return self.successor.get_last_lr()
|
||||
if self.last_epoch == 0:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
return [
|
||||
lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
def build_lr_scheduler(optimizer, optim_cfg):
|
||||
"""A function wrapper for building a learning rate scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): an Optimizer.
|
||||
optim_cfg (CfgNode): optimization config.
|
||||
"""
|
||||
lr_scheduler = optim_cfg.LR_SCHEDULER
|
||||
stepsize = optim_cfg.STEPSIZE
|
||||
gamma = optim_cfg.GAMMA
|
||||
max_epoch = optim_cfg.MAX_EPOCH
|
||||
|
||||
if lr_scheduler not in AVAI_SCHEDS:
|
||||
raise ValueError(
|
||||
"Unsupported scheduler: {}. Must be one of {}".format(
|
||||
lr_scheduler, AVAI_SCHEDS
|
||||
)
|
||||
)
|
||||
|
||||
if lr_scheduler == "single_step":
|
||||
if isinstance(stepsize, (list, tuple)):
|
||||
stepsize = stepsize[-1]
|
||||
|
||||
if not isinstance(stepsize, int):
|
||||
raise TypeError(
|
||||
"For single_step lr_scheduler, stepsize must "
|
||||
"be an integer, but got {}".format(type(stepsize))
|
||||
)
|
||||
|
||||
if stepsize <= 0:
|
||||
stepsize = max_epoch
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=stepsize, gamma=gamma
|
||||
)
|
||||
|
||||
elif lr_scheduler == "multi_step":
|
||||
if not isinstance(stepsize, (list, tuple)):
|
||||
raise TypeError(
|
||||
"For multi_step lr_scheduler, stepsize must "
|
||||
"be a list, but got {}".format(type(stepsize))
|
||||
)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer, milestones=stepsize, gamma=gamma
|
||||
)
|
||||
|
||||
elif lr_scheduler == "cosine":
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, float(max_epoch)
|
||||
)
|
||||
|
||||
if optim_cfg.WARMUP_EPOCH > 0:
|
||||
if not optim_cfg.WARMUP_RECOUNT:
|
||||
scheduler.last_epoch = optim_cfg.WARMUP_EPOCH
|
||||
|
||||
if optim_cfg.WARMUP_TYPE == "constant":
|
||||
scheduler = ConstantWarmupScheduler(
|
||||
optimizer, scheduler, optim_cfg.WARMUP_EPOCH,
|
||||
optim_cfg.WARMUP_CONS_LR
|
||||
)
|
||||
|
||||
elif optim_cfg.WARMUP_TYPE == "linear":
|
||||
scheduler = LinearWarmupScheduler(
|
||||
optimizer, scheduler, optim_cfg.WARMUP_EPOCH,
|
||||
optim_cfg.WARMUP_MIN_LR
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return scheduler
|
||||
136
Dassl.ProGrad.pytorch/dassl/optim/optimizer.py
Normal file
136
Dassl.ProGrad.pytorch/dassl/optim/optimizer.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Modified from https://github.com/KaiyangZhou/deep-person-reid
|
||||
"""
|
||||
import warnings
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .radam import RAdam
|
||||
|
||||
AVAI_OPTIMS = ["adam", "amsgrad", "sgd", "rmsprop", "radam", "adamw"]
|
||||
|
||||
|
||||
def build_optimizer(model, optim_cfg):
|
||||
"""A function wrapper for building an optimizer.
|
||||
|
||||
Args:
|
||||
model (nn.Module or iterable): model.
|
||||
optim_cfg (CfgNode): optimization config.
|
||||
"""
|
||||
optim = optim_cfg.NAME
|
||||
lr = optim_cfg.LR
|
||||
weight_decay = optim_cfg.WEIGHT_DECAY
|
||||
momentum = optim_cfg.MOMENTUM
|
||||
sgd_dampening = optim_cfg.SGD_DAMPNING
|
||||
sgd_nesterov = optim_cfg.SGD_NESTEROV
|
||||
rmsprop_alpha = optim_cfg.RMSPROP_ALPHA
|
||||
adam_beta1 = optim_cfg.ADAM_BETA1
|
||||
adam_beta2 = optim_cfg.ADAM_BETA2
|
||||
staged_lr = optim_cfg.STAGED_LR
|
||||
new_layers = optim_cfg.NEW_LAYERS
|
||||
base_lr_mult = optim_cfg.BASE_LR_MULT
|
||||
|
||||
if optim not in AVAI_OPTIMS:
|
||||
raise ValueError(
|
||||
"Unsupported optim: {}. Must be one of {}".format(
|
||||
optim, AVAI_OPTIMS
|
||||
)
|
||||
)
|
||||
|
||||
if staged_lr:
|
||||
if not isinstance(model, nn.Module):
|
||||
raise TypeError(
|
||||
"When staged_lr is True, model given to "
|
||||
"build_optimizer() must be an instance of nn.Module"
|
||||
)
|
||||
|
||||
if isinstance(model, nn.DataParallel):
|
||||
model = model.module
|
||||
|
||||
if isinstance(new_layers, str):
|
||||
if new_layers is None:
|
||||
warnings.warn(
|
||||
"new_layers is empty, therefore, staged_lr is useless"
|
||||
)
|
||||
new_layers = [new_layers]
|
||||
|
||||
base_params = []
|
||||
base_layers = []
|
||||
new_params = []
|
||||
|
||||
for name, module in model.named_children():
|
||||
if name in new_layers:
|
||||
new_params += [p for p in module.parameters()]
|
||||
else:
|
||||
base_params += [p for p in module.parameters()]
|
||||
base_layers.append(name)
|
||||
|
||||
param_groups = [
|
||||
{
|
||||
"params": base_params,
|
||||
"lr": lr * base_lr_mult
|
||||
},
|
||||
{
|
||||
"params": new_params
|
||||
},
|
||||
]
|
||||
|
||||
else:
|
||||
if isinstance(model, nn.Module):
|
||||
param_groups = model.parameters()
|
||||
else:
|
||||
param_groups = model
|
||||
|
||||
if optim == "adam":
|
||||
optimizer = torch.optim.Adam(
|
||||
param_groups,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(adam_beta1, adam_beta2),
|
||||
)
|
||||
|
||||
elif optim == "amsgrad":
|
||||
optimizer = torch.optim.Adam(
|
||||
param_groups,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(adam_beta1, adam_beta2),
|
||||
amsgrad=True,
|
||||
)
|
||||
|
||||
elif optim == "sgd":
|
||||
optimizer = torch.optim.SGD(
|
||||
param_groups,
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
dampening=sgd_dampening,
|
||||
nesterov=sgd_nesterov,
|
||||
)
|
||||
|
||||
elif optim == "rmsprop":
|
||||
optimizer = torch.optim.RMSprop(
|
||||
param_groups,
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
alpha=rmsprop_alpha,
|
||||
)
|
||||
|
||||
elif optim == "radam":
|
||||
optimizer = RAdam(
|
||||
param_groups,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(adam_beta1, adam_beta2),
|
||||
)
|
||||
|
||||
elif optim == "adamw":
|
||||
optimizer = torch.optim.AdamW(
|
||||
param_groups,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
betas=(adam_beta1, adam_beta2),
|
||||
)
|
||||
|
||||
return optimizer
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user