release code
This commit is contained in:
264
Dassl.ProGrad.pytorch/dassl/data/data_manager.py
Normal file
264
Dassl.ProGrad.pytorch/dassl/data/data_manager.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
from dassl.utils import read_image
|
||||
|
||||
from .datasets import build_dataset
|
||||
from .samplers import build_sampler
|
||||
from .transforms import build_transform
|
||||
|
||||
INTERPOLATION_MODES = {
|
||||
"bilinear": Image.BILINEAR,
|
||||
"bicubic": Image.BICUBIC,
|
||||
"nearest": Image.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
def build_data_loader(
|
||||
cfg,
|
||||
sampler_type="SequentialSampler",
|
||||
data_source=None,
|
||||
batch_size=64,
|
||||
n_domain=0,
|
||||
n_ins=2,
|
||||
tfm=None,
|
||||
is_train=True,
|
||||
dataset_wrapper=None,
|
||||
):
|
||||
# Build sampler
|
||||
sampler = build_sampler(
|
||||
sampler_type,
|
||||
cfg=cfg,
|
||||
data_source=data_source,
|
||||
batch_size=batch_size,
|
||||
n_domain=n_domain,
|
||||
n_ins=n_ins,
|
||||
)
|
||||
|
||||
if dataset_wrapper is None:
|
||||
dataset_wrapper = DatasetWrapper
|
||||
|
||||
# Build data loader
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
||||
drop_last=is_train and len(data_source) >= batch_size,
|
||||
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
|
||||
)
|
||||
assert len(data_loader) > 0
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
class DataManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
custom_tfm_train=None,
|
||||
custom_tfm_test=None,
|
||||
dataset_wrapper=None
|
||||
):
|
||||
# Load dataset
|
||||
dataset = build_dataset(cfg)
|
||||
# Build transform
|
||||
if custom_tfm_train is None:
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
else:
|
||||
print("* Using custom transform for training")
|
||||
tfm_train = custom_tfm_train
|
||||
|
||||
if custom_tfm_test is None:
|
||||
tfm_test = build_transform(cfg, is_train=False)
|
||||
else:
|
||||
print("* Using custom transform for testing")
|
||||
tfm_test = custom_tfm_test
|
||||
|
||||
# Build train_loader_x
|
||||
train_loader_x = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=dataset.train_x,
|
||||
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build train_loader_u
|
||||
train_loader_u = None
|
||||
if dataset.train_u:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
|
||||
|
||||
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
|
||||
|
||||
train_loader_u = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=sampler_type_,
|
||||
data_source=dataset.train_u,
|
||||
batch_size=batch_size_,
|
||||
n_domain=n_domain_,
|
||||
n_ins=n_ins_,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build val_loader
|
||||
val_loader = None
|
||||
if dataset.val:
|
||||
val_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.val,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Build test_loader
|
||||
test_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.test,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
)
|
||||
|
||||
# Attributes
|
||||
self._num_classes = dataset.num_classes
|
||||
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
|
||||
self._lab2cname = dataset.lab2cname
|
||||
|
||||
# Dataset and data-loaders
|
||||
self.dataset = dataset
|
||||
self.train_loader_x = train_loader_x
|
||||
self.train_loader_u = train_loader_u
|
||||
self.val_loader = val_loader
|
||||
self.test_loader = test_loader
|
||||
|
||||
if cfg.VERBOSE:
|
||||
self.show_dataset_summary(cfg)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._num_classes
|
||||
|
||||
@property
|
||||
def num_source_domains(self):
|
||||
return self._num_source_domains
|
||||
|
||||
@property
|
||||
def lab2cname(self):
|
||||
return self._lab2cname
|
||||
|
||||
def show_dataset_summary(self, cfg):
|
||||
print("***** Dataset statistics *****")
|
||||
|
||||
print(" Dataset: {}".format(cfg.DATASET.NAME))
|
||||
|
||||
if cfg.DATASET.SOURCE_DOMAINS:
|
||||
print(" Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS))
|
||||
if cfg.DATASET.TARGET_DOMAINS:
|
||||
print(" Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS))
|
||||
|
||||
print(" # classes: {:,}".format(self.num_classes))
|
||||
|
||||
print(" # train_x: {:,}".format(len(self.dataset.train_x)))
|
||||
|
||||
if self.dataset.train_u:
|
||||
print(" # train_u: {:,}".format(len(self.dataset.train_u)))
|
||||
|
||||
if self.dataset.val:
|
||||
print(" # val: {:,}".format(len(self.dataset.val)))
|
||||
|
||||
print(" # test: {:,}".format(len(self.dataset.test)))
|
||||
|
||||
|
||||
class DatasetWrapper(TorchDataset):
|
||||
|
||||
def __init__(self, cfg, data_source, transform=None, is_train=False):
|
||||
self.cfg = cfg
|
||||
self.data_source = data_source
|
||||
self.transform = transform # accept list (tuple) as input
|
||||
self.is_train = is_train
|
||||
# Augmenting an image K>1 times is only allowed during training
|
||||
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
||||
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
||||
|
||||
if self.k_tfm > 1 and transform is None:
|
||||
raise ValueError(
|
||||
"Cannot augment the image {} times "
|
||||
"because transform is None".format(self.k_tfm)
|
||||
)
|
||||
|
||||
# Build transform that doesn't apply any data augmentation
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
to_tensor = []
|
||||
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
to_tensor += [T.ToTensor()]
|
||||
if "normalize" in cfg.INPUT.TRANSFORMS:
|
||||
normalize = T.Normalize(
|
||||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||||
)
|
||||
to_tensor += [normalize]
|
||||
self.to_tensor = T.Compose(to_tensor)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_source[idx]
|
||||
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath
|
||||
}
|
||||
|
||||
img0 = read_image(item.impath)
|
||||
|
||||
if self.transform is not None:
|
||||
if isinstance(self.transform, (list, tuple)):
|
||||
for i, tfm in enumerate(self.transform):
|
||||
img = self._transform_image(tfm, img0)
|
||||
keyname = "img"
|
||||
if (i + 1) > 1:
|
||||
keyname += str(i + 1)
|
||||
output[keyname] = img
|
||||
else:
|
||||
img = self._transform_image(self.transform, img0)
|
||||
output["img"] = img
|
||||
|
||||
if self.return_img0:
|
||||
output["img0"] = self.to_tensor(img0)
|
||||
|
||||
return output
|
||||
|
||||
def _transform_image(self, tfm, img0):
|
||||
img_list = []
|
||||
|
||||
for k in range(self.k_tfm):
|
||||
img_list.append(tfm(img0))
|
||||
|
||||
img = img_list
|
||||
if len(img) == 1:
|
||||
img = img[0]
|
||||
|
||||
return img
|
||||
Reference in New Issue
Block a user