import torch import torchvision.transforms as T import numpy as np from tabulate import tabulate from torch.utils.data import Dataset as TorchDataset import os from dassl.utils import read_image from dassl.data.datasets import build_dataset from dassl.data.samplers import build_sampler from dassl.data.transforms import INTERPOLATION_MODES, build_transform from .new_da import RandomResizedCropPair, build_transform_pair from PIL import Image def build_data_loader( cfg, sampler_type="SequentialSampler", data_source=None, batch_size=64, n_domain=0, n_ins=2, tfm=None, is_train=True, dataset_wrapper=None, weight=None, ): # Build sampler sampler = build_sampler( sampler_type, cfg=cfg, data_source=data_source, batch_size=batch_size, n_domain=n_domain, n_ins=n_ins ) if dataset_wrapper is None: dataset_wrapper = DatasetWrapper # Build data loader data_loader = torch.utils.data.DataLoader( dataset_wrapper(cfg, data_source,transform=tfm, is_train=is_train,weight=weight), batch_size=batch_size, sampler=sampler, num_workers=cfg.DATALOADER.NUM_WORKERS, drop_last=is_train and len(data_source) >= batch_size, pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA) ) assert len(data_loader) > 0 return data_loader def build_data_loader_mask( cfg, dataset, sampler_type="SequentialSampler", data_source=None, batch_size=64, n_domain=0, n_ins=2, tfm=None, is_train=True, dataset_wrapper=None, weight=None, ): # Build sampler sampler = build_sampler( sampler_type, cfg=cfg, data_source=data_source, batch_size=batch_size, n_domain=n_domain, n_ins=n_ins ) if dataset_wrapper is None: dataset_wrapper = DatasetWrapperMask # Build data loader data_loader = torch.utils.data.DataLoader( dataset_wrapper(cfg, dataset,data_source,transform=tfm, is_train=is_train,weight=weight), batch_size=batch_size, sampler=sampler, num_workers=cfg.DATALOADER.NUM_WORKERS, drop_last=is_train and len(data_source) >= batch_size, pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA) ) assert len(data_loader) > 0 return data_loader def select_dm_loader(cfg,dataset,s_ind=None,is_train=False): tfm = build_transform(cfg, is_train=is_train) if is_train: dataloader = build_data_loader( cfg, sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER, data_source=list(np.asarray(dataset)[s_ind]) if s_ind is not None else dataset, batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, #cfg.DATALOADER.TRAIN_X.BATCH_SIZE* n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN, n_ins=cfg.DATALOADER.TRAIN_X.N_INS, tfm=tfm, is_train=is_train, dataset_wrapper=None, ) else: dataloader = build_data_loader( cfg, sampler_type=cfg.DATALOADER.TEST.SAMPLER, data_source=list(np.asarray(dataset)[s_ind]) if s_ind is not None else dataset, batch_size=cfg.DATASET.SELECTION_BATCH_SIZE, n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN, n_ins=cfg.DATALOADER.TRAIN_X.N_INS, tfm=tfm, is_train=is_train, dataset_wrapper=None, ) return dataloader class DataManager: def __init__( self, cfg, dataset, s_ind=None, custom_tfm_train=None, custom_tfm_test=None, dataset_wrapper=None, weight=None, ): # # Load dataset # dataset = build_dataset(cfg) # Build transform if custom_tfm_train is None: ###pair is for tfm_train_pair = build_transform_pair(cfg, is_train=True) tfm_train = build_transform(cfg,is_train=True) else: print("* Using custom transform for training") tfm_train = custom_tfm_train if custom_tfm_test is None: tfm_test = build_transform(cfg, is_train=False) else: print("* Using custom transform for testing") tfm_test = custom_tfm_test # Build train_loader_x train_loader_x = build_data_loader_mask( cfg, dataset, sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER, data_source=list(np.asarray(dataset.train_x)[s_ind]) if s_ind is not None else dataset.train_x, batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN, n_ins=cfg.DATALOADER.TRAIN_X.N_INS, tfm=tfm_train_pair, is_train=True, dataset_wrapper=dataset_wrapper, weight=weight ) train_loader_xmore = build_data_loader( cfg, sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER, data_source=list(np.asarray(dataset.train_x)[s_ind]) if s_ind is not None else dataset.train_x, batch_size=cfg.DATASET.SELECTION_BATCH_SIZE, n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN, n_ins=cfg.DATALOADER.TRAIN_X.N_INS, tfm=tfm_train, is_train=True, dataset_wrapper=dataset_wrapper, weight=weight ) # Build train_loader_u train_loader_u = None if dataset.train_u: sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS if cfg.DATALOADER.TRAIN_U.SAME_AS_X: sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS train_loader_u = build_data_loader( cfg, sampler_type=sampler_type_, data_source=dataset.train_u, batch_size=batch_size_, n_domain=n_domain_, n_ins=n_ins_, tfm=tfm_train, is_train=True, dataset_wrapper=dataset_wrapper ) # Build val_loader val_loader = None if dataset.val: val_loader = build_data_loader( cfg, sampler_type=cfg.DATALOADER.TEST.SAMPLER, data_source=dataset.val, batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, tfm=tfm_test, is_train=False, dataset_wrapper=dataset_wrapper ) # Build test_loader test_loader = build_data_loader( cfg, sampler_type=cfg.DATALOADER.TEST.SAMPLER, data_source=dataset.test, batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, tfm=tfm_test, is_train=False, dataset_wrapper=dataset_wrapper ) # Attributes self._num_classes = dataset.num_classes self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS) self._lab2cname = dataset.lab2cname # Dataset and data-loaders self.dataset = dataset self.train_loader_x = train_loader_x self.train_loader_u = train_loader_u self.train_loader_xmore = train_loader_xmore self.val_loader = val_loader self.test_loader = test_loader if cfg.VERBOSE: self.show_dataset_summary(cfg) @property def num_classes(self): return self._num_classes @property def num_source_domains(self): return self._num_source_domains @property def lab2cname(self): return self._lab2cname def show_dataset_summary(self, cfg): dataset_name = cfg.DATASET.NAME source_domains = cfg.DATASET.SOURCE_DOMAINS target_domains = cfg.DATASET.TARGET_DOMAINS table = [] table.append(["Dataset", dataset_name]) if source_domains: table.append(["Source", source_domains]) if target_domains: table.append(["Target", target_domains]) table.append(["# classes", f"{self.num_classes:,}"]) table.append(["# train_x", f"{len(self.dataset.train_x):,}"]) if self.dataset.train_u: table.append(["# train_u", f"{len(self.dataset.train_u):,}"]) if self.dataset.val: table.append(["# val", f"{len(self.dataset.val):,}"]) table.append(["# test", f"{len(self.dataset.test):,}"]) print(tabulate(table)) class DatasetWrapperMask(TorchDataset): def __init__(self, cfg, dataset,data_source,transform=None, is_train=False,weight=None): self.cfg = cfg self.data_source = data_source self.transform = transform # accept list (tuple) as input self.is_train = is_train self.data_path = dataset.dataset_dir self.mask_path = os.path.join(dataset.dataset_dir,'mask') # Augmenting an image K>1 times is only allowed during training self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1 self.return_img0 = cfg.DATALOADER.RETURN_IMG0 if weight is not None: self.weight = weight else: self.weight = None if self.k_tfm > 1 and transform is None: raise ValueError( "Cannot augment the image {} times " "because transform is None".format(self.k_tfm) ) # Build transform that doesn't apply any data augmentation interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION] to_tensor = [] to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)] to_tensor += [T.ToTensor()] if "normalize" in cfg.INPUT.TRANSFORMS: normalize = T.Normalize( mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD ) to_tensor += [normalize] self.to_tensor = T.Compose(to_tensor) def __len__(self): return len(self.data_source) def __getitem__(self, idx): item = self.data_source[idx] if self.weight is None: output = { "label": item.label, "domain": item.domain, "impath": item.impath, "index": idx } else: output = { "label": item.label, "domain": item.domain, "impath": item.impath, "index": idx, "weight": self.weight[idx] } # img_path = os.path.join('/'.join(item.impath.split('/')[:-1]),'mask',item.impath.split('/')[-1]) ('/').join(item.impath.split('/')[-2:]) if self.cfg.DATASET.NAME in ['Food101','Caltech101','DescribableTextures','EuroSAT','UCF101']: mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[-2:]))) elif self.cfg.DATASET.NAME in ['SUN397']: mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[7:]))) elif self.cfg.DATASET.NAME in ['ImageNet']: mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[7:]))) elif self.cfg.DATASET.NAME in ['VOC12']: mask_path = os.path.join(self.data_path,'VOCdevkit/VOC2012/SegmentationClass_All',item.impath.split('/')[-1][:-3]+'png') mask = read_image(mask_path) else: mask = read_image(os.path.join(self.mask_path, item.impath.split('/')[-1])) img0 = read_image(item.impath) mask = mask.resize(img0.size) if self.transform is not None: if isinstance(self.transform, (list, tuple)): for i, tfm in enumerate(self.transform): img = self._transform_image(tfm, img0,img0) keyname = "img" if (i + 1) > 1: keyname += str(i + 1) output[keyname] = img else: img,mask = self._transform_image(self.transform, img0,mask) output["img"] = img output["mask"] = mask else: output["img"] = img0 if self.return_img0: output["img0"] = self.to_tensor(img0) # without any augmentation return output def _transform_image(self, tfm, img0,mask): img_list = [] for k in range(self.k_tfm): img_list.append(tfm(img0,mask)) img = img_list if len(img_list) == 1: img = img_list[0][0] mask = img_list[0][1] return img,mask class DatasetWrapper(TorchDataset): def __init__(self, cfg, data_source,transform=None, is_train=False,weight=None): self.cfg = cfg self.data_source = data_source self.transform = transform # accept list (tuple) as input self.is_train = is_train self.mask_path = ('/').join(data_source[0].impath.split('/')[:-2])+'/mask' # Augmenting an image K>1 times is only allowed during training self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1 self.return_img0 = cfg.DATALOADER.RETURN_IMG0 if weight is not None: self.weight = weight else: self.weight = None if self.k_tfm > 1 and transform is None: raise ValueError( "Cannot augment the image {} times " "because transform is None".format(self.k_tfm) ) # Build transform that doesn't apply any data augmentation interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION] to_tensor = [] to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)] to_tensor += [T.ToTensor()] if "normalize" in cfg.INPUT.TRANSFORMS: normalize = T.Normalize( mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD ) to_tensor += [normalize] self.to_tensor = T.Compose(to_tensor) def __len__(self): return len(self.data_source) def __getitem__(self, idx): item = self.data_source[idx] if self.weight is None: output = { "label": item.label, "domain": item.domain, "impath": item.impath, "index": idx } else: output = { "label": item.label, "domain": item.domain, "impath": item.impath, "index": idx, "weight": self.weight[idx] } # img0 = read_image(item.impath) img0 = read_image(item.impath) # img0 = img0.resize(mask.size) # mask = read_image(item.impath.split('/')[:-1].join('/')) if self.transform is not None: if isinstance(self.transform, (list, tuple)): for i, tfm in enumerate(self.transform): img = self._transform_image(tfm, img0) keyname = "img" if (i + 1) > 1: keyname += str(i + 1) output[keyname] = img else: img = self._transform_image(self.transform, img0) output["img"] = img output['mask'] = 1 else: output["img"] = img0 if self.return_img0: output["img0"] = self.to_tensor(img0) # without any augmentation return output def _transform_image(self, tfm, img0): img_list = [] for k in range(self.k_tfm): img_list.append(tfm(img0)) img = img_list if len(img) == 1: img = img[0] return img