release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
from .pacs import PACS
from .vlcs import VLCS
from .cifar_c import CIFAR10C, CIFAR100C
from .digits_dg import DigitsDG
from .digit_single import DigitSingle
from .office_home_dg import OfficeHomeDG

View File

@@ -0,0 +1,123 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
AVAI_C_TYPES = [
"brightness",
"contrast",
"defocus_blur",
"elastic_transform",
"fog",
"frost",
"gaussian_blur",
"gaussian_noise",
"glass_blur",
"impulse_noise",
"jpeg_compression",
"motion_blur",
"pixelate",
"saturate",
"shot_noise",
"snow",
"spatter",
"speckle_noise",
"zoom_blur",
]
@DATASET_REGISTRY.register()
class CIFAR10C(DatasetBase):
"""CIFAR-10 -> CIFAR-10-C.
Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o
Statistics:
- 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10
- 10 categories
Reference:
- Hendrycks et al. Benchmarking neural network robustness
to common corruptions and perturbations. ICLR 2019.
"""
dataset_dir = ""
domains = ["cifar10", "cifar10_c"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = root
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
source_domain = cfg.DATASET.SOURCE_DOMAINS[0]
target_domain = cfg.DATASET.TARGET_DOMAINS[0]
assert source_domain == self.domains[0]
assert target_domain == self.domains[1]
c_type = cfg.DATASET.CIFAR_C_TYPE
c_level = cfg.DATASET.CIFAR_C_LEVEL
if not c_type:
raise ValueError(
"Please specify DATASET.CIFAR_C_TYPE in the config file"
)
assert (
c_type in AVAI_C_TYPES
), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"'
assert 1 <= c_level <= 5
train_dir = osp.join(self.dataset_dir, source_domain, "train")
test_dir = osp.join(
self.dataset_dir, target_domain, c_type, str(c_level)
)
if not osp.exists(test_dir):
raise ValueError
train = self._read_data(train_dir)
test = self._read_data(test_dir)
super().__init__(train_x=train, test=test)
def _read_data(self, data_dir):
class_names = listdir_nohidden(data_dir)
class_names.sort()
items = []
for label, class_name in enumerate(class_names):
class_dir = osp.join(data_dir, class_name)
imnames = listdir_nohidden(class_dir)
for imname in imnames:
impath = osp.join(class_dir, imname)
item = Datum(impath=impath, label=label, domain=0)
items.append(item)
return items
@DATASET_REGISTRY.register()
class CIFAR100C(CIFAR10C):
"""CIFAR-100 -> CIFAR-100-C.
Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o
Statistics:
- 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100
- 10 categories
Reference:
- Hendrycks et al. Benchmarking neural network robustness
to common corruptions and perturbations. ICLR 2019.
"""
dataset_dir = ""
domains = ["cifar100", "cifar100_c"]
def __init__(self, cfg):
super().__init__(cfg)

View File

@@ -0,0 +1,124 @@
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}
def read_image_list(im_dir, n_max=None, n_repeat=None):
items = []
for imname in listdir_nohidden(im_dir):
imname_noext = osp.splitext(imname)[0]
label = int(imname_noext.split("_")[1])
impath = osp.join(im_dir, imname)
items.append((impath, label))
if n_max is not None:
# Note that the sampling process is NOT random,
# which follows that in Volpi et al. NIPS'18.
items = items[:n_max]
if n_repeat is not None:
items *= n_repeat
return items
def load_mnist(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_mnist_m(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, MNIST_M[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_svhn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SVHN[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_syn(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, SYN[split])
n_max = 10000 if split == "train" else None
return read_image_list(data_dir, n_max=n_max)
def load_usps(dataset_dir, split="train"):
data_dir = osp.join(dataset_dir, USPS[split])
return read_image_list(data_dir)
@DATASET_REGISTRY.register()
class DigitSingle(DatasetBase):
"""Digit recognition datasets for single-source domain generalization.
There are five digit datasets:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
- USPS: hand-written digits, slightly different from MNIST.
Protocol:
Volpi et al. train a model using 10,000 images from MNIST and
evaluate the model on the test split of the other four datasets. However,
the code does not restrict you to only use MNIST as the source dataset.
Instead, you can use any dataset as the source. But note that only 10,000
images will be sampled from the source dataset for training.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
- Volpi et al. Generalizing to Unseen Domains via Adversarial Data
Augmentation. NIPS 2018.
"""
# Reuse the digit-5 folder instead of creating a new folder
dataset_dir = "digit5"
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split="train"):
items = []
for domain, dname in enumerate(input_domains):
func = "load_" + dname
domain_dir = osp.join(self.dataset_dir, dname)
items_d = eval(func)(domain_dir, split=split)
for impath, label in items_d:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items

View File

@@ -0,0 +1,97 @@
import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class DigitsDG(DatasetBase):
"""Digits-DG.
It contains 4 digit datasets:
- MNIST: hand-written digits.
- MNIST-M: variant of MNIST with blended background.
- SVHN: street view house number.
- SYN: synthetic digits.
Reference:
- Lecun et al. Gradient-based learning applied to document
recognition. IEEE 1998.
- Ganin et al. Domain-adversarial training of neural networks.
JMLR 2016.
- Netzer et al. Reading digits in natural images with unsupervised
feature learning. NIPS-W 2011.
- Zhou et al. Deep Domain-Adversarial Image Generation for Domain
Generalisation. AAAI 2020.
"""
dataset_dir = "digits_dg"
domains = ["mnist", "mnist_m", "svhn", "syn"]
data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "digits_dg.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
)
val = self.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
)
test = self.read_data(
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
)
super().__init__(train_x=train, val=val, test=test)
@staticmethod
def read_data(dataset_dir, input_domains, split):
def _load_data_from_directory(directory):
folders = listdir_nohidden(directory)
folders.sort()
items_ = []
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(directory, folder, "*.jpg"))
for impath in impaths:
items_.append((impath, label))
return items_
items = []
for domain, dname in enumerate(input_domains):
if split == "all":
train_dir = osp.join(dataset_dir, dname, "train")
impath_label_list = _load_data_from_directory(train_dir)
val_dir = osp.join(dataset_dir, dname, "val")
impath_label_list += _load_data_from_directory(val_dir)
else:
split_dir = osp.join(dataset_dir, dname, split)
impath_label_list = _load_data_from_directory(split_dir)
for impath, label in impath_label_list:
class_name = impath.split("/")[-2].lower()
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=class_name
)
items.append(item)
return items

