378 lines
10 KiB
Python
378 lines
10 KiB
Python
import os
|
|
import random
|
|
import os.path as osp
|
|
import tarfile
|
|
import zipfile
|
|
from collections import defaultdict
|
|
import gdown
|
|
import json
|
|
import torch
|
|
from torch.utils.data import Dataset as TorchDataset
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
|
|
|
|
def read_json(fpath):
|
|
"""Read json file from a path."""
|
|
with open(fpath, 'r') as f:
|
|
obj = json.load(f)
|
|
return obj
|
|
|
|
|
|
def write_json(obj, fpath):
|
|
"""Writes to a json file."""
|
|
if not osp.exists(osp.dirname(fpath)):
|
|
os.makedirs(osp.dirname(fpath))
|
|
with open(fpath, 'w') as f:
|
|
json.dump(obj, f, indent=4, separators=(',', ': '))
|
|
|
|
|
|
def read_image(path):
|
|
"""Read image from path using ``PIL.Image``.
|
|
|
|
Args:
|
|
path (str): path to an image.
|
|
|
|
Returns:
|
|
PIL image
|
|
"""
|
|
if not osp.exists(path):
|
|
raise IOError('No file exists at {}'.format(path))
|
|
|
|
while True:
|
|
try:
|
|
img = Image.open(path).convert('RGB')
|
|
return img
|
|
except IOError:
|
|
print(
|
|
'Cannot read image from {}, '
|
|
'probably due to heavy IO. Will re-try'.format(path)
|
|
)
|
|
|
|
|
|
def listdir_nohidden(path, sort=False):
|
|
"""List non-hidden items in a directory.
|
|
|
|
Args:
|
|
path (str): directory path.
|
|
sort (bool): sort the items.
|
|
"""
|
|
items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f]
|
|
if sort:
|
|
items.sort()
|
|
return items
|
|
|
|
|
|
class Datum:
|
|
"""Data instance which defines the basic attributes.
|
|
|
|
Args:
|
|
impath (str): image path.
|
|
label (int): class label.
|
|
domain (int): domain label.
|
|
classname (str): class name.
|
|
"""
|
|
|
|
def __init__(self, impath='', label=0, domain=-1, classname=''):
|
|
assert isinstance(impath, str)
|
|
assert isinstance(label, int)
|
|
assert isinstance(domain, int)
|
|
assert isinstance(classname, str)
|
|
|
|
self._impath = impath
|
|
self._label = label
|
|
self._domain = domain
|
|
self._classname = classname
|
|
|
|
@property
|
|
def impath(self):
|
|
return self._impath
|
|
|
|
@property
|
|
def label(self):
|
|
return self._label
|
|
|
|
@property
|
|
def domain(self):
|
|
return self._domain
|
|
|
|
@property
|
|
def classname(self):
|
|
return self._classname
|
|
|
|
|
|
class DatasetBase:
|
|
"""A unified dataset class for
|
|
1) domain adaptation
|
|
2) domain generalization
|
|
3) semi-supervised learning
|
|
"""
|
|
dataset_dir = '' # the directory where the dataset is stored
|
|
domains = [] # string names of all domains
|
|
|
|
def __init__(self, train_x=None, train_u=None, val=None, test=None,t_sne=None):
|
|
self._train_x = train_x # labeled training data
|
|
self._train_u = train_u # unlabeled training data (optional)
|
|
self._val = val # validation data (optional)
|
|
self._test = test # test data
|
|
self._num_classes = self.get_num_classes(train_x)
|
|
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
|
|
|
|
@property
|
|
def train_x(self):
|
|
return self._train_x
|
|
|
|
@property
|
|
def train_u(self):
|
|
return self._train_u
|
|
|
|
@property
|
|
def val(self):
|
|
return self._val
|
|
|
|
@property
|
|
def test(self):
|
|
return self._test
|
|
|
|
@property
|
|
def lab2cname(self):
|
|
return self._lab2cname
|
|
|
|
@property
|
|
def classnames(self):
|
|
return self._classnames
|
|
|
|
@property
|
|
def num_classes(self):
|
|
return self._num_classes
|
|
|
|
def get_num_classes(self, data_source):
|
|
"""Count number of classes.
|
|
|
|
Args:
|
|
data_source (list): a list of Datum objects.
|
|
"""
|
|
label_set = set()
|
|
for item in data_source:
|
|
label_set.add(item.label)
|
|
return max(label_set) + 1
|
|
|
|
def get_lab2cname(self, data_source):
|
|
"""Get a label-to-classname mapping (dict).
|
|
|
|
Args:
|
|
data_source (list): a list of Datum objects.
|
|
"""
|
|
container = set()
|
|
for item in data_source:
|
|
container.add((item.label, item.classname))
|
|
mapping = {label: classname for label, classname in container}
|
|
labels = list(mapping.keys())
|
|
labels.sort()
|
|
classnames = [mapping[label] for label in labels]
|
|
return mapping, classnames
|
|
|
|
def check_input_domains(self, source_domains, target_domains):
|
|
self.is_input_domain_valid(source_domains)
|
|
self.is_input_domain_valid(target_domains)
|
|
|
|
def is_input_domain_valid(self, input_domains):
|
|
for domain in input_domains:
|
|
if domain not in self.domains:
|
|
raise ValueError(
|
|
'Input domain must belong to {}, '
|
|
'but got [{}]'.format(self.domains, domain)
|
|
)
|
|
|
|
def download_data(self, url, dst, from_gdrive=True):
|
|
if not osp.exists(osp.dirname(dst)):
|
|
os.makedirs(osp.dirname(dst))
|
|
|
|
if from_gdrive:
|
|
gdown.download(url, dst, quiet=False)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
print('Extracting file ...')
|
|
|
|
try:
|
|
tar = tarfile.open(dst)
|
|
tar.extractall(path=osp.dirname(dst))
|
|
tar.close()
|
|
except:
|
|
zip_ref = zipfile.ZipFile(dst, 'r')
|
|
zip_ref.extractall(osp.dirname(dst))
|
|
zip_ref.close()
|
|
|
|
print('File extracted to {}'.format(osp.dirname(dst)))
|
|
|
|
def generate_fewshot_dataset(
|
|
self, *data_sources, num_shots=-1, repeat=True
|
|
):
|
|
"""Generate a few-shot dataset (typically for the training set).
|
|
|
|
This function is useful when one wants to evaluate a model
|
|
in a few-shot learning setting where each class only contains
|
|
a few number of images.
|
|
|
|
Args:
|
|
data_sources: each individual is a list containing Datum objects.
|
|
num_shots (int): number of instances per class to sample.
|
|
repeat (bool): repeat images if needed.
|
|
"""
|
|
if num_shots < 1:
|
|
if len(data_sources) == 1:
|
|
return data_sources[0]
|
|
return data_sources
|
|
|
|
print(f'Creating a {num_shots}-shot dataset')
|
|
|
|
output = []
|
|
|
|
for data_source in data_sources:
|
|
tracker = self.split_dataset_by_label(data_source)
|
|
dataset = []
|
|
|
|
for label, items in tracker.items():
|
|
if len(items) >= num_shots:
|
|
sampled_items = random.sample(items, num_shots)
|
|
else:
|
|
if repeat:
|
|
sampled_items = random.choices(items, k=num_shots)
|
|
else:
|
|
sampled_items = items
|
|
dataset.extend(sampled_items)
|
|
|
|
output.append(dataset)
|
|
|
|
if len(output) == 1:
|
|
return output[0]
|
|
|
|
return output
|
|
|
|
def split_dataset_by_label(self, data_source):
|
|
"""Split a dataset, i.e. a list of Datum objects,
|
|
into class-specific groups stored in a dictionary.
|
|
|
|
Args:
|
|
data_source (list): a list of Datum objects.
|
|
"""
|
|
output = defaultdict(list)
|
|
|
|
for item in data_source:
|
|
output[item.label].append(item)
|
|
|
|
return output
|
|
|
|
def split_dataset_by_domain(self, data_source):
|
|
"""Split a dataset, i.e. a list of Datum objects,
|
|
into domain-specific groups stored in a dictionary.
|
|
|
|
Args:
|
|
data_source (list): a list of Datum objects.
|
|
"""
|
|
output = defaultdict(list)
|
|
|
|
for item in data_source:
|
|
output[item.domain].append(item)
|
|
|
|
return output
|
|
|
|
|
|
class DatasetWrapper(TorchDataset):
|
|
def __init__(self, data_source, input_size, transform=None, is_train=False,
|
|
return_img0=False, k_tfm=1):
|
|
self.data_source = data_source
|
|
self.transform = transform # accept list (tuple) as input
|
|
self.is_train = is_train
|
|
# Augmenting an image K>1 times is only allowed during training
|
|
self.k_tfm = k_tfm if is_train else 1
|
|
self.return_img0 = return_img0
|
|
|
|
if self.k_tfm > 1 and transform is None:
|
|
raise ValueError(
|
|
'Cannot augment the image {} times '
|
|
'because transform is None'.format(self.k_tfm)
|
|
)
|
|
|
|
# Build transform that doesn't apply any data augmentation
|
|
interp_mode = T.InterpolationMode.BICUBIC
|
|
to_tensor = []
|
|
to_tensor += [T.Resize(input_size, interpolation=interp_mode)]
|
|
to_tensor += [T.ToTensor()]
|
|
normalize = T.Normalize(
|
|
mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
|
|
)
|
|
to_tensor += [normalize]
|
|
self.to_tensor = T.Compose(to_tensor)
|
|
|
|
def __len__(self):
|
|
return len(self.data_source)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data_source[idx]
|
|
|
|
output = {
|
|
'label': item.label,
|
|
'domain': item.domain,
|
|
'impath': item.impath
|
|
}
|
|
|
|
img0 = read_image(item.impath)
|
|
|
|
if self.transform is not None:
|
|
if isinstance(self.transform, (list, tuple)):
|
|
for i, tfm in enumerate(self.transform):
|
|
img = self._transform_image(tfm, img0)
|
|
keyname = 'img'
|
|
if (i + 1) > 1:
|
|
keyname += str(i + 1)
|
|
output[keyname] = img
|
|
else:
|
|
img = self._transform_image(self.transform, img0)
|
|
output['img'] = img
|
|
|
|
if self.return_img0:
|
|
output['img0'] = self.to_tensor(img0)
|
|
|
|
return output['img'], output['label'], output['impath']
|
|
|
|
def _transform_image(self, tfm, img0):
|
|
img_list = []
|
|
|
|
for k in range(self.k_tfm):
|
|
img_list.append(tfm(img0))
|
|
|
|
img = img_list
|
|
if len(img) == 1:
|
|
img = img[0]
|
|
|
|
return img
|
|
|
|
|
|
def build_data_loader(
|
|
data_source=None,
|
|
batch_size=64,
|
|
input_size=224,
|
|
tfm=None,
|
|
is_train=True,
|
|
shuffle=False,
|
|
dataset_wrapper=None
|
|
):
|
|
|
|
if dataset_wrapper is None:
|
|
dataset_wrapper = DatasetWrapper
|
|
|
|
# Build data loader
|
|
data_loader = torch.utils.data.DataLoader(
|
|
dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train),
|
|
batch_size=batch_size,
|
|
num_workers=8,
|
|
shuffle=shuffle,
|
|
drop_last=False,
|
|
pin_memory=(torch.cuda.is_available())
|
|
)
|
|
assert len(data_loader) > 0
|
|
|
|
return data_loader
|