125 lines
4.2 KiB
Python
125 lines
4.2 KiB
Python
import os.path as osp
|
|
|
|
from dassl.utils import listdir_nohidden
|
|
|
|
from ..build import DATASET_REGISTRY
|
|
from ..base_dataset import Datum, DatasetBase
|
|
|
|
# Folder names for train and test sets
|
|
MNIST = {"train": "train_images", "test": "test_images"}
|
|
MNIST_M = {"train": "train_images", "test": "test_images"}
|
|
SVHN = {"train": "train_images", "test": "test_images"}
|
|
SYN = {"train": "train_images", "test": "test_images"}
|
|
USPS = {"train": "train_images", "test": "test_images"}
|
|
|
|
|
|
def read_image_list(im_dir, n_max=None, n_repeat=None):
|
|
items = []
|
|
|
|
for imname in listdir_nohidden(im_dir):
|
|
imname_noext = osp.splitext(imname)[0]
|
|
label = int(imname_noext.split("_")[1])
|
|
impath = osp.join(im_dir, imname)
|
|
items.append((impath, label))
|
|
|
|
if n_max is not None:
|
|
# Note that the sampling process is NOT random,
|
|
# which follows that in Volpi et al. NIPS'18.
|
|
items = items[:n_max]
|
|
|
|
if n_repeat is not None:
|
|
items *= n_repeat
|
|
|
|
return items
|
|
|
|
|
|
def load_mnist(dataset_dir, split="train"):
|
|
data_dir = osp.join(dataset_dir, MNIST[split])
|
|
n_max = 10000 if split == "train" else None
|
|
return read_image_list(data_dir, n_max=n_max)
|
|
|
|
|
|
def load_mnist_m(dataset_dir, split="train"):
|
|
data_dir = osp.join(dataset_dir, MNIST_M[split])
|
|
n_max = 10000 if split == "train" else None
|
|
return read_image_list(data_dir, n_max=n_max)
|
|
|
|
|
|
def load_svhn(dataset_dir, split="train"):
|
|
data_dir = osp.join(dataset_dir, SVHN[split])
|
|
n_max = 10000 if split == "train" else None
|
|
return read_image_list(data_dir, n_max=n_max)
|
|
|
|
|
|
def load_syn(dataset_dir, split="train"):
|
|
data_dir = osp.join(dataset_dir, SYN[split])
|
|
n_max = 10000 if split == "train" else None
|
|
return read_image_list(data_dir, n_max=n_max)
|
|
|
|
|
|
def load_usps(dataset_dir, split="train"):
|
|
data_dir = osp.join(dataset_dir, USPS[split])
|
|
return read_image_list(data_dir)
|
|
|
|
|
|
@DATASET_REGISTRY.register()
|
|
class DigitSingle(DatasetBase):
|
|
"""Digit recognition datasets for single-source domain generalization.
|
|
|
|
There are five digit datasets:
|
|
- MNIST: hand-written digits.
|
|
- MNIST-M: variant of MNIST with blended background.
|
|
- SVHN: street view house number.
|
|
- SYN: synthetic digits.
|
|
- USPS: hand-written digits, slightly different from MNIST.
|
|
|
|
Protocol:
|
|
Volpi et al. train a model using 10,000 images from MNIST and
|
|
evaluate the model on the test split of the other four datasets. However,
|
|
the code does not restrict you to only use MNIST as the source dataset.
|
|
Instead, you can use any dataset as the source. But note that only 10,000
|
|
images will be sampled from the source dataset for training.
|
|
|
|
Reference:
|
|
- Lecun et al. Gradient-based learning applied to document
|
|
recognition. IEEE 1998.
|
|
- Ganin et al. Domain-adversarial training of neural networks.
|
|
JMLR 2016.
|
|
- Netzer et al. Reading digits in natural images with unsupervised
|
|
feature learning. NIPS-W 2011.
|
|
- Volpi et al. Generalizing to Unseen Domains via Adversarial Data
|
|
Augmentation. NIPS 2018.
|
|
"""
|
|
|
|
# Reuse the digit-5 folder instead of creating a new folder
|
|
dataset_dir = "digit5"
|
|
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
|
|
|
def __init__(self, cfg):
|
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
|
|
|
self.check_input_domains(
|
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
|
)
|
|
|
|
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
|
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
|
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
|
|
|
super().__init__(train_x=train, val=val, test=test)
|
|
|
|
def _read_data(self, input_domains, split="train"):
|
|
items = []
|
|
|
|
for domain, dname in enumerate(input_domains):
|
|
func = "load_" + dname
|
|
domain_dir = osp.join(self.dataset_dir, dname)
|
|
items_d = eval(func)(domain_dir, split=split)
|
|
|
|
for impath, label in items_d:
|
|
item = Datum(impath=impath, label=label, domain=domain)
|
|
items.append(item)
|
|
|
|
return items
|