release code
This commit is contained in:
1
Dassl.ProGrad.pytorch/dassl/data/__init__.py
Normal file
1
Dassl.ProGrad.pytorch/dassl/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .data_manager import DataManager, DatasetWrapper
|
||||
264
Dassl.ProGrad.pytorch/dassl/data/data_manager.py
Normal file
264
Dassl.ProGrad.pytorch/dassl/data/data_manager.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
from dassl.utils import read_image
|
||||
|
||||
from .datasets import build_dataset
|
||||
from .samplers import build_sampler
|
||||
from .transforms import build_transform
|
||||
|
||||
INTERPOLATION_MODES = {
|
||||
"bilinear": Image.BILINEAR,
|
||||
"bicubic": Image.BICUBIC,
|
||||
"nearest": Image.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
def build_data_loader(
|
||||
cfg,
|
||||
sampler_type="SequentialSampler",
|
||||
data_source=None,
|
||||
batch_size=64,
|
||||
n_domain=0,
|
||||
n_ins=2,
|
||||
tfm=None,
|
||||
is_train=True,
|
||||
dataset_wrapper=None,
|
||||
):
|
||||
# Build sampler
|
||||
sampler = build_sampler(
|
||||
sampler_type,
|
||||
cfg=cfg,
|
||||
data_source=data_source,
|
||||
batch_size=batch_size,
|
||||
n_domain=n_domain,
|
||||
n_ins=n_ins,
|
||||
)
|
||||
|
||||
if dataset_wrapper is None:
|
||||
dataset_wrapper = DatasetWrapper
|
||||
|
||||
# Build data loader
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
||||
drop_last=is_train and len(data_source) >= batch_size,
|
||||
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
|
||||
)
|
||||
assert len(data_loader) > 0
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
class DataManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
custom_tfm_train=None,
|
||||
custom_tfm_test=None,
|
||||
dataset_wrapper=None
|
||||
):
|
||||
# Load dataset
|
||||
dataset = build_dataset(cfg)
|
||||
# Build transform
|
||||
if custom_tfm_train is None:
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
else:
|
||||
print("* Using custom transform for training")
|
||||
tfm_train = custom_tfm_train
|
||||
|
||||
if custom_tfm_test is None:
|
||||
tfm_test = build_transform(cfg, is_train=False)
|
||||
else:
|
||||
print("* Using custom transform for testing")
|
||||
tfm_test = custom_tfm_test
|
||||
|
||||
# Build train_loader_x
|
||||
train_loader_x = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=dataset.train_x,
|
||||
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build train_loader_u
|
||||
train_loader_u = None
|
||||
if dataset.train_u:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
|
||||
|
||||
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
|
||||
|
||||
train_loader_u = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=sampler_type_,
|
||||
data_source=dataset.train_u,
|
||||
batch_size=batch_size_,
|
||||
n_domain=n_domain_,
|
||||
n_ins=n_ins_,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build val_loader
|
||||
val_loader = None
|
||||
if dataset.val:
|
||||
val_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.val,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build test_loader
|
||||
test_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.test,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Attributes
|
||||
self._num_classes = dataset.num_classes
|
||||
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
|
||||
self._lab2cname = dataset.lab2cname
|
||||
|
||||
# Dataset and data-loaders
|
||||
self.dataset = dataset
|
||||
self.train_loader_x = train_loader_x
|
||||
self.train_loader_u = train_loader_u
|
||||
self.val_loader = val_loader
|
||||
self.test_loader = test_loader
|
||||
|
||||
if cfg.VERBOSE:
|
||||
self.show_dataset_summary(cfg)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._num_classes
|
||||
|
||||
@property
|
||||
def num_source_domains(self):
|
||||
return self._num_source_domains
|
||||
|
||||
@property
|
||||
def lab2cname(self):
|
||||
return self._lab2cname
|
||||
|
||||
def show_dataset_summary(self, cfg):
|
||||
print("***** Dataset statistics *****")
|
||||
|
||||
print(" Dataset: {}".format(cfg.DATASET.NAME))
|
||||
|
||||
if cfg.DATASET.SOURCE_DOMAINS:
|
||||
print(" Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS))
|
||||
if cfg.DATASET.TARGET_DOMAINS:
|
||||
print(" Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS))
|
||||
|
||||
print(" # classes: {:,}".format(self.num_classes))
|
||||
|
||||
print(" # train_x: {:,}".format(len(self.dataset.train_x)))
|
||||
|
||||
if self.dataset.train_u:
|
||||
print(" # train_u: {:,}".format(len(self.dataset.train_u)))
|
||||
|
||||
if self.dataset.val:
|
||||
print(" # val: {:,}".format(len(self.dataset.val)))
|
||||
|
||||
print(" # test: {:,}".format(len(self.dataset.test)))
|
||||
|
||||
|
||||
class DatasetWrapper(TorchDataset):
|
||||
|
||||
def __init__(self, cfg, data_source, transform=None, is_train=False):
|
||||
self.cfg = cfg
|
||||
self.data_source = data_source
|
||||
self.transform = transform # accept list (tuple) as input
|
||||
self.is_train = is_train
|
||||
# Augmenting an image K>1 times is only allowed during training
|
||||
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
||||
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
||||
|
||||
if self.k_tfm > 1 and transform is None:
|
||||
raise ValueError(
|
||||
"Cannot augment the image {} times "
|
||||
"because transform is None".format(self.k_tfm)
|
||||
)
|
||||
|
||||
# Build transform that doesn't apply any data augmentation
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
to_tensor = []
|
||||
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
to_tensor += [T.ToTensor()]
|
||||
if "normalize" in cfg.INPUT.TRANSFORMS:
|
||||
normalize = T.Normalize(
|
||||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||||
)
|
||||
to_tensor += [normalize]
|
||||
self.to_tensor = T.Compose(to_tensor)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_source[idx]
|
||||
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath
|
||||
}
|
||||
|
||||
img0 = read_image(item.impath)
|
||||
|
||||
if self.transform is not None:
|
||||
if isinstance(self.transform, (list, tuple)):
|
||||
for i, tfm in enumerate(self.transform):
|
||||
img = self._transform_image(tfm, img0)
|
||||
keyname = "img"
|
||||
if (i + 1) > 1:
|
||||
keyname += str(i + 1)
|
||||
output[keyname] = img
|
||||
else:
|
||||
img = self._transform_image(self.transform, img0)
|
||||
output["img"] = img
|
||||
|
||||
if self.return_img0:
|
||||
output["img0"] = self.to_tensor(img0)
|
||||
|
||||
return output
|
||||
|
||||
def _transform_image(self, tfm, img0):
|
||||
img_list = []
|
||||
|
||||
for k in range(self.k_tfm):
|
||||
img_list.append(tfm(img0))
|
||||
|
||||
img = img_list
|
||||
if len(img) == 1:
|
||||
img = img[0]
|
||||
|
||||
return img
|
||||
6
Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .build import DATASET_REGISTRY, build_dataset # isort:skip
|
||||
from .base_dataset import Datum, DatasetBase # isort:skip
|
||||
|
||||
from .da import *
|
||||
from .dg import *
|
||||
from .ssl import *
|
||||
225
Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
Normal file
225
Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
import random
|
||||
import os.path as osp
|
||||
import tarfile
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
import gdown
|
||||
|
||||
from dassl.utils import check_isfile
|
||||
|
||||
|
||||
class Datum:
|
||||
"""Data instance which defines the basic attributes.
|
||||
|
||||
Args:
|
||||
impath (str): image path.
|
||||
label (int): class label.
|
||||
domain (int): domain label.
|
||||
classname (str): class name.
|
||||
"""
|
||||
|
||||
def __init__(self, impath="", label=0, domain=0, classname=""):
|
||||
assert isinstance(impath, str)
|
||||
assert check_isfile(impath)
|
||||
|
||||
self._impath = impath
|
||||
self._label = label
|
||||
self._domain = domain
|
||||
self._classname = classname
|
||||
|
||||
@property
|
||||
def impath(self):
|
||||
return self._impath
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return self._label
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
return self._domain
|
||||
|
||||
@property
|
||||
def classname(self):
|
||||
return self._classname
|
||||
|
||||
|
||||
class DatasetBase:
|
||||
"""A unified dataset class for
|
||||
1) domain adaptation
|
||||
2) domain generalization
|
||||
3) semi-supervised learning
|
||||
"""
|
||||
|
||||
dataset_dir = "" # the directory where the dataset is stored
|
||||
domains = [] # string names of all domains
|
||||
|
||||
def __init__(self, train_x=None, train_u=None, val=None, test=None):
|
||||
self._train_x = train_x # labeled training data
|
||||
self._train_u = train_u # unlabeled training data (optional)
|
||||
self._val = val # validation data (optional)
|
||||
self._test = test # test data
|
||||
|
||||
self._num_classes = self.get_num_classes(train_x)
|
||||
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
|
||||
|
||||
@property
|
||||
def train_x(self):
|
||||
return self._train_x
|
||||
|
||||
@property
|
||||
def train_u(self):
|
||||
return self._train_u
|
||||
|
||||
@property
|
||||
def val(self):
|
||||
return self._val
|
||||
|
||||
@property
|
||||
def test(self):
|
||||
return self._test
|
||||
|
||||
@property
|
||||
def lab2cname(self):
|
||||
return self._lab2cname
|
||||
|
||||
@property
|
||||
def classnames(self):
|
||||
return self._classnames
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._num_classes
|
||||
|
||||
def get_num_classes(self, data_source):
|
||||
"""Count number of classes.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
label_set = set()
|
||||
for item in data_source:
|
||||
label_set.add(item.label)
|
||||
return max(label_set) + 1
|
||||
|
||||
def get_lab2cname(self, data_source):
|
||||
"""Get a label-to-classname mapping (dict).
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
container = set()
|
||||
for item in data_source:
|
||||
container.add((item.label, item.classname))
|
||||
mapping = {label: classname for label, classname in container}
|
||||
labels = list(mapping.keys())
|
||||
labels.sort()
|
||||
classnames = [mapping[label] for label in labels]
|
||||
return mapping, classnames
|
||||
|
||||
def check_input_domains(self, source_domains, target_domains):
|
||||
self.is_input_domain_valid(source_domains)
|
||||
self.is_input_domain_valid(target_domains)
|
||||
|
||||
def is_input_domain_valid(self, input_domains):
|
||||
for domain in input_domains:
|
||||
if domain not in self.domains:
|
||||
raise ValueError(
|
||||
"Input domain must belong to {}, "
|
||||
"but got [{}]".format(self.domains, domain)
|
||||
)
|
||||
|
||||
def download_data(self, url, dst, from_gdrive=True):
|
||||
if not osp.exists(osp.dirname(dst)):
|
||||
os.makedirs(osp.dirname(dst))
|
||||
|
||||
if from_gdrive:
|
||||
gdown.download(url, dst, quiet=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
print("Extracting file ...")
|
||||
|
||||
try:
|
||||
tar = tarfile.open(dst)
|
||||
tar.extractall(path=osp.dirname(dst))
|
||||
tar.close()
|
||||
except:
|
||||
zip_ref = zipfile.ZipFile(dst, "r")
|
||||
zip_ref.extractall(osp.dirname(dst))
|
||||
zip_ref.close()
|
||||
|
||||
print("File extracted to {}".format(osp.dirname(dst)))
|
||||
|
||||
def generate_fewshot_dataset(
|
||||
self, *data_sources, num_shots=-1, repeat=False
|
||||
):
|
||||
"""Generate a few-shot dataset (typically for the training set).
|
||||
|
||||
This function is useful when one wants to evaluate a model
|
||||
in a few-shot learning setting where each class only contains
|
||||
a few number of images.
|
||||
|
||||
Args:
|
||||
data_sources: each individual is a list containing Datum objects.
|
||||
num_shots (int): number of instances per class to sample.
|
||||
repeat (bool): repeat images if needed (default: False).
|
||||
"""
|
||||
if num_shots < 1:
|
||||
if len(data_sources) == 1:
|
||||
return data_sources[0]
|
||||
return data_sources
|
||||
|
||||
print(f"Creating a {num_shots}-shot dataset")
|
||||
|
||||
output = []
|
||||
|
||||
for data_source in data_sources:
|
||||
tracker = self.split_dataset_by_label(data_source)
|
||||
dataset = []
|
||||
|
||||
for label, items in tracker.items():
|
||||
if len(items) >= num_shots:
|
||||
sampled_items = random.sample(items, num_shots)
|
||||
else:
|
||||
if repeat:
|
||||
sampled_items = random.choices(items, k=num_shots)
|
||||
else:
|
||||
sampled_items = items
|
||||
dataset.extend(sampled_items)
|
||||
|
||||
output.append(dataset)
|
||||
|
||||
if len(output) == 1:
|
||||
return output[0]
|
||||
|
||||
return output
|
||||
|
||||
def split_dataset_by_label(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into class-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
output[item.label].append(item)
|
||||
|
||||
return output
|
||||
|
||||
def split_dataset_by_domain(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into domain-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
output[item.domain].append(item)
|
||||
|
||||
return output
|
||||
11
Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dassl.utils import Registry, check_availability
|
||||
|
||||
DATASET_REGISTRY = Registry("DATASET")
|
||||
|
||||
|
||||
def build_dataset(cfg):
|
||||
avai_datasets = DATASET_REGISTRY.registered_names()
|
||||
check_availability(cfg.DATASET.NAME, avai_datasets)
|
||||
if cfg.VERBOSE:
|
||||
print("Loading dataset: {}".format(cfg.DATASET.NAME))
|
||||
return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg)
|
||||
7
Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
Normal file
7
Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .digit5 import Digit5
|
||||
from .visda17 import VisDA17
|
||||
from .cifarstl import CIFARSTL
|
||||
from .office31 import Office31
|
||||
from .domainnet import DomainNet
|
||||
from .office_home import OfficeHome
|
||||
from .mini_domainnet import miniDomainNet
|
||||
68
Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
Normal file
68
Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFARSTL(DatasetBase):
|
||||
"""CIFAR-10 and STL-10.
|
||||
|
||||
CIFAR-10:
|
||||
- 60,000 32x32 colour images.
|
||||
- 10 classes, with 6,000 images per class.
|
||||
- 50,000 training images and 10,000 test images.
|
||||
- URL: https://www.cs.toronto.edu/~kriz/cifar.html.
|
||||
|
||||
STL-10:
|
||||
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
|
||||
monkey, ship, truck.
|
||||
- Images are 96x96 pixels, color.
|
||||
- 500 training images (10 pre-defined folds), 800 test images
|
||||
per class.
|
||||
- URL: https://cs.stanford.edu/~acoates/stl10/.
|
||||
|
||||
Reference:
|
||||
- Krizhevsky. Learning Multiple Layers of Features
|
||||
from Tiny Images. Tech report.
|
||||
- Coates et al. An Analysis of Single Layer Networks in
|
||||
Unsupervised Feature Learning. AISTATS 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "cifar_stl"
|
||||
domains = ["cifar", "stl"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
data_dir = osp.join(self.dataset_dir, dname, split)
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
|
||||
for class_name in class_names:
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
label = int(class_name.split("_")[0])
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label, domain=domain)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
124
Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import random
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
# Folder names for train and test sets
|
||||
MNIST = {"train": "train_images", "test": "test_images"}
|
||||
MNIST_M = {"train": "train_images", "test": "test_images"}
|
||||
SVHN = {"train": "train_images", "test": "test_images"}
|
||||
SYN = {"train": "train_images", "test": "test_images"}
|
||||
USPS = {"train": "train_images", "test": "test_images"}
|
||||
|
||||
|
||||
def read_image_list(im_dir, n_max=None, n_repeat=None):
|
||||
items = []
|
||||
|
||||
for imname in listdir_nohidden(im_dir):
|
||||
imname_noext = osp.splitext(imname)[0]
|
||||
label = int(imname_noext.split("_")[1])
|
||||
impath = osp.join(im_dir, imname)
|
||||
items.append((impath, label))
|
||||
|
||||
if n_max is not None:
|
||||
items = random.sample(items, n_max)
|
||||
|
||||
if n_repeat is not None:
|
||||
items *= n_repeat
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def load_mnist(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_mnist_m(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST_M[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_svhn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SVHN[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_syn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SYN[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_usps(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, USPS[split])
|
||||
n_repeat = 3 if split == "train" else None
|
||||
return read_image_list(data_dir, n_repeat=n_repeat)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Digit5(DatasetBase):
|
||||
"""Five digit datasets.
|
||||
|
||||
It contains:
|
||||
- MNIST: hand-written digits.
|
||||
- MNIST-M: variant of MNIST with blended background.
|
||||
- SVHN: street view house number.
|
||||
- SYN: synthetic digits.
|
||||
- USPS: hand-written digits, slightly different from MNIST.
|
||||
|
||||
For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from
|
||||
the training set and 9,000 images from the test set. For USPS which has only
|
||||
9,298 images in total, we use the entire dataset but replicate its training
|
||||
set for 3 times so as to match the training set size of other domains.
|
||||
|
||||
Reference:
|
||||
- Lecun et al. Gradient-based learning applied to document
|
||||
recognition. IEEE 1998.
|
||||
- Ganin et al. Domain-adversarial training of neural networks.
|
||||
JMLR 2016.
|
||||
- Netzer et al. Reading digits in natural images with unsupervised
|
||||
feature learning. NIPS-W 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "digit5"
|
||||
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
func = "load_" + dname
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
items_d = eval(func)(domain_dir, split=split)
|
||||
|
||||
for impath, label in items_d:
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=str(label)
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
69
Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
Normal file
69
Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DomainNet(DatasetBase):
|
||||
"""DomainNet.
|
||||
|
||||
Statistics:
|
||||
- 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,
|
||||
Real, Sketch.
|
||||
- Around 0.6M images.
|
||||
- 345 categories.
|
||||
- URL: http://ai.bu.edu/M3SDA/.
|
||||
|
||||
Special note: the t-shirt class (327) is missing in painting_train.txt.
|
||||
|
||||
Reference:
|
||||
- Peng et al. Moment Matching for Multi-Source Domain
|
||||
Adaptation. ICCV 2019.
|
||||
"""
|
||||
|
||||
dataset_dir = "domainnet"
|
||||
domains = [
|
||||
"clipart", "infograph", "painting", "quickdraw", "real", "sketch"
|
||||
]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.split_dir = osp.join(self.dataset_dir, "splits")
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
filename = dname + "_" + split + ".txt"
|
||||
split_file = osp.join(self.split_dir, filename)
|
||||
|
||||
with open(split_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[1]
|
||||
impath = osp.join(self.dataset_dir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,58 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class miniDomainNet(DatasetBase):
|
||||
"""A subset of DomainNet.
|
||||
|
||||
Reference:
|
||||
- Peng et al. Moment Matching for Multi-Source Domain
|
||||
Adaptation. ICCV 2019.
|
||||
- Zhou et al. Domain Adaptive Ensemble Learning.
|
||||
"""
|
||||
|
||||
dataset_dir = "domainnet"
|
||||
domains = ["clipart", "painting", "real", "sketch"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.split_dir = osp.join(self.dataset_dir, "splits_mini")
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
filename = dname + "_" + split + ".txt"
|
||||
split_file = osp.join(self.split_dir, filename)
|
||||
|
||||
with open(split_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[1]
|
||||
impath = osp.join(self.dataset_dir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
Normal file
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Office31(DatasetBase):
|
||||
"""Office-31.
|
||||
|
||||
Statistics:
|
||||
- 4,110 images.
|
||||
- 31 classes related to office objects.
|
||||
- 3 domains: Amazon, Webcam, Dslr.
|
||||
- URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.
|
||||
|
||||
Reference:
|
||||
- Saenko et al. Adapting visual category models to
|
||||
new domains. ECCV 2010.
|
||||
"""
|
||||
|
||||
dataset_dir = "office31"
|
||||
domains = ["amazon", "webcam", "dslr"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
class_names = listdir_nohidden(domain_dir)
|
||||
class_names.sort()
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_path = osp.join(domain_dir, class_name)
|
||||
imnames = listdir_nohidden(class_path)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_path, imname)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=class_name
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
Normal file
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OfficeHome(DatasetBase):
|
||||
"""Office-Home.
|
||||
|
||||
Statistics:
|
||||
- Around 15,500 images.
|
||||
- 65 classes related to office and home objects.
|
||||
- 4 domains: Art, Clipart, Product, Real World.
|
||||
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
|
||||
|
||||
Reference:
|
||||
- Venkateswara et al. Deep Hashing Network for Unsupervised
|
||||
Domain Adaptation. CVPR 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "office_home"
|
||||
domains = ["art", "clipart", "product", "real_world"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
class_names = listdir_nohidden(domain_dir)
|
||||
class_names.sort()
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_path = osp.join(domain_dir, class_name)
|
||||
imnames = listdir_nohidden(class_path)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_path, imname)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=class_name.lower(),
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
61
Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
Normal file
61
Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VisDA17(DatasetBase):
|
||||
"""VisDA17.
|
||||
|
||||
Focusing on simulation-to-reality domain shift.
|
||||
|
||||
URL: http://ai.bu.edu/visda-2017/.
|
||||
|
||||
Reference:
|
||||
- Peng et al. VisDA: The Visual Domain Adaptation
|
||||
Challenge. ArXiv 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "visda17"
|
||||
domains = ["synthetic", "real"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data("synthetic")
|
||||
train_u = self._read_data("real")
|
||||
test = self._read_data("real")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, dname):
|
||||
filedir = "train" if dname == "synthetic" else "validation"
|
||||
image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
|
||||
items = []
|
||||
# There is only one source domain
|
||||
domain = 0
|
||||
|
||||
with open(image_list, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[0]
|
||||
impath = osp.join(self.dataset_dir, filedir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
6
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .pacs import PACS
|
||||
from .vlcs import VLCS
|
||||
from .cifar_c import CIFAR10C, CIFAR100C
|
||||
from .digits_dg import DigitsDG
|
||||
from .digit_single import DigitSingle
|
||||
from .office_home_dg import OfficeHomeDG
|
||||
123
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
Normal file
123
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
AVAI_C_TYPES = [
|
||||
"brightness",
|
||||
"contrast",
|
||||
"defocus_blur",
|
||||
"elastic_transform",
|
||||
"fog",
|
||||
"frost",
|
||||
"gaussian_blur",
|
||||
"gaussian_noise",
|
||||
"glass_blur",
|
||||
"impulse_noise",
|
||||
"jpeg_compression",
|
||||
"motion_blur",
|
||||
"pixelate",
|
||||
"saturate",
|
||||
"shot_noise",
|
||||
"snow",
|
||||
"spatter",
|
||||
"speckle_noise",
|
||||
"zoom_blur",
|
||||
]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFAR10C(DatasetBase):
|
||||
"""CIFAR-10 -> CIFAR-10-C.
|
||||
|
||||
Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o
|
||||
|
||||
Statistics:
|
||||
- 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10
|
||||
- 10 categories
|
||||
|
||||
Reference:
|
||||
- Hendrycks et al. Benchmarking neural network robustness
|
||||
to common corruptions and perturbations. ICLR 2019.
|
||||
"""
|
||||
|
||||
dataset_dir = ""
|
||||
domains = ["cifar10", "cifar10_c"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = root
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
source_domain = cfg.DATASET.SOURCE_DOMAINS[0]
|
||||
target_domain = cfg.DATASET.TARGET_DOMAINS[0]
|
||||
assert source_domain == self.domains[0]
|
||||
assert target_domain == self.domains[1]
|
||||
|
||||
c_type = cfg.DATASET.CIFAR_C_TYPE
|
||||
c_level = cfg.DATASET.CIFAR_C_LEVEL
|
||||
|
||||
if not c_type:
|
||||
raise ValueError(
|
||||
"Please specify DATASET.CIFAR_C_TYPE in the config file"
|
||||
)
|
||||
|
||||
assert (
|
||||
c_type in AVAI_C_TYPES
|
||||
), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"'
|
||||
assert 1 <= c_level <= 5
|
||||
|
||||
train_dir = osp.join(self.dataset_dir, source_domain, "train")
|
||||
test_dir = osp.join(
|
||||
self.dataset_dir, target_domain, c_type, str(c_level)
|
||||
)
|
||||
|
||||
if not osp.exists(test_dir):
|
||||
raise ValueError
|
||||
|
||||
train = self._read_data(train_dir)
|
||||
test = self._read_data(test_dir)
|
||||
|
||||
super().__init__(train_x=train, test=test)
|
||||
|
||||
def _read_data(self, data_dir):
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
class_names.sort()
|
||||
items = []
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label, domain=0)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFAR100C(CIFAR10C):
|
||||
"""CIFAR-100 -> CIFAR-100-C.
|
||||
|
||||
Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o
|
||||
|
||||
Statistics:
|
||||
- 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100
|
||||
- 10 categories
|
||||
|
||||
Reference:
|
||||
- Hendrycks et al. Benchmarking neural network robustness
|
||||
to common corruptions and perturbations. ICLR 2019.
|
||||
"""
|
||||
|
||||
dataset_dir = ""
|
||||
domains = ["cifar100", "cifar100_c"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
124
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
# Folder names for train and test sets
|
||||
MNIST = {"train": "train_images", "test": "test_images"}
|
||||
MNIST_M = {"train": "train_images", "test": "test_images"}
|
||||
SVHN = {"train": "train_images", "test": "test_images"}
|
||||
SYN = {"train": "train_images", "test": "test_images"}
|
||||
USPS = {"train": "train_images", "test": "test_images"}
|
||||
|
||||
|
||||
def read_image_list(im_dir, n_max=None, n_repeat=None):
|
||||
items = []
|
||||
|
||||
for imname in listdir_nohidden(im_dir):
|
||||
imname_noext = osp.splitext(imname)[0]
|
||||
label = int(imname_noext.split("_")[1])
|
||||
impath = osp.join(im_dir, imname)
|
||||
items.append((impath, label))
|
||||
|
||||
if n_max is not None:
|
||||
# Note that the sampling process is NOT random,
|
||||
# which follows that in Volpi et al. NIPS'18.
|
||||
items = items[:n_max]
|
||||
|
||||
if n_repeat is not None:
|
||||
items *= n_repeat
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def load_mnist(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST[split])
|
||||
n_max = 10000 if split == "train" else None
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_mnist_m(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST_M[split])
|
||||
n_max = 10000 if split == "train" else None
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_svhn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SVHN[split])
|
||||
n_max = 10000 if split == "train" else None
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_syn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SYN[split])
|
||||
n_max = 10000 if split == "train" else None
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_usps(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, USPS[split])
|
||||
return read_image_list(data_dir)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DigitSingle(DatasetBase):
|
||||
"""Digit recognition datasets for single-source domain generalization.
|
||||
|
||||
There are five digit datasets:
|
||||
- MNIST: hand-written digits.
|
||||
- MNIST-M: variant of MNIST with blended background.
|
||||
- SVHN: street view house number.
|
||||
- SYN: synthetic digits.
|
||||
- USPS: hand-written digits, slightly different from MNIST.
|
||||
|
||||
Protocol:
|
||||
Volpi et al. train a model using 10,000 images from MNIST and
|
||||
evaluate the model on the test split of the other four datasets. However,
|
||||
the code does not restrict you to only use MNIST as the source dataset.
|
||||
Instead, you can use any dataset as the source. But note that only 10,000
|
||||
images will be sampled from the source dataset for training.
|
||||
|
||||
Reference:
|
||||
- Lecun et al. Gradient-based learning applied to document
|
||||
recognition. IEEE 1998.
|
||||
- Ganin et al. Domain-adversarial training of neural networks.
|
||||
JMLR 2016.
|
||||
- Netzer et al. Reading digits in natural images with unsupervised
|
||||
feature learning. NIPS-W 2011.
|
||||
- Volpi et al. Generalizing to Unseen Domains via Adversarial Data
|
||||
Augmentation. NIPS 2018.
|
||||
"""
|
||||
|
||||
# Reuse the digit-5 folder instead of creating a new folder
|
||||
dataset_dir = "digit5"
|
||||
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
func = "load_" + dname
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
items_d = eval(func)(domain_dir, split=split)
|
||||
|
||||
for impath, label in items_d:
|
||||
item = Datum(impath=impath, label=label, domain=domain)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
97
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
Normal file
97
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import glob
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DigitsDG(DatasetBase):
|
||||
"""Digits-DG.
|
||||
|
||||
It contains 4 digit datasets:
|
||||
- MNIST: hand-written digits.
|
||||
- MNIST-M: variant of MNIST with blended background.
|
||||
- SVHN: street view house number.
|
||||
- SYN: synthetic digits.
|
||||
|
||||
Reference:
|
||||
- Lecun et al. Gradient-based learning applied to document
|
||||
recognition. IEEE 1998.
|
||||
- Ganin et al. Domain-adversarial training of neural networks.
|
||||
JMLR 2016.
|
||||
- Netzer et al. Reading digits in natural images with unsupervised
|
||||
feature learning. NIPS-W 2011.
|
||||
- Zhou et al. Deep Domain-Adversarial Image Generation for Domain
|
||||
Generalisation. AAAI 2020.
|
||||
"""
|
||||
|
||||
dataset_dir = "digits_dg"
|
||||
domains = ["mnist", "mnist_m", "svhn", "syn"]
|
||||
data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
if not osp.exists(self.dataset_dir):
|
||||
dst = osp.join(root, "digits_dg.zip")
|
||||
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = self.read_data(
|
||||
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
|
||||
)
|
||||
val = self.read_data(
|
||||
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
|
||||
)
|
||||
test = self.read_data(
|
||||
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
|
||||
)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
@staticmethod
|
||||
def read_data(dataset_dir, input_domains, split):
|
||||
|
||||
def _load_data_from_directory(directory):
|
||||
folders = listdir_nohidden(directory)
|
||||
folders.sort()
|
||||
items_ = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
impaths = glob.glob(osp.join(directory, folder, "*.jpg"))
|
||||
|
||||
for impath in impaths:
|
||||
items_.append((impath, label))
|
||||
|
||||
return items_
|
||||
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
if split == "all":
|
||||
train_dir = osp.join(dataset_dir, dname, "train")
|
||||
impath_label_list = _load_data_from_directory(train_dir)
|
||||
val_dir = osp.join(dataset_dir, dname, "val")
|
||||
impath_label_list += _load_data_from_directory(val_dir)
|
||||
else:
|
||||
split_dir = osp.join(dataset_dir, dname, split)
|
||||
impath_label_list = _load_data_from_directory(split_dir)
|
||||
|
||||
for impath, label in impath_label_list:
|
||||
class_name = impath.split("/")[-2].lower()
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=class_name
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,49 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from .digits_dg import DigitsDG
|
||||
from ..base_dataset import DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OfficeHomeDG(DatasetBase):
|
||||
"""Office-Home.
|
||||
|
||||
Statistics:
|
||||
- Around 15,500 images.
|
||||
- 65 classes related to office and home objects.
|
||||
- 4 domains: Art, Clipart, Product, Real World.
|
||||
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
|
||||
|
||||
Reference:
|
||||
- Venkateswara et al. Deep Hashing Network for Unsupervised
|
||||
Domain Adaptation. CVPR 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "office_home_dg"
|
||||
domains = ["art", "clipart", "product", "real_world"]
|
||||
data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
if not osp.exists(self.dataset_dir):
|
||||
dst = osp.join(root, "office_home_dg.zip")
|
||||
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = DigitsDG.read_data(
|
||||
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
|
||||
)
|
||||
val = DigitsDG.read_data(
|
||||
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
|
||||
)
|
||||
test = DigitsDG.read_data(
|
||||
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
|
||||
)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
94
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
Normal file
94
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PACS(DatasetBase):
|
||||
"""PACS.
|
||||
|
||||
Statistics:
|
||||
- 4 domains: Photo (1,670), Art (2,048), Cartoon
|
||||
(2,344), Sketch (3,929).
|
||||
- 7 categories: dog, elephant, giraffe, guitar, horse,
|
||||
house and person.
|
||||
|
||||
Reference:
|
||||
- Li et al. Deeper, broader and artier domain generalization.
|
||||
ICCV 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "pacs"
|
||||
domains = ["art_painting", "cartoon", "photo", "sketch"]
|
||||
data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE"
|
||||
# the following images contain errors and should be ignored
|
||||
_error_paths = ["sketch/dog/n02103406_4068-1.png"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.image_dir = osp.join(self.dataset_dir, "images")
|
||||
self.split_dir = osp.join(self.dataset_dir, "splits")
|
||||
|
||||
if not osp.exists(self.dataset_dir):
|
||||
dst = osp.join(root, "pacs.zip")
|
||||
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all")
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
if split == "all":
|
||||
file_train = osp.join(
|
||||
self.split_dir, dname + "_train_kfold.txt"
|
||||
)
|
||||
impath_label_list = self._read_split_pacs(file_train)
|
||||
file_val = osp.join(
|
||||
self.split_dir, dname + "_crossval_kfold.txt"
|
||||
)
|
||||
impath_label_list += self._read_split_pacs(file_val)
|
||||
else:
|
||||
file = osp.join(
|
||||
self.split_dir, dname + "_" + split + "_kfold.txt"
|
||||
)
|
||||
impath_label_list = self._read_split_pacs(file)
|
||||
|
||||
for impath, label in impath_label_list:
|
||||
classname = impath.split("/")[-2]
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
def _read_split_pacs(self, split_file):
|
||||
items = []
|
||||
|
||||
with open(split_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
if impath in self._error_paths:
|
||||
continue
|
||||
impath = osp.join(self.image_dir, impath)
|
||||
label = int(label) - 1
|
||||
items.append((impath, label))
|
||||
|
||||
return items
|
||||
60
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
Normal file
60
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import glob
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VLCS(DatasetBase):
|
||||
"""VLCS.
|
||||
|
||||
Statistics:
|
||||
- 4 domains: CALTECH, LABELME, PASCAL, SUN
|
||||
- 5 categories: bird, car, chair, dog, and person.
|
||||
|
||||
Reference:
|
||||
- Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "VLCS"
|
||||
domains = ["caltech", "labelme", "pascal", "sun"]
|
||||
data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
if not osp.exists(self.dataset_dir):
|
||||
dst = osp.join(root, "vlcs.zip")
|
||||
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
dname = dname.upper()
|
||||
path = osp.join(self.dataset_dir, dname, split)
|
||||
folders = listdir_nohidden(path)
|
||||
folders.sort()
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
impaths = glob.glob(osp.join(path, folder, "*.jpg"))
|
||||
|
||||
for impath in impaths:
|
||||
item = Datum(impath=impath, label=label, domain=domain)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,3 @@
|
||||
from .svhn import SVHN
|
||||
from .cifar import CIFAR10, CIFAR100
|
||||
from .stl10 import STL10
|
||||
108
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
Normal file
108
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import math
|
||||
import random
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFAR10(DatasetBase):
|
||||
"""CIFAR10 for SSL.
|
||||
|
||||
Reference:
|
||||
- Krizhevsky. Learning Multiple Layers of Features
|
||||
from Tiny Images. Tech report.
|
||||
"""
|
||||
|
||||
dataset_dir = "cifar10"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
train_dir = osp.join(self.dataset_dir, "train")
|
||||
test_dir = osp.join(self.dataset_dir, "test")
|
||||
|
||||
assert cfg.DATASET.NUM_LABELED > 0
|
||||
|
||||
train_x, train_u, val = self._read_data_train(
|
||||
train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT
|
||||
)
|
||||
test = self._read_data_test(test_dir)
|
||||
|
||||
if cfg.DATASET.ALL_AS_UNLABELED:
|
||||
train_u = train_u + train_x
|
||||
|
||||
if len(val) == 0:
|
||||
val = None
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
|
||||
|
||||
def _read_data_train(self, data_dir, num_labeled, val_percent):
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
class_names.sort()
|
||||
num_labeled_per_class = num_labeled / len(class_names)
|
||||
items_x, items_u, items_v = [], [], []
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
|
||||
# Split into train and val following Oliver et al. 2018
|
||||
# Set cfg.DATASET.VAL_PERCENT to 0 to not use val data
|
||||
num_val = math.floor(len(imnames) * val_percent)
|
||||
imnames_train = imnames[num_val:]
|
||||
imnames_val = imnames[:num_val]
|
||||
|
||||
# Note we do shuffle after split
|
||||
random.shuffle(imnames_train)
|
||||
|
||||
for i, imname in enumerate(imnames_train):
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label)
|
||||
|
||||
if (i + 1) <= num_labeled_per_class:
|
||||
items_x.append(item)
|
||||
|
||||
else:
|
||||
items_u.append(item)
|
||||
|
||||
for imname in imnames_val:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label)
|
||||
items_v.append(item)
|
||||
|
||||
return items_x, items_u, items_v
|
||||
|
||||
def _read_data_test(self, data_dir):
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
class_names.sort()
|
||||
items = []
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFAR100(CIFAR10):
|
||||
"""CIFAR100 for SSL.
|
||||
|
||||
Reference:
|
||||
- Krizhevsky. Learning Multiple Layers of Features
|
||||
from Tiny Images. Tech report.
|
||||
"""
|
||||
|
||||
dataset_dir = "cifar100"
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
87
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
Normal file
87
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class STL10(DatasetBase):
|
||||
"""STL-10 dataset.
|
||||
|
||||
Description:
|
||||
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
|
||||
monkey, ship, truck.
|
||||
- Images are 96x96 pixels, color.
|
||||
- 500 training images per class, 800 test images per class.
|
||||
- 100,000 unlabeled images for unsupervised learning.
|
||||
|
||||
Reference:
|
||||
- Coates et al. An Analysis of Single Layer Networks in
|
||||
Unsupervised Feature Learning. AISTATS 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "stl10"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
train_dir = osp.join(self.dataset_dir, "train")
|
||||
test_dir = osp.join(self.dataset_dir, "test")
|
||||
unlabeled_dir = osp.join(self.dataset_dir, "unlabeled")
|
||||
fold_file = osp.join(
|
||||
self.dataset_dir, "stl10_binary", "fold_indices.txt"
|
||||
)
|
||||
|
||||
# Only use the first five splits
|
||||
assert 0 <= cfg.DATASET.STL10_FOLD <= 4
|
||||
|
||||
train_x = self._read_data_train(
|
||||
train_dir, cfg.DATASET.STL10_FOLD, fold_file
|
||||
)
|
||||
train_u = self._read_data_all(unlabeled_dir)
|
||||
test = self._read_data_all(test_dir)
|
||||
|
||||
if cfg.DATASET.ALL_AS_UNLABELED:
|
||||
train_u = train_u + train_x
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data_train(self, data_dir, fold, fold_file):
|
||||
imnames = listdir_nohidden(data_dir)
|
||||
imnames.sort()
|
||||
items = []
|
||||
|
||||
list_idx = list(range(len(imnames)))
|
||||
if fold >= 0:
|
||||
with open(fold_file, "r") as f:
|
||||
str_idx = f.read().splitlines()[fold]
|
||||
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ")
|
||||
|
||||
for i in list_idx:
|
||||
imname = imnames[i]
|
||||
impath = osp.join(data_dir, imname)
|
||||
label = osp.splitext(imname)[0].split("_")[1]
|
||||
label = int(label)
|
||||
item = Datum(impath=impath, label=label)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
def _read_data_all(self, data_dir):
|
||||
imnames = listdir_nohidden(data_dir)
|
||||
items = []
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(data_dir, imname)
|
||||
label = osp.splitext(imname)[0].split("_")[1]
|
||||
if label == "none":
|
||||
label = -1
|
||||
else:
|
||||
label = int(label)
|
||||
item = Datum(impath=impath, label=label)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
17
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
Normal file
17
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .cifar import CIFAR10
|
||||
from ..build import DATASET_REGISTRY
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SVHN(CIFAR10):
|
||||
"""SVHN for SSL.
|
||||
|
||||
Reference:
|
||||
- Netzer et al. Reading Digits in Natural Images with
|
||||
Unsupervised Feature Learning. NIPS-W 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "svhn"
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
205
Dassl.ProGrad.pytorch/dassl/data/samplers.py
Normal file
205
Dassl.ProGrad.pytorch/dassl/data/samplers.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import copy
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
|
||||
|
||||
|
||||
class RandomDomainSampler(Sampler):
|
||||
"""Randomly samples N domains each with K images
|
||||
to form a minibatch of size N*K.
|
||||
|
||||
Args:
|
||||
data_source (list): list of Datums.
|
||||
batch_size (int): batch size.
|
||||
n_domain (int): number of domains to sample in a minibatch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, n_domain):
|
||||
self.data_source = data_source
|
||||
|
||||
# Keep track of image indices for each domain
|
||||
self.domain_dict = defaultdict(list)
|
||||
for i, item in enumerate(data_source):
|
||||
self.domain_dict[item.domain].append(i)
|
||||
self.domains = list(self.domain_dict.keys())
|
||||
|
||||
# Make sure each domain has equal number of images
|
||||
if n_domain is None or n_domain <= 0:
|
||||
n_domain = len(self.domains)
|
||||
assert batch_size % n_domain == 0
|
||||
self.n_img_per_domain = batch_size // n_domain
|
||||
|
||||
self.batch_size = batch_size
|
||||
# n_domain denotes number of domains sampled in a minibatch
|
||||
self.n_domain = n_domain
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
def __iter__(self):
|
||||
domain_dict = copy.deepcopy(self.domain_dict)
|
||||
final_idxs = []
|
||||
stop_sampling = False
|
||||
|
||||
while not stop_sampling:
|
||||
selected_domains = random.sample(self.domains, self.n_domain)
|
||||
|
||||
for domain in selected_domains:
|
||||
idxs = domain_dict[domain]
|
||||
selected_idxs = random.sample(idxs, self.n_img_per_domain)
|
||||
final_idxs.extend(selected_idxs)
|
||||
|
||||
for idx in selected_idxs:
|
||||
domain_dict[domain].remove(idx)
|
||||
|
||||
remaining = len(domain_dict[domain])
|
||||
if remaining < self.n_img_per_domain:
|
||||
stop_sampling = True
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class SeqDomainSampler(Sampler):
|
||||
"""Sequential domain sampler, which randomly samples K
|
||||
images from each domain to form a minibatch.
|
||||
|
||||
Args:
|
||||
data_source (list): list of Datums.
|
||||
batch_size (int): batch size.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size):
|
||||
self.data_source = data_source
|
||||
|
||||
# Keep track of image indices for each domain
|
||||
self.domain_dict = defaultdict(list)
|
||||
for i, item in enumerate(data_source):
|
||||
self.domain_dict[item.domain].append(i)
|
||||
self.domains = list(self.domain_dict.keys())
|
||||
self.domains.sort()
|
||||
|
||||
# Make sure each domain has equal number of images
|
||||
n_domain = len(self.domains)
|
||||
assert batch_size % n_domain == 0
|
||||
self.n_img_per_domain = batch_size // n_domain
|
||||
|
||||
self.batch_size = batch_size
|
||||
# n_domain denotes number of domains sampled in a minibatch
|
||||
self.n_domain = n_domain
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
def __iter__(self):
|
||||
domain_dict = copy.deepcopy(self.domain_dict)
|
||||
final_idxs = []
|
||||
stop_sampling = False
|
||||
|
||||
while not stop_sampling:
|
||||
for domain in self.domains:
|
||||
idxs = domain_dict[domain]
|
||||
selected_idxs = random.sample(idxs, self.n_img_per_domain)
|
||||
final_idxs.extend(selected_idxs)
|
||||
|
||||
for idx in selected_idxs:
|
||||
domain_dict[domain].remove(idx)
|
||||
|
||||
remaining = len(domain_dict[domain])
|
||||
if remaining < self.n_img_per_domain:
|
||||
stop_sampling = True
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class RandomClassSampler(Sampler):
|
||||
"""Randomly samples N classes each with K instances to
|
||||
form a minibatch of size N*K.
|
||||
|
||||
Modified from https://github.com/KaiyangZhou/deep-person-reid.
|
||||
|
||||
Args:
|
||||
data_source (list): list of Datums.
|
||||
batch_size (int): batch size.
|
||||
n_ins (int): number of instances per class to sample in a minibatch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, n_ins):
|
||||
if batch_size < n_ins:
|
||||
raise ValueError(
|
||||
"batch_size={} must be no less "
|
||||
"than n_ins={}".format(batch_size, n_ins)
|
||||
)
|
||||
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.n_ins = n_ins
|
||||
self.ncls_per_batch = self.batch_size // self.n_ins
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, item in enumerate(data_source):
|
||||
self.index_dic[item.label].append(index)
|
||||
self.labels = list(self.index_dic.keys())
|
||||
assert len(self.labels) >= self.ncls_per_batch
|
||||
|
||||
# estimate number of images in an epoch
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
def __iter__(self):
|
||||
batch_idxs_dict = defaultdict(list)
|
||||
|
||||
for label in self.labels:
|
||||
idxs = copy.deepcopy(self.index_dic[label])
|
||||
if len(idxs) < self.n_ins:
|
||||
idxs = np.random.choice(idxs, size=self.n_ins, replace=True)
|
||||
random.shuffle(idxs)
|
||||
batch_idxs = []
|
||||
for idx in idxs:
|
||||
batch_idxs.append(idx)
|
||||
if len(batch_idxs) == self.n_ins:
|
||||
batch_idxs_dict[label].append(batch_idxs)
|
||||
batch_idxs = []
|
||||
|
||||
avai_labels = copy.deepcopy(self.labels)
|
||||
final_idxs = []
|
||||
|
||||
while len(avai_labels) >= self.ncls_per_batch:
|
||||
selected_labels = random.sample(avai_labels, self.ncls_per_batch)
|
||||
for label in selected_labels:
|
||||
batch_idxs = batch_idxs_dict[label].pop(0)
|
||||
final_idxs.extend(batch_idxs)
|
||||
if len(batch_idxs_dict[label]) == 0:
|
||||
avai_labels.remove(label)
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def build_sampler(
|
||||
sampler_type,
|
||||
cfg=None,
|
||||
data_source=None,
|
||||
batch_size=32,
|
||||
n_domain=0,
|
||||
n_ins=16
|
||||
):
|
||||
if sampler_type == "RandomSampler":
|
||||
return RandomSampler(data_source)
|
||||
|
||||
elif sampler_type == "SequentialSampler":
|
||||
return SequentialSampler(data_source)
|
||||
|
||||
elif sampler_type == "RandomDomainSampler":
|
||||
return RandomDomainSampler(data_source, batch_size, n_domain)
|
||||
|
||||
elif sampler_type == "SeqDomainSampler":
|
||||
return SeqDomainSampler(data_source, batch_size)
|
||||
|
||||
elif sampler_type == "RandomClassSampler":
|
||||
return RandomClassSampler(data_source, batch_size, n_ins)
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown sampler type: {}".format(sampler_type))
|
||||
1
Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py
Normal file
1
Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .transforms import build_transform
|
||||
273
Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
Normal file
273
Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Source: https://github.com/DeepVoltaire/AutoAugment
|
||||
"""
|
||||
import numpy as np
|
||||
import random
|
||||
from PIL import Image, ImageOps, ImageEnhance
|
||||
|
||||
|
||||
class ImageNetPolicy:
|
||||
"""Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||
|
||||
Example:
|
||||
>>> policy = ImageNetPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> ImageNetPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class CIFAR10Policy:
|
||||
"""Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||
|
||||
Example:
|
||||
>>> policy = CIFAR10Policy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> CIFAR10Policy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class SVHNPolicy:
|
||||
"""Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||
|
||||
Example:
|
||||
>>> policy = SVHNPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> SVHNPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor),
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
p1,
|
||||
operation1,
|
||||
magnitude_idx1,
|
||||
p2,
|
||||
operation2,
|
||||
magnitude_idx2,
|
||||
fillcolor=(128, 128, 128),
|
||||
):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10,
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(
|
||||
rot, Image.new("RGBA", rot.size, (128, ) * 4), rot
|
||||
).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX":
|
||||
lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor,
|
||||
),
|
||||
"shearY":
|
||||
lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor,
|
||||
),
|
||||
"translateX":
|
||||
lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(
|
||||
1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0,
|
||||
1, 0
|
||||
),
|
||||
fillcolor=fillcolor,
|
||||
),
|
||||
"translateY":
|
||||
lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(
|
||||
1, 0, 0, 0, 1, magnitude * img.size[1] * random.
|
||||
choice([-1, 1])
|
||||
),
|
||||
fillcolor=fillcolor,
|
||||
),
|
||||
"rotate":
|
||||
lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color":
|
||||
lambda img, magnitude: ImageEnhance.Color(img).
|
||||
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize":
|
||||
lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize":
|
||||
lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast":
|
||||
lambda img, magnitude: ImageEnhance.Contrast(img).
|
||||
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness":
|
||||
lambda img, magnitude: ImageEnhance.Sharpness(img).
|
||||
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness":
|
||||
lambda img, magnitude: ImageEnhance.Brightness(img).
|
||||
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast":
|
||||
lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize":
|
||||
lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert":
|
||||
lambda img, magnitude: ImageOps.invert(img),
|
||||
}
|
||||
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1:
|
||||
img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2:
|
||||
img = self.operation2(img, self.magnitude2)
|
||||
return img
|
||||
363
Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
Normal file
363
Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Credit to
|
||||
1) https://github.com/ildoonet/pytorch-randaugment
|
||||
2) https://github.com/kakaobrain/fast-autoaugment
|
||||
"""
|
||||
import numpy as np
|
||||
import random
|
||||
import PIL
|
||||
import torch
|
||||
import PIL.ImageOps
|
||||
import PIL.ImageDraw
|
||||
import PIL.ImageEnhance
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def ShearX(img, v):
|
||||
assert -0.3 <= v <= 0.3
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
|
||||
|
||||
|
||||
def ShearY(img, v):
|
||||
assert -0.3 <= v <= 0.3
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
|
||||
|
||||
|
||||
def TranslateX(img, v):
|
||||
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||
assert -0.45 <= v <= 0.45
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
v = v * img.size[0]
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
||||
|
||||
|
||||
def TranslateXabs(img, v):
|
||||
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||
assert 0 <= v
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
||||
|
||||
|
||||
def TranslateY(img, v):
|
||||
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||
assert -0.45 <= v <= 0.45
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
v = v * img.size[1]
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
||||
|
||||
|
||||
def TranslateYabs(img, v):
|
||||
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||
assert 0 <= v
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
||||
|
||||
|
||||
def Rotate(img, v):
|
||||
assert -30 <= v <= 30
|
||||
if random.random() > 0.5:
|
||||
v = -v
|
||||
return img.rotate(v)
|
||||
|
||||
|
||||
def AutoContrast(img, _):
|
||||
return PIL.ImageOps.autocontrast(img)
|
||||
|
||||
|
||||
def Invert(img, _):
|
||||
return PIL.ImageOps.invert(img)
|
||||
|
||||
|
||||
def Equalize(img, _):
|
||||
return PIL.ImageOps.equalize(img)
|
||||
|
||||
|
||||
def Flip(img, _):
|
||||
return PIL.ImageOps.mirror(img)
|
||||
|
||||
|
||||
def Solarize(img, v):
|
||||
assert 0 <= v <= 256
|
||||
return PIL.ImageOps.solarize(img, v)
|
||||
|
||||
|
||||
def SolarizeAdd(img, addition=0, threshold=128):
|
||||
img_np = np.array(img).astype(np.int)
|
||||
img_np = img_np + addition
|
||||
img_np = np.clip(img_np, 0, 255)
|
||||
img_np = img_np.astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
return PIL.ImageOps.solarize(img, threshold)
|
||||
|
||||
|
||||
def Posterize(img, v):
|
||||
assert 4 <= v <= 8
|
||||
v = int(v)
|
||||
return PIL.ImageOps.posterize(img, v)
|
||||
|
||||
|
||||
def Contrast(img, v):
|
||||
assert 0.0 <= v <= 2.0
|
||||
return PIL.ImageEnhance.Contrast(img).enhance(v)
|
||||
|
||||
|
||||
def Color(img, v):
|
||||
assert 0.0 <= v <= 2.0
|
||||
return PIL.ImageEnhance.Color(img).enhance(v)
|
||||
|
||||
|
||||
def Brightness(img, v):
|
||||
assert 0.0 <= v <= 2.0
|
||||
return PIL.ImageEnhance.Brightness(img).enhance(v)
|
||||
|
||||
|
||||
def Sharpness(img, v):
|
||||
assert 0.0 <= v <= 2.0
|
||||
return PIL.ImageEnhance.Sharpness(img).enhance(v)
|
||||
|
||||
|
||||
def Cutout(img, v):
|
||||
# [0, 60] => percentage: [0, 0.2]
|
||||
assert 0.0 <= v <= 0.2
|
||||
if v <= 0.0:
|
||||
return img
|
||||
|
||||
v = v * img.size[0]
|
||||
return CutoutAbs(img, v)
|
||||
|
||||
|
||||
def CutoutAbs(img, v):
|
||||
# [0, 60] => percentage: [0, 0.2]
|
||||
# assert 0 <= v <= 20
|
||||
if v < 0:
|
||||
return img
|
||||
w, h = img.size
|
||||
x0 = np.random.uniform(w)
|
||||
y0 = np.random.uniform(h)
|
||||
|
||||
x0 = int(max(0, x0 - v/2.0))
|
||||
y0 = int(max(0, y0 - v/2.0))
|
||||
x1 = min(w, x0 + v)
|
||||
y1 = min(h, y0 + v)
|
||||
|
||||
xy = (x0, y0, x1, y1)
|
||||
color = (125, 123, 114)
|
||||
# color = (0, 0, 0)
|
||||
img = img.copy()
|
||||
PIL.ImageDraw.Draw(img).rectangle(xy, color)
|
||||
return img
|
||||
|
||||
|
||||
def SamplePairing(imgs):
|
||||
# [0, 0.4]
|
||||
def f(img1, v):
|
||||
i = np.random.choice(len(imgs))
|
||||
img2 = PIL.Image.fromarray(imgs[i])
|
||||
return PIL.Image.blend(img1, img2, v)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def Identity(img, v):
|
||||
return img
|
||||
|
||||
|
||||
class Lighting:
|
||||
"""Lighting noise (AlexNet - style PCA - based noise)."""
|
||||
|
||||
def __init__(self, alphastd, eigval, eigvec):
|
||||
self.alphastd = alphastd
|
||||
self.eigval = torch.Tensor(eigval)
|
||||
self.eigvec = torch.Tensor(eigvec)
|
||||
|
||||
def __call__(self, img):
|
||||
if self.alphastd == 0:
|
||||
return img
|
||||
|
||||
alpha = img.new().resize_(3).normal_(0, self.alphastd)
|
||||
rgb = (
|
||||
self.eigvec.type_as(img).clone().mul(
|
||||
alpha.view(1, 3).expand(3, 3)
|
||||
).mul(self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze()
|
||||
)
|
||||
|
||||
return img.add(rgb.view(3, 1, 1).expand_as(img))
|
||||
|
||||
|
||||
class CutoutDefault:
|
||||
"""
|
||||
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
|
||||
"""
|
||||
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
def randaugment_list():
|
||||
# 16 oeprations and their ranges
|
||||
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
|
||||
# augs = [
|
||||
# (Identity, 0., 1.0),
|
||||
# (ShearX, 0., 0.3), # 0
|
||||
# (ShearY, 0., 0.3), # 1
|
||||
# (TranslateX, 0., 0.33), # 2
|
||||
# (TranslateY, 0., 0.33), # 3
|
||||
# (Rotate, 0, 30), # 4
|
||||
# (AutoContrast, 0, 1), # 5
|
||||
# (Invert, 0, 1), # 6
|
||||
# (Equalize, 0, 1), # 7
|
||||
# (Solarize, 0, 110), # 8
|
||||
# (Posterize, 4, 8), # 9
|
||||
# # (Contrast, 0.1, 1.9), # 10
|
||||
# (Color, 0.1, 1.9), # 11
|
||||
# (Brightness, 0.1, 1.9), # 12
|
||||
# (Sharpness, 0.1, 1.9), # 13
|
||||
# # (Cutout, 0, 0.2), # 14
|
||||
# # (SamplePairing(imgs), 0, 0.4) # 15
|
||||
# ]
|
||||
|
||||
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
|
||||
augs = [
|
||||
(AutoContrast, 0, 1),
|
||||
(Equalize, 0, 1),
|
||||
(Invert, 0, 1),
|
||||
(Rotate, 0, 30),
|
||||
(Posterize, 4, 8),
|
||||
(Solarize, 0, 256),
|
||||
(SolarizeAdd, 0, 110),
|
||||
(Color, 0.1, 1.9),
|
||||
(Contrast, 0.1, 1.9),
|
||||
(Brightness, 0.1, 1.9),
|
||||
(Sharpness, 0.1, 1.9),
|
||||
(ShearX, 0.0, 0.3),
|
||||
(ShearY, 0.0, 0.3),
|
||||
(CutoutAbs, 0, 40),
|
||||
(TranslateXabs, 0.0, 100),
|
||||
(TranslateYabs, 0.0, 100),
|
||||
]
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
def randaugment_list2():
|
||||
augs = [
|
||||
(AutoContrast, 0, 1),
|
||||
(Brightness, 0.1, 1.9),
|
||||
(Color, 0.1, 1.9),
|
||||
(Contrast, 0.1, 1.9),
|
||||
(Equalize, 0, 1),
|
||||
(Identity, 0, 1),
|
||||
(Invert, 0, 1),
|
||||
(Posterize, 4, 8),
|
||||
(Rotate, -30, 30),
|
||||
(Sharpness, 0.1, 1.9),
|
||||
(ShearX, -0.3, 0.3),
|
||||
(ShearY, -0.3, 0.3),
|
||||
(Solarize, 0, 256),
|
||||
(TranslateX, -0.3, 0.3),
|
||||
(TranslateY, -0.3, 0.3),
|
||||
]
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
def fixmatch_list():
|
||||
# https://arxiv.org/abs/2001.07685
|
||||
augs = [
|
||||
(AutoContrast, 0, 1),
|
||||
(Brightness, 0.05, 0.95),
|
||||
(Color, 0.05, 0.95),
|
||||
(Contrast, 0.05, 0.95),
|
||||
(Equalize, 0, 1),
|
||||
(Identity, 0, 1),
|
||||
(Posterize, 4, 8),
|
||||
(Rotate, -30, 30),
|
||||
(Sharpness, 0.05, 0.95),
|
||||
(ShearX, -0.3, 0.3),
|
||||
(ShearY, -0.3, 0.3),
|
||||
(Solarize, 0, 256),
|
||||
(TranslateX, -0.3, 0.3),
|
||||
(TranslateY, -0.3, 0.3),
|
||||
]
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
class RandAugment:
|
||||
|
||||
def __init__(self, n=2, m=10):
|
||||
assert 0 <= m <= 30
|
||||
self.n = n
|
||||
self.m = m
|
||||
self.augment_list = randaugment_list()
|
||||
|
||||
def __call__(self, img):
|
||||
ops = random.choices(self.augment_list, k=self.n)
|
||||
|
||||
for op, minval, maxval in ops:
|
||||
val = (self.m / 30) * (maxval-minval) + minval
|
||||
img = op(img, val)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class RandAugment2:
|
||||
|
||||
def __init__(self, n=2, p=0.6):
|
||||
self.n = n
|
||||
self.p = p
|
||||
self.augment_list = randaugment_list2()
|
||||
|
||||
def __call__(self, img):
|
||||
ops = random.choices(self.augment_list, k=self.n)
|
||||
|
||||
for op, minval, maxval in ops:
|
||||
if random.random() > self.p:
|
||||
continue
|
||||
m = random.random()
|
||||
val = m * (maxval-minval) + minval
|
||||
img = op(img, val)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class RandAugmentFixMatch:
|
||||
|
||||
def __init__(self, n=2):
|
||||
self.n = n
|
||||
self.augment_list = fixmatch_list()
|
||||
|
||||
def __call__(self, img):
|
||||
ops = random.choices(self.augment_list, k=self.n)
|
||||
|
||||
for op, minval, maxval in ops:
|
||||
m = random.random()
|
||||
val = m * (maxval-minval) + minval
|
||||
img = op(img, val)
|
||||
|
||||
return img
|
||||
341
Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
Normal file
341
Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import (
|
||||
Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,
|
||||
RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,
|
||||
RandomHorizontalFlip
|
||||
)
|
||||
|
||||
from .autoaugment import SVHNPolicy, CIFAR10Policy, ImageNetPolicy
|
||||
from .randaugment import RandAugment, RandAugment2, RandAugmentFixMatch
|
||||
|
||||
AVAI_CHOICES = [
|
||||
"random_flip",
|
||||
"random_resized_crop",
|
||||
"normalize",
|
||||
"instance_norm",
|
||||
"random_crop",
|
||||
"random_translation",
|
||||
"center_crop", # This has become a default operation for test
|
||||
"cutout",
|
||||
"imagenet_policy",
|
||||
"cifar10_policy",
|
||||
"svhn_policy",
|
||||
"randaugment",
|
||||
"randaugment_fixmatch",
|
||||
"randaugment2",
|
||||
"gaussian_noise",
|
||||
"colorjitter",
|
||||
"randomgrayscale",
|
||||
"gaussian_blur",
|
||||
]
|
||||
|
||||
INTERPOLATION_MODES = {
|
||||
"bilinear": Image.BILINEAR,
|
||||
"bicubic": Image.BICUBIC,
|
||||
"nearest": Image.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
class Random2DTranslation:
|
||||
"""Given an image of (height, width), we resize it to
|
||||
(height*1.125, width*1.125), and then perform random cropping.
|
||||
|
||||
Args:
|
||||
height (int): target image height.
|
||||
width (int): target image width.
|
||||
p (float, optional): probability that this operation takes place.
|
||||
Default is 0.5.
|
||||
interpolation (int, optional): desired interpolation. Default is
|
||||
``PIL.Image.BILINEAR``
|
||||
"""
|
||||
|
||||
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.p = p
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return img.resize((self.width, self.height), self.interpolation)
|
||||
|
||||
new_width = int(round(self.width * 1.125))
|
||||
new_height = int(round(self.height * 1.125))
|
||||
resized_img = img.resize((new_width, new_height), self.interpolation)
|
||||
|
||||
x_maxrange = new_width - self.width
|
||||
y_maxrange = new_height - self.height
|
||||
x1 = int(round(random.uniform(0, x_maxrange)))
|
||||
y1 = int(round(random.uniform(0, y_maxrange)))
|
||||
croped_img = resized_img.crop(
|
||||
(x1, y1, x1 + self.width, y1 + self.height)
|
||||
)
|
||||
|
||||
return croped_img
|
||||
|
||||
|
||||
class InstanceNormalization:
|
||||
"""Normalize data using per-channel mean and standard deviation.
|
||||
|
||||
Reference:
|
||||
- Ulyanov et al. Instance normalization: The missing in- gredient
|
||||
for fast stylization. ArXiv 2016.
|
||||
- Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.
|
||||
ICLR 2018.
|
||||
"""
|
||||
|
||||
def __init__(self, eps=1e-8):
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, img):
|
||||
C, H, W = img.shape
|
||||
img_re = img.reshape(C, H * W)
|
||||
mean = img_re.mean(1).view(C, 1, 1)
|
||||
std = img_re.std(1).view(C, 1, 1)
|
||||
return (img-mean) / (std + self.eps)
|
||||
|
||||
|
||||
class Cutout:
|
||||
"""Randomly mask out one or more patches from an image.
|
||||
|
||||
https://github.com/uoguelph-mlrg/Cutout
|
||||
|
||||
Args:
|
||||
n_holes (int, optional): number of patches to cut out
|
||||
of each image. Default is 1.
|
||||
length (int, optinal): length (in pixels) of each square
|
||||
patch. Default is 16.
|
||||
"""
|
||||
|
||||
def __init__(self, n_holes=1, length=16):
|
||||
self.n_holes = n_holes
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): tensor image of size (C, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor: image with n_holes of dimension
|
||||
length x length cut out of it.
|
||||
"""
|
||||
h = img.size(1)
|
||||
w = img.size(2)
|
||||
|
||||
mask = np.ones((h, w), np.float32)
|
||||
|
||||
for n in range(self.n_holes):
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
return img * mask
|
||||
|
||||
|
||||
class GaussianNoise:
|
||||
"""Add gaussian noise."""
|
||||
|
||||
def __init__(self, mean=0, std=0.15, p=0.5):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return img
|
||||
noise = torch.randn(img.size()) * self.std + self.mean
|
||||
return img + noise
|
||||
|
||||
|
||||
def build_transform(cfg, is_train=True, choices=None):
|
||||
"""Build transformation function.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): config.
|
||||
is_train (bool, optional): for training (True) or test (False).
|
||||
Default is True.
|
||||
choices (list, optional): list of strings which will overwrite
|
||||
cfg.INPUT.TRANSFORMS if given. Default is None.
|
||||
"""
|
||||
if cfg.INPUT.NO_TRANSFORM:
|
||||
print("Note: no transform is applied!")
|
||||
return None
|
||||
|
||||
if choices is None:
|
||||
choices = cfg.INPUT.TRANSFORMS
|
||||
|
||||
for choice in choices:
|
||||
assert choice in AVAI_CHOICES
|
||||
|
||||
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
|
||||
|
||||
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||
|
||||
if is_train:
|
||||
return _build_transform_train(cfg, choices, target_size, normalize)
|
||||
else:
|
||||
return _build_transform_test(cfg, choices, target_size, normalize)
|
||||
|
||||
|
||||
def _build_transform_train(cfg, choices, target_size, normalize):
|
||||
print("Building transform_train")
|
||||
tfm_train = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
|
||||
# Make sure the image size matches the target size
|
||||
conditions = []
|
||||
conditions += ["random_crop" not in choices]
|
||||
conditions += ["random_resized_crop" not in choices]
|
||||
if all(conditions):
|
||||
print(f"+ resize to {target_size}")
|
||||
tfm_train += [Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
|
||||
if "random_translation" in choices:
|
||||
print("+ random translation")
|
||||
tfm_train += [
|
||||
Random2DTranslation(cfg.INPUT.SIZE[0], cfg.INPUT.SIZE[1])
|
||||
]
|
||||
|
||||
if "random_crop" in choices:
|
||||
crop_padding = cfg.INPUT.CROP_PADDING
|
||||
print("+ random crop (padding = {})".format(crop_padding))
|
||||
tfm_train += [RandomCrop(cfg.INPUT.SIZE, padding=crop_padding)]
|
||||
|
||||
if "random_resized_crop" in choices:
|
||||
print(f"+ random resized crop (size={cfg.INPUT.SIZE})")
|
||||
tfm_train += [
|
||||
RandomResizedCrop(cfg.INPUT.SIZE, interpolation=interp_mode)
|
||||
]
|
||||
|
||||
if "center_crop" in choices:
|
||||
print(f"+ center crop (size={cfg.INPUT.SIZE})")
|
||||
tfm_train += [CenterCrop(cfg.INPUT.SIZE)]
|
||||
|
||||
if "random_flip" in choices:
|
||||
print("+ random flip")
|
||||
tfm_train += [RandomHorizontalFlip()]
|
||||
|
||||
if "imagenet_policy" in choices:
|
||||
print("+ imagenet policy")
|
||||
tfm_train += [ImageNetPolicy()]
|
||||
|
||||
if "cifar10_policy" in choices:
|
||||
print("+ cifar10 policy")
|
||||
tfm_train += [CIFAR10Policy()]
|
||||
|
||||
if "svhn_policy" in choices:
|
||||
print("+ svhn policy")
|
||||
tfm_train += [SVHNPolicy()]
|
||||
|
||||
if "randaugment" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
m_ = cfg.INPUT.RANDAUGMENT_M
|
||||
print("+ randaugment (n={}, m={})".format(n_, m_))
|
||||
tfm_train += [RandAugment(n_, m_)]
|
||||
|
||||
if "randaugment_fixmatch" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print("+ randaugment_fixmatch (n={})".format(n_))
|
||||
tfm_train += [RandAugmentFixMatch(n_)]
|
||||
|
||||
if "randaugment2" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print("+ randaugment2 (n={})".format(n_))
|
||||
tfm_train += [RandAugment2(n_)]
|
||||
|
||||
if "colorjitter" in choices:
|
||||
print("+ color jitter")
|
||||
tfm_train += [
|
||||
ColorJitter(
|
||||
brightness=cfg.INPUT.COLORJITTER_B,
|
||||
contrast=cfg.INPUT.COLORJITTER_C,
|
||||
saturation=cfg.INPUT.COLORJITTER_S,
|
||||
hue=cfg.INPUT.COLORJITTER_H,
|
||||
)
|
||||
]
|
||||
|
||||
if "randomgrayscale" in choices:
|
||||
print("+ random gray scale")
|
||||
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
|
||||
|
||||
if "gaussian_blur" in choices:
|
||||
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
|
||||
tfm_train += [
|
||||
RandomApply([GaussianBlur(cfg.INPUT.GB_K)], p=cfg.INPUT.GB_P)
|
||||
]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_train += [ToTensor()]
|
||||
|
||||
if "cutout" in choices:
|
||||
cutout_n = cfg.INPUT.CUTOUT_N
|
||||
cutout_len = cfg.INPUT.CUTOUT_LEN
|
||||
print("+ cutout (n_holes={}, length={})".format(cutout_n, cutout_len))
|
||||
tfm_train += [Cutout(cutout_n, cutout_len)]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
"+ normalization (mean={}, "
|
||||
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
|
||||
)
|
||||
tfm_train += [normalize]
|
||||
|
||||
if "gaussian_noise" in choices:
|
||||
print(
|
||||
"+ gaussian noise (mean={}, std={})".format(
|
||||
cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD
|
||||
)
|
||||
)
|
||||
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_train += [InstanceNormalization()]
|
||||
|
||||
tfm_train = Compose(tfm_train)
|
||||
|
||||
return tfm_train
|
||||
|
||||
|
||||
def _build_transform_test(cfg, choices, target_size, normalize):
|
||||
print("Building transform_test")
|
||||
tfm_test = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
|
||||
print(f"+ resize the smaller edge to {max(cfg.INPUT.SIZE)}")
|
||||
tfm_test += [Resize(max(cfg.INPUT.SIZE), interpolation=interp_mode)]
|
||||
|
||||
print(f"+ {target_size} center crop")
|
||||
tfm_test += [CenterCrop(cfg.INPUT.SIZE)]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_test += [ToTensor()]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
"+ normalization (mean={}, "
|
||||
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
|
||||
)
|
||||
tfm_test += [normalize]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_test += [InstanceNormalization()]
|
||||
|
||||
tfm_test = Compose(tfm_test)
|
||||
|
||||
return tfm_test
|
||||
Reference in New Issue
Block a user