release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
from .build import DATASET_REGISTRY, build_dataset # isort:skip
from .base_dataset import Datum, DatasetBase # isort:skip
from .da import *
from .dg import *
from .ssl import *

View File

@@ -0,0 +1,225 @@
import os
import random
import os.path as osp
import tarfile
import zipfile
from collections import defaultdict
import gdown
from dassl.utils import check_isfile
class Datum:
"""Data instance which defines the basic attributes.
Args:
impath (str): image path.
label (int): class label.
domain (int): domain label.
classname (str): class name.
"""
def __init__(self, impath="", label=0, domain=0, classname=""):
assert isinstance(impath, str)
assert check_isfile(impath)
self._impath = impath
self._label = label
self._domain = domain
self._classname = classname
@property
def impath(self):
return self._impath
@property
def label(self):
return self._label
@property
def domain(self):
return self._domain
@property
def classname(self):
return self._classname
class DatasetBase:
"""A unified dataset class for
1) domain adaptation
2) domain generalization
3) semi-supervised learning
"""
dataset_dir = "" # the directory where the dataset is stored
domains = [] # string names of all domains
def __init__(self, train_x=None, train_u=None, val=None, test=None):
self._train_x = train_x # labeled training data
self._train_u = train_u # unlabeled training data (optional)
self._val = val # validation data (optional)
self._test = test # test data
self._num_classes = self.get_num_classes(train_x)
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
@property
def train_x(self):
return self._train_x
@property
def train_u(self):
return self._train_u
@property
def val(self):
return self._val
@property
def test(self):
return self._test
@property
def lab2cname(self):
return self._lab2cname
@property
def classnames(self):
return self._classnames
@property
def num_classes(self):
return self._num_classes
def get_num_classes(self, data_source):
"""Count number of classes.
Args:
data_source (list): a list of Datum objects.
"""
label_set = set()
for item in data_source:
label_set.add(item.label)
return max(label_set) + 1
def get_lab2cname(self, data_source):
"""Get a label-to-classname mapping (dict).
Args:
data_source (list): a list of Datum objects.
"""
container = set()
for item in data_source:
container.add((item.label, item.classname))
mapping = {label: classname for label, classname in container}
labels = list(mapping.keys())
labels.sort()
classnames = [mapping[label] for label in labels]
return mapping, classnames
def check_input_domains(self, source_domains, target_domains):
self.is_input_domain_valid(source_domains)
self.is_input_domain_valid(target_domains)
def is_input_domain_valid(self, input_domains):
for domain in input_domains:
if domain not in self.domains:
raise ValueError(
"Input domain must belong to {}, "
"but got [{}]".format(self.domains, domain)
)
def download_data(self, url, dst, from_gdrive=True):
if not osp.exists(osp.dirname(dst)):
os.makedirs(osp.dirname(dst))
if from_gdrive:
gdown.download(url, dst, quiet=False)
else:
raise NotImplementedError
print("Extracting file ...")
try:
tar = tarfile.open(dst)
tar.extractall(path=osp.dirname(dst))
tar.close()
except:
zip_ref = zipfile.ZipFile(dst, "r")
zip_ref.extractall(osp.dirname(dst))
zip_ref.close()
print("File extracted to {}".format(osp.dirname(dst)))
def generate_fewshot_dataset(
self, *data_sources, num_shots=-1, repeat=False
):
"""Generate a few-shot dataset (typically for the training set).
This function is useful when one wants to evaluate a model
in a few-shot learning setting where each class only contains
a few number of images.
Args:
data_sources: each individual is a list containing Datum objects.
num_shots (int): number of instances per class to sample.
repeat (bool): repeat images if needed (default: False).
"""
if num_shots < 1:
if len(data_sources) == 1:
return data_sources[0]
return data_sources
print(f"Creating a {num_shots}-shot dataset")
output = []
for data_source in data_sources:
tracker = self.split_dataset_by_label(data_source)
dataset = []
for label, items in tracker.items():
if len(items) >= num_shots:
sampled_items = random.sample(items, num_shots)
else:
if repeat:
sampled_items = random.choices(items, k=num_shots)
else:
sampled_items = items
dataset.extend(sampled_items)
output.append(dataset)
if len(output) == 1:
return output[0]
return output
def split_dataset_by_label(self, data_source):
"""Split a dataset, i.e. a list of Datum objects,
into class-specific groups stored in a dictionary.
Args:
data_source (list): a list of Datum objects.
"""
output = defaultdict(list)
for item in data_source:
output[item.label].append(item)
return output
def split_dataset_by_domain(self, data_source):
"""Split a dataset, i.e. a list of Datum objects,
into domain-specific groups stored in a dictionary.
Args:
data_source (list): a list of Datum objects.
"""
output = defaultdict(list)
for item in data_source:
output[item.domain].append(item)
return output