View File

@@ -0,0 +1,49 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from .digits_dg import DigitsDG
from ..base_dataset import DatasetBase
@DATASET_REGISTRY.register()
class OfficeHomeDG(DatasetBase):
"""Office-Home.
Statistics:
- Around 15,500 images.
- 65 classes related to office and home objects.
- 4 domains: Art, Clipart, Product, Real World.
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
Reference:
- Venkateswara et al. Deep Hashing Network for Unsupervised
Domain Adaptation. CVPR 2017.
"""
dataset_dir = "office_home_dg"
domains = ["art", "clipart", "product", "real_world"]
data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "office_home_dg.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
)
val = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
)
test = DigitsDG.read_data(
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
)
super().__init__(train_x=train, val=val, test=test)

View File

@@ -0,0 +1,94 @@
import os.path as osp
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class PACS(DatasetBase):
"""PACS.
Statistics:
- 4 domains: Photo (1,670), Art (2,048), Cartoon
(2,344), Sketch (3,929).
- 7 categories: dog, elephant, giraffe, guitar, horse,
house and person.
Reference:
- Li et al. Deeper, broader and artier domain generalization.
ICCV 2017.
"""
dataset_dir = "pacs"
domains = ["art_painting", "cartoon", "photo", "sketch"]
data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE"
# the following images contain errors and should be ignored
_error_paths = ["sketch/dog/n02103406_4068-1.png"]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.image_dir = osp.join(self.dataset_dir, "images")
self.split_dir = osp.join(self.dataset_dir, "splits")
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "pacs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
if split == "all":
file_train = osp.join(
self.split_dir, dname + "_train_kfold.txt"
)
impath_label_list = self._read_split_pacs(file_train)
file_val = osp.join(
self.split_dir, dname + "_crossval_kfold.txt"
)
impath_label_list += self._read_split_pacs(file_val)
else:
file = osp.join(
self.split_dir, dname + "_" + split + "_kfold.txt"
)
impath_label_list = self._read_split_pacs(file)
for impath, label in impath_label_list:
classname = impath.split("/")[-2]
item = Datum(
impath=impath,
label=label,
domain=domain,
classname=classname
)
items.append(item)
return items
def _read_split_pacs(self, split_file):
items = []
with open(split_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
impath, label = line.split(" ")
if impath in self._error_paths:
continue
impath = osp.join(self.image_dir, impath)
label = int(label) - 1
items.append((impath, label))
return items

View File

@@ -0,0 +1,60 @@
import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class VLCS(DatasetBase):
"""VLCS.
Statistics:
- 4 domains: CALTECH, LABELME, PASCAL, SUN
- 5 categories: bird, car, chair, dog, and person.
Reference:
- Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
"""
dataset_dir = "VLCS"
domains = ["caltech", "labelme", "pascal", "sun"]
data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "vlcs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
dname = dname.upper()
path = osp.join(self.dataset_dir, dname, split)
folders = listdir_nohidden(path)
folders.sort()
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(path, folder, "*.jpg"))
for impath in impaths:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items