226 lines
6.1 KiB
Python
226 lines
6.1 KiB
Python
import os
|
|
import random
|
|
import os.path as osp
|
|
import tarfile
|
|
import zipfile
|
|
from collections import defaultdict
|
|
import gdown
|
|
|
|
from dassl.utils import check_isfile
|
|
|
|
|
|
class Datum:
|
|
"""Data instance which defines the basic attributes.
|
|
|
|
Args:
|
|
impath (str): image path.
|
|
label (int): class label.
|
|
domain (int): domain label.
|
|
classname (str): class name.
|
|
"""
|
|
|
|
def __init__(self, impath="", label=0, domain=0, classname=""):
|
|
assert isinstance(impath, str)
|
|
assert check_isfile(impath)
|
|
|
|
self._impath = impath
|
|
self._label = label
|
|
self._domain = domain
|
|
self._classname = classname
|
|
|
|
@property
|
|
def impath(self):
|
|
return self._impath
|
|
|
|
@property
|
|
def label(self):
|
|
return self._label
|
|
|
|
@property
|
|
def domain(self):
|
|
return self._domain
|
|
|
|
@property
|
|
def classname(self):
|
|
return self._classname
|
|
|
|
|
|
class DatasetBase:
|
|
"""A unified dataset class for
|
|
1) domain adaptation
|
|
2) domain generalization
|
|
3) semi-supervised learning
|
|
"""
|
|
|
|
dataset_dir = "" # the directory where the dataset is stored
|
|
domains = [] # string names of all domains
|
|
|
|
def __init__(self, train_x=None, train_u=None, val=None, test=None):
|
|
self._train_x = train_x # labeled training data
|
|
self._train_u = train_u # unlabeled training data (optional)
|
|
self._val = val # validation data (optional)
|
|
self._test = test # test data
|
|
|
|
self._num_classes = self.get_num_classes(train_x)
|
|
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
|
|
|
|
@property
|
|
def train_x(self):
|
|
return self._train_x
|
|
|
|
@property
|
|
def train_u(self):
|
|
return self._train_u
|
|
|
|
@property
|
|
def val(self):
|
|
return self._val
|
|
|
|
@property
|
|
def test(self):
|
|
return self._test
|
|
|
|
@property
|
|
def lab2cname(self):
|
|
return self._lab2cname
|
|
|
|
@property
|
|
def classnames(self):
|
|
return self._classnames
|
|
|
|
@property
|
|
def num_classes(self):
|
|
return self._num_classes
|
|
|
|
def get_num_classes(self, data_source):
|
|
"""Count number of classes.
|
|
|
|
Args:
|
|
data_source (list): a list of Datum objects.
|
|
"""
|
|
label_set = set()
|
|
for item in data_source:
|
|
label_set.add(item.label)
|
|
return max(label_set) + 1
|
|
|
|
def get_lab2cname(self, data_source):
|
|
"""Get a label-to-classname mapping (dict).
|
|
|
|
Args:
|
|
data_source (list): a list of Datum objects.
|
|
"""
|
|
container = set()
|
|
for item in data_source:
|
|
container.add((item.label, item.classname))
|
|
mapping = {label: classname for label, classname in container}
|
|
labels = list(mapping.keys())
|
|
labels.sort()
|
|
classnames = [mapping[label] for label in labels]
|
|
return mapping, classnames
|
|
|
|
def check_input_domains(self, source_domains, target_domains):
|
|
self.is_input_domain_valid(source_domains)
|
|
self.is_input_domain_valid(target_domains)
|
|
|
|
def is_input_domain_valid(self, input_domains):
|
|
for domain in input_domains:
|
|
if domain not in self.domains:
|
|
raise ValueError(
|
|
"Input domain must belong to {}, "
|
|
"but got [{}]".format(self.domains, domain)
|
|
)
|
|
|
|
def download_data(self, url, dst, from_gdrive=True):
|
|
if not osp.exists(osp.dirname(dst)):
|
|
os.makedirs(osp.dirname(dst))
|
|
|
|
if from_gdrive:
|
|
gdown.download(url, dst, quiet=False)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
print("Extracting file ...")
|
|
|
|
try:
|
|
tar = tarfile.open(dst)
|
|
tar.extractall(path=osp.dirname(dst))
|
|
tar.close()
|
|
except:
|
|
zip_ref = zipfile.ZipFile(dst, "r")
|
|
zip_ref.extractall(osp.dirname(dst))
|
|
zip_ref.close()
|
|
|
|
print("File extracted to {}".format(osp.dirname(dst)))
|
|
|
|
def generate_fewshot_dataset(
|
|
self, *data_sources, num_shots=-1, repeat=False
|
|
):
|
|
"""Generate a few-shot dataset (typically for the training set).
|
|
|
|
This function is useful when one wants to evaluate a model
|
|
in a few-shot learning setting where each class only contains
|
|
a few number of images.
|
|
|
|
Args:
|
|
data_sources: each individual is a list containing Datum objects.
|
|
num_shots (int): number of instances per class to sample.
|
|
repeat (bool): repeat images if needed (default: False).
|
|
"""
|
|
if num_shots < 1:
|
|
if len(data_sources) == 1:
|
|
return data_sources[0]
|
|
return data_sources
|
|
|
|
print(f"Creating a {num_shots}-shot dataset")
|
|
|
|
output = []
|
|
|
|
for data_source in data_sources:
|
|
tracker = self.split_dataset_by_label(data_source)
|
|
dataset = []
|
|
|
|
for label, items in tracker.items():
|
|
if len(items) >= num_shots:
|
|
sampled_items = random.sample(items, num_shots)
|
|
else:
|
|
if repeat:
|
|
sampled_items = random.choices(items, k=num_shots)
|
|
else:
|
|
sampled_items = items
|
|
dataset.extend(sampled_items)
|
|
|
|
output.append(dataset)
|
|
|
|
if len(output) == 1:
|
|
return output[0]
|
|
|
|
return output
|
|
|
|
def split_dataset_by_label(self, data_source):
|
|
"""Split a dataset, i.e. a list of Datum objects,
|
|
into class-specific groups stored in a dictionary.
|
|
|
|
Args:
|
|
data_source (list): a list of Datum objects.
|
|
"""
|
|
output = defaultdict(list)
|
|
|
|
for item in data_source:
|
|
output[item.label].append(item)
|
|
|
|
return output
|
|
|
|
def split_dataset_by_domain(self, data_source):
|
|
"""Split a dataset, i.e. a list of Datum objects,
|
|
into domain-specific groups stored in a dictionary.
|
|
|
|
Args:
|
|
data_source (list): a list of Datum objects.
|
|
"""
|
|
output = defaultdict(list)
|
|
|
|
for item in data_source:
|
|
output[item.domain].append(item)
|
|
|
|
return output
|