265 lines
7.9 KiB
Python
265 lines
7.9 KiB
Python
import torch
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset as TorchDataset
|
|
|
|
from dassl.utils import read_image
|
|
|
|
from .datasets import build_dataset
|
|
from .samplers import build_sampler
|
|
from .transforms import build_transform
|
|
|
|
INTERPOLATION_MODES = {
|
|
"bilinear": Image.BILINEAR,
|
|
"bicubic": Image.BICUBIC,
|
|
"nearest": Image.NEAREST,
|
|
}
|
|
|
|
|
|
def build_data_loader(
|
|
cfg,
|
|
sampler_type="SequentialSampler",
|
|
data_source=None,
|
|
batch_size=64,
|
|
n_domain=0,
|
|
n_ins=2,
|
|
tfm=None,
|
|
is_train=True,
|
|
dataset_wrapper=None,
|
|
):
|
|
# Build sampler
|
|
sampler = build_sampler(
|
|
sampler_type,
|
|
cfg=cfg,
|
|
data_source=data_source,
|
|
batch_size=batch_size,
|
|
n_domain=n_domain,
|
|
n_ins=n_ins,
|
|
)
|
|
|
|
if dataset_wrapper is None:
|
|
dataset_wrapper = DatasetWrapper
|
|
|
|
# Build data loader
|
|
data_loader = torch.utils.data.DataLoader(
|
|
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
|
|
batch_size=batch_size,
|
|
sampler=sampler,
|
|
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
|
drop_last=is_train and len(data_source) >= batch_size,
|
|
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
|
|
)
|
|
assert len(data_loader) > 0
|
|
|
|
return data_loader
|
|
|
|
|
|
class DataManager:
|
|
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
custom_tfm_train=None,
|
|
custom_tfm_test=None,
|
|
dataset_wrapper=None
|
|
):
|
|
# Load dataset
|
|
dataset = build_dataset(cfg)
|
|
# Build transform
|
|
if custom_tfm_train is None:
|
|
tfm_train = build_transform(cfg, is_train=True)
|
|
else:
|
|
print("* Using custom transform for training")
|
|
tfm_train = custom_tfm_train
|
|
|
|
if custom_tfm_test is None:
|
|
tfm_test = build_transform(cfg, is_train=False)
|
|
else:
|
|
print("* Using custom transform for testing")
|
|
tfm_test = custom_tfm_test
|
|
|
|
# Build train_loader_x
|
|
train_loader_x = build_data_loader(
|
|
cfg,
|
|
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
|
data_source=dataset.train_x,
|
|
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
|
|
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
|
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
|
tfm=tfm_train,
|
|
is_train=True,
|
|
dataset_wrapper=dataset_wrapper,
|
|
)
|
|
|
|
# Build train_loader_u
|
|
train_loader_u = None
|
|
if dataset.train_u:
|
|
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
|
|
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
|
|
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
|
|
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
|
|
|
|
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
|
|
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
|
|
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
|
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
|
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
|
|
|
|
train_loader_u = build_data_loader(
|
|
cfg,
|
|
sampler_type=sampler_type_,
|
|
data_source=dataset.train_u,
|
|
batch_size=batch_size_,
|
|
n_domain=n_domain_,
|
|
n_ins=n_ins_,
|
|
tfm=tfm_train,
|
|
is_train=True,
|
|
dataset_wrapper=dataset_wrapper,
|
|
)
|
|
|
|
# Build val_loader
|
|
val_loader = None
|
|
if dataset.val:
|
|
val_loader = build_data_loader(
|
|
cfg,
|
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
|
data_source=dataset.val,
|
|
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
|
tfm=tfm_test,
|
|
is_train=False,
|
|
dataset_wrapper=dataset_wrapper,
|
|
)
|
|
|
|
# Build test_loader
|
|
test_loader = build_data_loader(
|
|
cfg,
|
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
|
data_source=dataset.test,
|
|
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
|
tfm=tfm_test,
|
|
is_train=False,
|
|
dataset_wrapper=dataset_wrapper,
|
|
)
|
|
|
|
# Attributes
|
|
self._num_classes = dataset.num_classes
|
|
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
|
|
self._lab2cname = dataset.lab2cname
|
|
|
|
# Dataset and data-loaders
|
|
self.dataset = dataset
|
|
self.train_loader_x = train_loader_x
|
|
self.train_loader_u = train_loader_u
|
|
self.val_loader = val_loader
|
|
self.test_loader = test_loader
|
|
|
|
if cfg.VERBOSE:
|
|
self.show_dataset_summary(cfg)
|
|
|
|
@property
|
|
def num_classes(self):
|
|
return self._num_classes
|
|
|
|
@property
|
|
def num_source_domains(self):
|
|
return self._num_source_domains
|
|
|
|
@property
|
|
def lab2cname(self):
|
|
return self._lab2cname
|
|
|
|
def show_dataset_summary(self, cfg):
|
|
print("***** Dataset statistics *****")
|
|
|
|
print(" Dataset: {}".format(cfg.DATASET.NAME))
|
|
|
|
if cfg.DATASET.SOURCE_DOMAINS:
|
|
print(" Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS))
|
|
if cfg.DATASET.TARGET_DOMAINS:
|
|
print(" Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS))
|
|
|
|
print(" # classes: {:,}".format(self.num_classes))
|
|
|
|
print(" # train_x: {:,}".format(len(self.dataset.train_x)))
|
|
|
|
if self.dataset.train_u:
|
|
print(" # train_u: {:,}".format(len(self.dataset.train_u)))
|
|
|
|
if self.dataset.val:
|
|
print(" # val: {:,}".format(len(self.dataset.val)))
|
|
|
|
print(" # test: {:,}".format(len(self.dataset.test)))
|
|
|
|
|
|
class DatasetWrapper(TorchDataset):
|
|
|
|
def __init__(self, cfg, data_source, transform=None, is_train=False):
|
|
self.cfg = cfg
|
|
self.data_source = data_source
|
|
self.transform = transform # accept list (tuple) as input
|
|
self.is_train = is_train
|
|
# Augmenting an image K>1 times is only allowed during training
|
|
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
|
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
|
|
|
if self.k_tfm > 1 and transform is None:
|
|
raise ValueError(
|
|
"Cannot augment the image {} times "
|
|
"because transform is None".format(self.k_tfm)
|
|
)
|
|
|
|
# Build transform that doesn't apply any data augmentation
|
|
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
|
to_tensor = []
|
|
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
|
to_tensor += [T.ToTensor()]
|
|
if "normalize" in cfg.INPUT.TRANSFORMS:
|
|
normalize = T.Normalize(
|
|
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
|
)
|
|
to_tensor += [normalize]
|
|
self.to_tensor = T.Compose(to_tensor)
|
|
|
|
def __len__(self):
|
|
return len(self.data_source)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data_source[idx]
|
|
|
|
output = {
|
|
"label": item.label,
|
|
"domain": item.domain,
|
|
"impath": item.impath
|
|
}
|
|
|
|
img0 = read_image(item.impath)
|
|
|
|
if self.transform is not None:
|
|
if isinstance(self.transform, (list, tuple)):
|
|
for i, tfm in enumerate(self.transform):
|
|
img = self._transform_image(tfm, img0)
|
|
keyname = "img"
|
|
if (i + 1) > 1:
|
|
keyname += str(i + 1)
|
|
output[keyname] = img
|
|
else:
|
|
img = self._transform_image(self.transform, img0)
|
|
output["img"] = img
|
|
|
|
if self.return_img0:
|
|
output["img0"] = self.to_tensor(img0)
|
|
|
|
return output
|
|
|
|
def _transform_image(self, tfm, img0):
|
|
img_list = []
|
|
|
|
for k in range(self.k_tfm):
|
|
img_list.append(tfm(img0))
|
|
|
|
img = img_list
|
|
if len(img) == 1:
|
|
img = img[0]
|
|
|
|
return img
|