482 lines
16 KiB
Python
482 lines
16 KiB
Python
import torch
|
|
import torchvision.transforms as T
|
|
import numpy as np
|
|
from tabulate import tabulate
|
|
from torch.utils.data import Dataset as TorchDataset
|
|
import os
|
|
from dassl.utils import read_image
|
|
|
|
from dassl.data.datasets import build_dataset
|
|
from dassl.data.samplers import build_sampler
|
|
from dassl.data.transforms import INTERPOLATION_MODES, build_transform
|
|
from .new_da import RandomResizedCropPair, build_transform_pair
|
|
from PIL import Image
|
|
|
|
def build_data_loader(
|
|
cfg,
|
|
sampler_type="SequentialSampler",
|
|
data_source=None,
|
|
batch_size=64,
|
|
n_domain=0,
|
|
n_ins=2,
|
|
tfm=None,
|
|
is_train=True,
|
|
dataset_wrapper=None,
|
|
weight=None,
|
|
):
|
|
# Build sampler
|
|
sampler = build_sampler(
|
|
sampler_type,
|
|
cfg=cfg,
|
|
data_source=data_source,
|
|
batch_size=batch_size,
|
|
n_domain=n_domain,
|
|
n_ins=n_ins
|
|
)
|
|
|
|
if dataset_wrapper is None:
|
|
dataset_wrapper = DatasetWrapper
|
|
|
|
# Build data loader
|
|
data_loader = torch.utils.data.DataLoader(
|
|
dataset_wrapper(cfg, data_source,transform=tfm, is_train=is_train,weight=weight),
|
|
batch_size=batch_size,
|
|
sampler=sampler,
|
|
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
|
drop_last=is_train and len(data_source) >= batch_size,
|
|
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA)
|
|
)
|
|
assert len(data_loader) > 0
|
|
|
|
return data_loader
|
|
|
|
|
|
|
|
def build_data_loader_mask(
|
|
cfg,
|
|
dataset,
|
|
sampler_type="SequentialSampler",
|
|
data_source=None,
|
|
batch_size=64,
|
|
n_domain=0,
|
|
n_ins=2,
|
|
tfm=None,
|
|
is_train=True,
|
|
dataset_wrapper=None,
|
|
weight=None,
|
|
):
|
|
# Build sampler
|
|
sampler = build_sampler(
|
|
sampler_type,
|
|
cfg=cfg,
|
|
data_source=data_source,
|
|
batch_size=batch_size,
|
|
n_domain=n_domain,
|
|
n_ins=n_ins
|
|
)
|
|
|
|
if dataset_wrapper is None:
|
|
dataset_wrapper = DatasetWrapperMask
|
|
|
|
# Build data loader
|
|
data_loader = torch.utils.data.DataLoader(
|
|
dataset_wrapper(cfg, dataset,data_source,transform=tfm, is_train=is_train,weight=weight),
|
|
batch_size=batch_size,
|
|
sampler=sampler,
|
|
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
|
drop_last=is_train and len(data_source) >= batch_size,
|
|
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA)
|
|
)
|
|
assert len(data_loader) > 0
|
|
|
|
return data_loader
|
|
|
|
def select_dm_loader(cfg,dataset,s_ind=None,is_train=False):
|
|
|
|
tfm = build_transform(cfg, is_train=is_train)
|
|
if is_train:
|
|
dataloader = build_data_loader(
|
|
cfg,
|
|
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
|
data_source=list(np.asarray(dataset)[s_ind]) if s_ind is not None else dataset,
|
|
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, #cfg.DATALOADER.TRAIN_X.BATCH_SIZE*
|
|
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
|
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
|
tfm=tfm,
|
|
is_train=is_train,
|
|
dataset_wrapper=None,
|
|
)
|
|
else:
|
|
dataloader = build_data_loader(
|
|
cfg,
|
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
|
data_source=list(np.asarray(dataset)[s_ind]) if s_ind is not None else dataset,
|
|
batch_size=cfg.DATASET.SELECTION_BATCH_SIZE,
|
|
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
|
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
|
tfm=tfm,
|
|
is_train=is_train,
|
|
dataset_wrapper=None,
|
|
)
|
|
|
|
return dataloader
|
|
|
|
class DataManager:
|
|
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
dataset,
|
|
s_ind=None,
|
|
custom_tfm_train=None,
|
|
custom_tfm_test=None,
|
|
dataset_wrapper=None,
|
|
weight=None,
|
|
):
|
|
# # Load dataset
|
|
# dataset = build_dataset(cfg)
|
|
|
|
# Build transform
|
|
if custom_tfm_train is None:
|
|
###pair is for
|
|
tfm_train_pair = build_transform_pair(cfg, is_train=True)
|
|
tfm_train = build_transform(cfg,is_train=True)
|
|
else:
|
|
print("* Using custom transform for training")
|
|
tfm_train = custom_tfm_train
|
|
|
|
if custom_tfm_test is None:
|
|
tfm_test = build_transform(cfg, is_train=False)
|
|
else:
|
|
print("* Using custom transform for testing")
|
|
tfm_test = custom_tfm_test
|
|
|
|
|
|
# Build train_loader_x
|
|
|
|
train_loader_x = build_data_loader_mask(
|
|
cfg,
|
|
dataset,
|
|
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
|
data_source=list(np.asarray(dataset.train_x)[s_ind]) if s_ind is not None else dataset.train_x,
|
|
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
|
|
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
|
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
|
tfm=tfm_train_pair,
|
|
is_train=True,
|
|
dataset_wrapper=dataset_wrapper,
|
|
weight=weight
|
|
)
|
|
|
|
|
|
train_loader_xmore = build_data_loader(
|
|
cfg,
|
|
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
|
data_source=list(np.asarray(dataset.train_x)[s_ind]) if s_ind is not None else dataset.train_x,
|
|
batch_size=cfg.DATASET.SELECTION_BATCH_SIZE,
|
|
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
|
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
|
tfm=tfm_train,
|
|
is_train=True,
|
|
dataset_wrapper=dataset_wrapper,
|
|
weight=weight
|
|
)
|
|
|
|
# Build train_loader_u
|
|
train_loader_u = None
|
|
if dataset.train_u:
|
|
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
|
|
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
|
|
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
|
|
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
|
|
|
|
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
|
|
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
|
|
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
|
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
|
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
|
|
|
|
train_loader_u = build_data_loader(
|
|
cfg,
|
|
sampler_type=sampler_type_,
|
|
data_source=dataset.train_u,
|
|
batch_size=batch_size_,
|
|
n_domain=n_domain_,
|
|
n_ins=n_ins_,
|
|
tfm=tfm_train,
|
|
is_train=True,
|
|
dataset_wrapper=dataset_wrapper
|
|
)
|
|
|
|
# Build val_loader
|
|
val_loader = None
|
|
if dataset.val:
|
|
val_loader = build_data_loader(
|
|
cfg,
|
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
|
data_source=dataset.val,
|
|
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
|
tfm=tfm_test,
|
|
is_train=False,
|
|
dataset_wrapper=dataset_wrapper
|
|
)
|
|
|
|
# Build test_loader
|
|
test_loader = build_data_loader(
|
|
cfg,
|
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
|
data_source=dataset.test,
|
|
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
|
tfm=tfm_test,
|
|
is_train=False,
|
|
dataset_wrapper=dataset_wrapper
|
|
)
|
|
|
|
# Attributes
|
|
self._num_classes = dataset.num_classes
|
|
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
|
|
self._lab2cname = dataset.lab2cname
|
|
|
|
# Dataset and data-loaders
|
|
self.dataset = dataset
|
|
self.train_loader_x = train_loader_x
|
|
self.train_loader_u = train_loader_u
|
|
self.train_loader_xmore = train_loader_xmore
|
|
self.val_loader = val_loader
|
|
self.test_loader = test_loader
|
|
|
|
if cfg.VERBOSE:
|
|
self.show_dataset_summary(cfg)
|
|
|
|
@property
|
|
def num_classes(self):
|
|
return self._num_classes
|
|
|
|
@property
|
|
def num_source_domains(self):
|
|
return self._num_source_domains
|
|
|
|
@property
|
|
def lab2cname(self):
|
|
return self._lab2cname
|
|
|
|
def show_dataset_summary(self, cfg):
|
|
dataset_name = cfg.DATASET.NAME
|
|
source_domains = cfg.DATASET.SOURCE_DOMAINS
|
|
target_domains = cfg.DATASET.TARGET_DOMAINS
|
|
|
|
table = []
|
|
table.append(["Dataset", dataset_name])
|
|
if source_domains:
|
|
table.append(["Source", source_domains])
|
|
if target_domains:
|
|
table.append(["Target", target_domains])
|
|
table.append(["# classes", f"{self.num_classes:,}"])
|
|
table.append(["# train_x", f"{len(self.dataset.train_x):,}"])
|
|
if self.dataset.train_u:
|
|
table.append(["# train_u", f"{len(self.dataset.train_u):,}"])
|
|
if self.dataset.val:
|
|
table.append(["# val", f"{len(self.dataset.val):,}"])
|
|
table.append(["# test", f"{len(self.dataset.test):,}"])
|
|
|
|
print(tabulate(table))
|
|
|
|
|
|
class DatasetWrapperMask(TorchDataset):
|
|
|
|
def __init__(self, cfg, dataset,data_source,transform=None, is_train=False,weight=None):
|
|
self.cfg = cfg
|
|
self.data_source = data_source
|
|
self.transform = transform # accept list (tuple) as input
|
|
self.is_train = is_train
|
|
self.data_path = dataset.dataset_dir
|
|
self.mask_path = os.path.join(dataset.dataset_dir,'mask')
|
|
# Augmenting an image K>1 times is only allowed during training
|
|
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
|
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
|
|
|
if weight is not None:
|
|
self.weight = weight
|
|
else:
|
|
self.weight = None
|
|
|
|
if self.k_tfm > 1 and transform is None:
|
|
raise ValueError(
|
|
"Cannot augment the image {} times "
|
|
"because transform is None".format(self.k_tfm)
|
|
)
|
|
|
|
# Build transform that doesn't apply any data augmentation
|
|
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
|
to_tensor = []
|
|
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
|
to_tensor += [T.ToTensor()]
|
|
if "normalize" in cfg.INPUT.TRANSFORMS:
|
|
normalize = T.Normalize(
|
|
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
|
)
|
|
to_tensor += [normalize]
|
|
self.to_tensor = T.Compose(to_tensor)
|
|
|
|
def __len__(self):
|
|
return len(self.data_source)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data_source[idx]
|
|
|
|
if self.weight is None:
|
|
output = {
|
|
"label": item.label,
|
|
"domain": item.domain,
|
|
"impath": item.impath,
|
|
"index": idx
|
|
}
|
|
else:
|
|
output = {
|
|
"label": item.label,
|
|
"domain": item.domain,
|
|
"impath": item.impath,
|
|
"index": idx,
|
|
"weight": self.weight[idx]
|
|
}
|
|
|
|
# img_path = os.path.join('/'.join(item.impath.split('/')[:-1]),'mask',item.impath.split('/')[-1]) ('/').join(item.impath.split('/')[-2:])
|
|
if self.cfg.DATASET.NAME in ['Food101','Caltech101','DescribableTextures','EuroSAT','UCF101']:
|
|
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[-2:])))
|
|
elif self.cfg.DATASET.NAME in ['SUN397']:
|
|
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[7:])))
|
|
elif self.cfg.DATASET.NAME in ['ImageNet']:
|
|
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[7:])))
|
|
elif self.cfg.DATASET.NAME in ['VOC12']:
|
|
mask_path = os.path.join(self.data_path,'VOCdevkit/VOC2012/SegmentationClass_All',item.impath.split('/')[-1][:-3]+'png')
|
|
mask = read_image(mask_path)
|
|
else:
|
|
mask = read_image(os.path.join(self.mask_path, item.impath.split('/')[-1]))
|
|
img0 = read_image(item.impath)
|
|
mask = mask.resize(img0.size)
|
|
if self.transform is not None:
|
|
if isinstance(self.transform, (list, tuple)):
|
|
for i, tfm in enumerate(self.transform):
|
|
img = self._transform_image(tfm, img0,img0)
|
|
keyname = "img"
|
|
if (i + 1) > 1:
|
|
keyname += str(i + 1)
|
|
output[keyname] = img
|
|
else:
|
|
img,mask = self._transform_image(self.transform, img0,mask)
|
|
output["img"] = img
|
|
output["mask"] = mask
|
|
else:
|
|
output["img"] = img0
|
|
|
|
if self.return_img0:
|
|
output["img0"] = self.to_tensor(img0) # without any augmentation
|
|
|
|
return output
|
|
|
|
def _transform_image(self, tfm, img0,mask):
|
|
img_list = []
|
|
for k in range(self.k_tfm):
|
|
img_list.append(tfm(img0,mask))
|
|
|
|
img = img_list
|
|
if len(img_list) == 1:
|
|
img = img_list[0][0]
|
|
mask = img_list[0][1]
|
|
|
|
return img,mask
|
|
|
|
|
|
class DatasetWrapper(TorchDataset):
|
|
|
|
def __init__(self, cfg, data_source,transform=None, is_train=False,weight=None):
|
|
self.cfg = cfg
|
|
self.data_source = data_source
|
|
self.transform = transform # accept list (tuple) as input
|
|
self.is_train = is_train
|
|
self.mask_path = ('/').join(data_source[0].impath.split('/')[:-2])+'/mask'
|
|
# Augmenting an image K>1 times is only allowed during training
|
|
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
|
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
|
|
|
if weight is not None:
|
|
self.weight = weight
|
|
else:
|
|
self.weight = None
|
|
|
|
if self.k_tfm > 1 and transform is None:
|
|
raise ValueError(
|
|
"Cannot augment the image {} times "
|
|
"because transform is None".format(self.k_tfm)
|
|
)
|
|
|
|
# Build transform that doesn't apply any data augmentation
|
|
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
|
to_tensor = []
|
|
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
|
to_tensor += [T.ToTensor()]
|
|
if "normalize" in cfg.INPUT.TRANSFORMS:
|
|
normalize = T.Normalize(
|
|
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
|
)
|
|
to_tensor += [normalize]
|
|
self.to_tensor = T.Compose(to_tensor)
|
|
|
|
def __len__(self):
|
|
return len(self.data_source)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data_source[idx]
|
|
|
|
if self.weight is None:
|
|
output = {
|
|
"label": item.label,
|
|
"domain": item.domain,
|
|
"impath": item.impath,
|
|
"index": idx
|
|
}
|
|
else:
|
|
output = {
|
|
"label": item.label,
|
|
"domain": item.domain,
|
|
"impath": item.impath,
|
|
"index": idx,
|
|
"weight": self.weight[idx]
|
|
}
|
|
|
|
# img0 = read_image(item.impath)
|
|
img0 = read_image(item.impath)
|
|
# img0 = img0.resize(mask.size)
|
|
# mask = read_image(item.impath.split('/')[:-1].join('/'))
|
|
if self.transform is not None:
|
|
if isinstance(self.transform, (list, tuple)):
|
|
for i, tfm in enumerate(self.transform):
|
|
img = self._transform_image(tfm, img0)
|
|
keyname = "img"
|
|
if (i + 1) > 1:
|
|
keyname += str(i + 1)
|
|
output[keyname] = img
|
|
else:
|
|
img = self._transform_image(self.transform, img0)
|
|
output["img"] = img
|
|
output['mask'] = 1
|
|
else:
|
|
output["img"] = img0
|
|
|
|
if self.return_img0:
|
|
output["img0"] = self.to_tensor(img0) # without any augmentation
|
|
|
|
return output
|
|
|
|
def _transform_image(self, tfm, img0):
|
|
img_list = []
|
|
|
|
for k in range(self.k_tfm):
|
|
img_list.append(tfm(img0))
|
|
|
|
img = img_list
|
|
if len(img) == 1:
|
|
img = img[0]
|
|
|
|
return img
|