View File

@@ -0,0 +1,11 @@
from dassl.utils import Registry, check_availability
DATASET_REGISTRY = Registry("DATASET")
def build_dataset(cfg):
avai_datasets = DATASET_REGISTRY.registered_names()
check_availability(cfg.DATASET.NAME, avai_datasets)
if cfg.VERBOSE:
print("Loading dataset: {}".format(cfg.DATASET.NAME))
return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg)

View File

@@ -0,0 +1,7 @@
from .digit5 import Digit5
from .visda17 import VisDA17
from .cifarstl import CIFARSTL
from .office31 import Office31
from .domainnet import DomainNet
from .office_home import OfficeHome
from .mini_domainnet import miniDomainNet

View File

@@ -0,0 +1,68 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class CIFARSTL(DatasetBase):
"""CIFAR-10 and STL-10.
CIFAR-10:
- 60,000 32x32 colour images.
- 10 classes, with 6,000 images per class.
- 50,000 training images and 10,000 test images.
- URL: https://www.cs.toronto.edu/~kriz/cifar.html.
STL-10:
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
monkey, ship, truck.
- Images are 96x96 pixels, color.
- 500 training images (10 pre-defined folds), 800 test images
per class.
- URL: https://cs.stanford.edu/~acoates/stl10/.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
- Coates et al. An Analysis of Single Layer Networks in
Unsupervised Feature Learning. AISTATS 2011.
"""
dataset_dir = "cifar_stl"
domains = ["cifar", "stl"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
data_dir = osp.join(self.dataset_dir, dname, split)
class_names = listdir_nohidden(data_dir)
for class_name in class_names:
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
label = int(class_name.split("_")[0])
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items

View File

@@ -0,0 +1,124 @@
import random
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}
def read_image_list(im_dir, n_max=None, n_repeat=None):
items = []
for imname in listdir_nohidden(im_dir):
imname_noext = osp.splitext(imname)[0]
label = int(imname_noext.split("_")[1])
impath = osp.join(im_dir, imname)
items.append((impath, label))
if n_max is not None:
items = random.sample(items, n_max)
if n_repeat is not None:
items *= n_repeat
return items
def load_mnist(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_mnist_m(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST_M[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_svhn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SVHN[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_syn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SYN[split])
n_max = 25000 if split == "train" else 9000
return read_image_list(data_dir, n_max=n_max)
def load_usps(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, USPS[split])
n_repeat = 3 if split == "train" else None
return read_image_list(data_dir, n_repeat=n_repeat)
@DATASET_REGISTRY.register()
class Digit5(DatasetBase):
"""Five digit datasets.
It contains:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
- USPS: hand-written digits, slightly different from MNIST.
For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from
the training set and 9,000 images from the test set. For USPS which has only
9,298 images in total, we use the entire dataset but replicate its training
set for 3 times so as to match the training set size of other domains.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
"""
dataset_dir = "digit5"
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
func = "load_" + dname
domain_dir = osp.join(self.dataset_dir, dname)
items_d = eval(func)(domain_dir, split=split)
for impath, label in items_d:
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=str(label)
)
items.append(item)
return items

View File

@@ -0,0 +1,69 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class DomainNet(DatasetBase):
"""DomainNet.
Statistics:
- 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,
Real, Sketch.
- Around 0.6M images.
- 345 categories.
- URL: http://ai.bu.edu/M3SDA/.
Special note: the t-shirt class (327) is missing in painting_train.txt.
Reference:
- Peng et al. Moment Matching for Multi-Source Domain
Adaptation. ICCV 2019.
"""
dataset_dir = "domainnet"
domains = [
"clipart", "infograph", "painting", "quickdraw", "real", "sketch"
]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.split_dir = osp.join(self.dataset_dir, "splits")
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
filename = dname + "_" + split + ".txt"
split_file = osp.join(self.split_dir, filename)
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[1]
impath = osp.join(self.dataset_dir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items

View File

@@ -0,0 +1,58 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class miniDomainNet(DatasetBase):
"""A subset of DomainNet.
Reference:
- Peng et al. Moment Matching for Multi-Source Domain
Adaptation. ICCV 2019.
- Zhou et al. Domain Adaptive Ensemble Learning.
"""
dataset_dir = "domainnet"
domains = ["clipart", "painting", "real", "sketch"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.split_dir = osp.join(self.dataset_dir, "splits_mini")
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
filename = dname + "_" + split + ".txt"
split_file = osp.join(self.split_dir, filename)
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[1]
impath = osp.join(self.dataset_dir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items

View File

@@ -0,0 +1,63 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class Office31(DatasetBase):
"""Office-31.
Statistics:
- 4,110 images.
- 31 classes related to office objects.
- 3 domains: Amazon, Webcam, Dslr.
- URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.
Reference:
- Saenko et al. Adapting visual category models to
new domains. ECCV 2010.
"""
dataset_dir = "office31"
domains = ["amazon", "webcam", "dslr"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains):
items = []
for domain, dname in enumerate(input_domains):
domain_dir = osp.join(self.dataset_dir, dname)
class_names = listdir_nohidden(domain_dir)
class_names.sort()
for label, class_name in enumerate(class_names):
class_path = osp.join(domain_dir, class_name)
imnames = listdir_nohidden(class_path)
for imname in imnames:
impath = osp.join(class_path, imname)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name
)
items.append(item)
return items

View File

@@ -0,0 +1,63 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class OfficeHome(DatasetBase):
"""Office-Home.
Statistics:
- Around 15,500 images.
- 65 classes related to office and home objects.
- 4 domains: Art, Clipart, Product, Real World.
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
Reference:
- Venkateswara et al. Deep Hashing Network for Unsupervised
Domain Adaptation. CVPR 2017.
"""
dataset_dir = "office_home"
domains = ["art", "clipart", "product", "real_world"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, input_domains):
items = []
for domain, dname in enumerate(input_domains):
domain_dir = osp.join(self.dataset_dir, dname)
class_names = listdir_nohidden(domain_dir)
class_names.sort()
for label, class_name in enumerate(class_names):
class_path = osp.join(domain_dir, class_name)
imnames = listdir_nohidden(class_path)
for imname in imnames:
impath = osp.join(class_path, imname)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name.lower(),
)
items.append(item)
return items

View File

@@ -0,0 +1,61 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class VisDA17(DatasetBase):
"""VisDA17.
Focusing on simulation-to-reality domain shift.
URL: http://ai.bu.edu/visda-2017/.
Reference:
- Peng et al. VisDA: The Visual Domain Adaptation
Challenge. ArXiv 2017.
"""
dataset_dir = "visda17"
domains = ["synthetic", "real"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train_x = self._read_data("synthetic")
train_u = self._read_data("real")
test = self._read_data("real")
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data(self, dname):
filedir = "train" if dname == "synthetic" else "validation"
image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
items = []
# There is only one source domain
domain = 0
with open(image_list, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
classname = impath.split("/")[0]
impath = osp.join(self.dataset_dir, filedir, impath)
label = int(label)
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items

View File

@@ -0,0 +1,6 @@
from .pacs import PACS
from .vlcs import VLCS
from .cifar_c import CIFAR10C, CIFAR100C
from .digits_dg import DigitsDG
from .digit_single import DigitSingle
from .office_home_dg import OfficeHomeDG

View File

@@ -0,0 +1,123 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
AVAI_C_TYPES = [
"brightness",
"contrast",
"defocus_blur",
"elastic_transform",
"fog",
"frost",
"gaussian_blur",
"gaussian_noise",
"glass_blur",
"impulse_noise",
"jpeg_compression",
"motion_blur",
"pixelate",
"saturate",
"shot_noise",
"snow",
"spatter",
"speckle_noise",
"zoom_blur",
]
@DATASET_REGISTRY.register()
class CIFAR10C(DatasetBase):
"""CIFAR-10 -> CIFAR-10-C.
Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o
Statistics:
- 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10
- 10 categories
Reference:
- Hendrycks et al. Benchmarking neural network robustness
to common corruptions and perturbations. ICLR 2019.
"""
dataset_dir = ""
domains = ["cifar10", "cifar10_c"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = root
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
source_domain = cfg.DATASET.SOURCE_DOMAINS[0]
target_domain = cfg.DATASET.TARGET_DOMAINS[0]
assert source_domain == self.domains[0]
assert target_domain == self.domains[1]
c_type = cfg.DATASET.CIFAR_C_TYPE
c_level = cfg.DATASET.CIFAR_C_LEVEL
if not c_type:
raise ValueError(
"Please specify DATASET.CIFAR_C_TYPE in the config file"
)
assert (
c_type in AVAI_C_TYPES
), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"'
assert 1 <= c_level <= 5
train_dir = osp.join(self.dataset_dir, source_domain, "train")
test_dir = osp.join(
self.dataset_dir, target_domain, c_type, str(c_level)
)
if not osp.exists(test_dir):
raise ValueError
train = self._read_data(train_dir)
test = self._read_data(test_dir)
super().__init__(train_x=train, test=test)
def _read_data(self, data_dir):
class_names = listdir_nohidden(data_dir)
class_names.sort()
items = []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label, domain=0)
items.append(item)
return items
@DATASET_REGISTRY.register()
class CIFAR100C(CIFAR10C):
"""CIFAR-100 -> CIFAR-100-C.
Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o
Statistics:
- 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100
- 10 categories
Reference:
- Hendrycks et al. Benchmarking neural network robustness
to common corruptions and perturbations. ICLR 2019.
"""
dataset_dir = ""
domains = ["cifar100", "cifar100_c"]
def __init__(self, cfg):
super().__init__(cfg)

View File

@@ -0,0 +1,124 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}
def read_image_list(im_dir, n_max=None, n_repeat=None):
items = []
for imname in listdir_nohidden(im_dir):
imname_noext = osp.splitext(imname)[0]
label = int(imname_noext.split("_")[1])
impath = osp.join(im_dir, imname)
items.append((impath, label))
if n_max is not None:
# Note that the sampling process is NOT random,
# which follows that in Volpi et al. NIPS'18.
items = items[:n_max]
if n_repeat is not None:
items *= n_repeat
return items
def load_mnist(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_mnist_m(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST_M[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_svhn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SVHN[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_syn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SYN[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_usps(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, USPS[split])
return read_image_list(data_dir)
@DATASET_REGISTRY.register()
class DigitSingle(DatasetBase):
"""Digit recognition datasets for single-source domain generalization.
There are five digit datasets:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
- USPS: hand-written digits, slightly different from MNIST.
Protocol:
Volpi et al. train a model using 10,000 images from MNIST and
evaluate the model on the test split of the other four datasets. However,
the code does not restrict you to only use MNIST as the source dataset.
Instead, you can use any dataset as the source. But note that only 10,000
images will be sampled from the source dataset for training.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
- Volpi et al. Generalizing to Unseen Domains via Adversarial Data
Augmentation. NIPS 2018.
"""
# Reuse the digit-5 folder instead of creating a new folder
dataset_dir = "digit5"
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
func = "load_" + dname
domain_dir = osp.join(self.dataset_dir, dname)
items_d = eval(func)(domain_dir, split=split)
for impath, label in items_d:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items

View File

@@ -0,0 +1,97 @@
import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class DigitsDG(DatasetBase):
"""Digits-DG.
It contains 4 digit datasets:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
- Zhou et al. Deep Domain-Adversarial Image Generation for Domain
Generalisation. AAAI 2020.
"""
dataset_dir = "digits_dg"
domains = ["mnist", "mnist_m", "svhn", "syn"]
data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "digits_dg.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
)
val = self.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
)
test = self.read_data(
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
)
super().__init__(train_x=train, val=val, test=test)
@staticmethod
def read_data(dataset_dir, input_domains, split):
def _load_data_from_directory(directory):
folders = listdir_nohidden(directory)
folders.sort()
items_ = []
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(directory, folder, "*.jpg"))
for impath in impaths:
items_.append((impath, label))
return items_
items = []
for domain, dname in enumerate(input_domains):
if split == "all":
train_dir = osp.join(dataset_dir, dname, "train")
impath_label_list = _load_data_from_directory(train_dir)
val_dir = osp.join(dataset_dir, dname, "val")
impath_label_list += _load_data_from_directory(val_dir)
else:
split_dir = osp.join(dataset_dir, dname, split)
impath_label_list = _load_data_from_directory(split_dir)
for impath, label in impath_label_list:
class_name = impath.split("/")[-2].lower()
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name
)
items.append(item)
return items

View File

@@ -0,0 +1,49 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from .digits_dg import DigitsDG
from ..base_dataset import DatasetBase
@DATASET_REGISTRY.register()
class OfficeHomeDG(DatasetBase):
"""Office-Home.
Statistics:
- Around 15,500 images.
- 65 classes related to office and home objects.
- 4 domains: Art, Clipart, Product, Real World.
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
Reference:
- Venkateswara et al. Deep Hashing Network for Unsupervised
Domain Adaptation. CVPR 2017.
"""
dataset_dir = "office_home_dg"
domains = ["art", "clipart", "product", "real_world"]
data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "office_home_dg.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
)
val = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
)
test = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
)
super().__init__(train_x=train, val=val, test=test)

View File

@@ -0,0 +1,94 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class PACS(DatasetBase):
"""PACS.
Statistics:
- 4 domains: Photo (1,670), Art (2,048), Cartoon
(2,344), Sketch (3,929).
- 7 categories: dog, elephant, giraffe, guitar, horse,
house and person.
Reference:
- Li et al. Deeper, broader and artier domain generalization.
ICCV 2017.
"""
dataset_dir = "pacs"
domains = ["art_painting", "cartoon", "photo", "sketch"]
data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE"
# the following images contain errors and should be ignored
_error_paths = ["sketch/dog/n02103406_4068-1.png"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.image_dir = osp.join(self.dataset_dir, "images")
self.split_dir = osp.join(self.dataset_dir, "splits")
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "pacs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
if split == "all":
file_train = osp.join(
self.split_dir, dname + "_train_kfold.txt"
)
impath_label_list = self._read_split_pacs(file_train)
file_val = osp.join(
self.split_dir, dname + "_crossval_kfold.txt"
)
impath_label_list += self._read_split_pacs(file_val)
else:
file = osp.join(
self.split_dir, dname + "_" + split + "_kfold.txt"
)
impath_label_list = self._read_split_pacs(file)
for impath, label in impath_label_list:
classname = impath.split("/")[-2]
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items
def _read_split_pacs(self, split_file):
items = []
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
if impath in self._error_paths:
continue
impath = osp.join(self.image_dir, impath)
label = int(label) - 1
items.append((impath, label))
return items

View File

@@ -0,0 +1,60 @@
import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class VLCS(DatasetBase):
"""VLCS.
Statistics:
- 4 domains: CALTECH, LABELME, PASCAL, SUN
- 5 categories: bird, car, chair, dog, and person.
Reference:
- Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
"""
dataset_dir = "VLCS"
domains = ["caltech", "labelme", "pascal", "sun"]
data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "vlcs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
dname = dname.upper()
path = osp.join(self.dataset_dir, dname, split)
folders = listdir_nohidden(path)
folders.sort()
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(path, folder, "*.jpg"))
for impath in impaths:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items

View File

@@ -0,0 +1,3 @@
from .svhn import SVHN
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10

View File

@@ -0,0 +1,108 @@
import math
import random
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class CIFAR10(DatasetBase):
"""CIFAR10 for SSL.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
"""
dataset_dir = "cifar10"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
train_dir = osp.join(self.dataset_dir, "train")
test_dir = osp.join(self.dataset_dir, "test")
assert cfg.DATASET.NUM_LABELED > 0
train_x, train_u, val = self._read_data_train(
train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT
)
test = self._read_data_test(test_dir)
if cfg.DATASET.ALL_AS_UNLABELED:
train_u = train_u + train_x
if len(val) == 0:
val = None
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
def _read_data_train(self, data_dir, num_labeled, val_percent):
class_names = listdir_nohidden(data_dir)
class_names.sort()
num_labeled_per_class = num_labeled / len(class_names)
items_x, items_u, items_v = [], [], []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
# Split into train and val following Oliver et al. 2018
# Set cfg.DATASET.VAL_PERCENT to 0 to not use val data
num_val = math.floor(len(imnames) * val_percent)
imnames_train = imnames[num_val:]
imnames_val = imnames[:num_val]
# Note we do shuffle after split
random.shuffle(imnames_train)
for i, imname in enumerate(imnames_train):
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
if (i + 1) <= num_labeled_per_class:
items_x.append(item)
else:
items_u.append(item)
for imname in imnames_val:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
items_v.append(item)
return items_x, items_u, items_v
def _read_data_test(self, data_dir):
class_names = listdir_nohidden(data_dir)
class_names.sort()
items = []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label)
items.append(item)
return items
@DATASET_REGISTRY.register()
class CIFAR100(CIFAR10):
"""CIFAR100 for SSL.
Reference:
- Krizhevsky. Learning Multiple Layers of Features
from Tiny Images. Tech report.
"""
dataset_dir = "cifar100"
def __init__(self, cfg):
super().__init__(cfg)

View File

@@ -0,0 +1,87 @@
import numpy as np
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class STL10(DatasetBase):
"""STL-10 dataset.
Description:
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
monkey, ship, truck.
- Images are 96x96 pixels, color.
- 500 training images per class, 800 test images per class.
- 100,000 unlabeled images for unsupervised learning.
Reference:
- Coates et al. An Analysis of Single Layer Networks in
Unsupervised Feature Learning. AISTATS 2011.
"""
dataset_dir = "stl10"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
train_dir = osp.join(self.dataset_dir, "train")
test_dir = osp.join(self.dataset_dir, "test")
unlabeled_dir = osp.join(self.dataset_dir, "unlabeled")
fold_file = osp.join(
self.dataset_dir, "stl10_binary", "fold_indices.txt"
)
# Only use the first five splits
assert 0 <= cfg.DATASET.STL10_FOLD <= 4
train_x = self._read_data_train(
train_dir, cfg.DATASET.STL10_FOLD, fold_file
)
train_u = self._read_data_all(unlabeled_dir)
test = self._read_data_all(test_dir)
if cfg.DATASET.ALL_AS_UNLABELED:
train_u = train_u + train_x
super().__init__(train_x=train_x, train_u=train_u, test=test)
def _read_data_train(self, data_dir, fold, fold_file):
imnames = listdir_nohidden(data_dir)
imnames.sort()
items = []
list_idx = list(range(len(imnames)))
if fold >= 0:
with open(fold_file, "r") as f:
str_idx = f.read().splitlines()[fold]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ")
for i in list_idx:
imname = imnames[i]
impath = osp.join(data_dir, imname)
label = osp.splitext(imname)[0].split("_")[1]
label = int(label)
item = Datum(impath=impath, label=label)
items.append(item)
return items
def _read_data_all(self, data_dir):
imnames = listdir_nohidden(data_dir)
items = []
for imname in imnames:
impath = osp.join(data_dir, imname)
label = osp.splitext(imname)[0].split("_")[1]
if label == "none":
label = -1
else:
label = int(label)
item = Datum(impath=impath, label=label)
items.append(item)
return items

View File

@@ -0,0 +1,17 @@
from .cifar import CIFAR10
from ..build import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class SVHN(CIFAR10):
"""SVHN for SSL.
Reference:
- Netzer et al. Reading Digits in Natural Images with
Unsupervised Feature Learning. NIPS-W 2011.
"""
dataset_dir = "svhn"
def __init__(self, cfg):
super().__init__(cfg)