release code
This commit is contained in:
7
Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
Normal file
7
Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .digit5 import Digit5
|
||||
from .visda17 import VisDA17
|
||||
from .cifarstl import CIFARSTL
|
||||
from .office31 import Office31
|
||||
from .domainnet import DomainNet
|
||||
from .office_home import OfficeHome
|
||||
from .mini_domainnet import miniDomainNet
|
||||
68
Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
Normal file
68
Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CIFARSTL(DatasetBase):
|
||||
"""CIFAR-10 and STL-10.
|
||||
|
||||
CIFAR-10:
|
||||
- 60,000 32x32 colour images.
|
||||
- 10 classes, with 6,000 images per class.
|
||||
- 50,000 training images and 10,000 test images.
|
||||
- URL: https://www.cs.toronto.edu/~kriz/cifar.html.
|
||||
|
||||
STL-10:
|
||||
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
|
||||
monkey, ship, truck.
|
||||
- Images are 96x96 pixels, color.
|
||||
- 500 training images (10 pre-defined folds), 800 test images
|
||||
per class.
|
||||
- URL: https://cs.stanford.edu/~acoates/stl10/.
|
||||
|
||||
Reference:
|
||||
- Krizhevsky. Learning Multiple Layers of Features
|
||||
from Tiny Images. Tech report.
|
||||
- Coates et al. An Analysis of Single Layer Networks in
|
||||
Unsupervised Feature Learning. AISTATS 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "cifar_stl"
|
||||
domains = ["cifar", "stl"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
data_dir = osp.join(self.dataset_dir, dname, split)
|
||||
class_names = listdir_nohidden(data_dir)
|
||||
|
||||
for class_name in class_names:
|
||||
class_dir = osp.join(data_dir, class_name)
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
label = int(class_name.split("_")[0])
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label, domain=domain)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
124
Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import random
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
# Folder names for train and test sets
|
||||
MNIST = {"train": "train_images", "test": "test_images"}
|
||||
MNIST_M = {"train": "train_images", "test": "test_images"}
|
||||
SVHN = {"train": "train_images", "test": "test_images"}
|
||||
SYN = {"train": "train_images", "test": "test_images"}
|
||||
USPS = {"train": "train_images", "test": "test_images"}
|
||||
|
||||
|
||||
def read_image_list(im_dir, n_max=None, n_repeat=None):
|
||||
items = []
|
||||
|
||||
for imname in listdir_nohidden(im_dir):
|
||||
imname_noext = osp.splitext(imname)[0]
|
||||
label = int(imname_noext.split("_")[1])
|
||||
impath = osp.join(im_dir, imname)
|
||||
items.append((impath, label))
|
||||
|
||||
if n_max is not None:
|
||||
items = random.sample(items, n_max)
|
||||
|
||||
if n_repeat is not None:
|
||||
items *= n_repeat
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def load_mnist(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_mnist_m(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, MNIST_M[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_svhn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SVHN[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_syn(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, SYN[split])
|
||||
n_max = 25000 if split == "train" else 9000
|
||||
return read_image_list(data_dir, n_max=n_max)
|
||||
|
||||
|
||||
def load_usps(dataset_dir, split="train"):
|
||||
data_dir = osp.join(dataset_dir, USPS[split])
|
||||
n_repeat = 3 if split == "train" else None
|
||||
return read_image_list(data_dir, n_repeat=n_repeat)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Digit5(DatasetBase):
|
||||
"""Five digit datasets.
|
||||
|
||||
It contains:
|
||||
- MNIST: hand-written digits.
|
||||
- MNIST-M: variant of MNIST with blended background.
|
||||
- SVHN: street view house number.
|
||||
- SYN: synthetic digits.
|
||||
- USPS: hand-written digits, slightly different from MNIST.
|
||||
|
||||
For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from
|
||||
the training set and 9,000 images from the test set. For USPS which has only
|
||||
9,298 images in total, we use the entire dataset but replicate its training
|
||||
set for 3 times so as to match the training set size of other domains.
|
||||
|
||||
Reference:
|
||||
- Lecun et al. Gradient-based learning applied to document
|
||||
recognition. IEEE 1998.
|
||||
- Ganin et al. Domain-adversarial training of neural networks.
|
||||
JMLR 2016.
|
||||
- Netzer et al. Reading digits in natural images with unsupervised
|
||||
feature learning. NIPS-W 2011.
|
||||
"""
|
||||
|
||||
dataset_dir = "digit5"
|
||||
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
func = "load_" + dname
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
items_d = eval(func)(domain_dir, split=split)
|
||||
|
||||
for impath, label in items_d:
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=str(label)
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
69
Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
Normal file
69
Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DomainNet(DatasetBase):
|
||||
"""DomainNet.
|
||||
|
||||
Statistics:
|
||||
- 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,
|
||||
Real, Sketch.
|
||||
- Around 0.6M images.
|
||||
- 345 categories.
|
||||
- URL: http://ai.bu.edu/M3SDA/.
|
||||
|
||||
Special note: the t-shirt class (327) is missing in painting_train.txt.
|
||||
|
||||
Reference:
|
||||
- Peng et al. Moment Matching for Multi-Source Domain
|
||||
Adaptation. ICCV 2019.
|
||||
"""
|
||||
|
||||
dataset_dir = "domainnet"
|
||||
domains = [
|
||||
"clipart", "infograph", "painting", "quickdraw", "real", "sketch"
|
||||
]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.split_dir = osp.join(self.dataset_dir, "splits")
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
filename = dname + "_" + split + ".txt"
|
||||
split_file = osp.join(self.split_dir, filename)
|
||||
|
||||
with open(split_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[1]
|
||||
impath = osp.join(self.dataset_dir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,58 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class miniDomainNet(DatasetBase):
|
||||
"""A subset of DomainNet.
|
||||
|
||||
Reference:
|
||||
- Peng et al. Moment Matching for Multi-Source Domain
|
||||
Adaptation. ICCV 2019.
|
||||
- Zhou et al. Domain Adaptive Ensemble Learning.
|
||||
"""
|
||||
|
||||
dataset_dir = "domainnet"
|
||||
domains = ["clipart", "painting", "real", "sketch"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.split_dir = osp.join(self.dataset_dir, "splits_mini")
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains, split="train"):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
filename = dname + "_" + split + ".txt"
|
||||
split_file = osp.join(self.split_dir, filename)
|
||||
|
||||
with open(split_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[1]
|
||||
impath = osp.join(self.dataset_dir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
Normal file
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Office31(DatasetBase):
|
||||
"""Office-31.
|
||||
|
||||
Statistics:
|
||||
- 4,110 images.
|
||||
- 31 classes related to office objects.
|
||||
- 3 domains: Amazon, Webcam, Dslr.
|
||||
- URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.
|
||||
|
||||
Reference:
|
||||
- Saenko et al. Adapting visual category models to
|
||||
new domains. ECCV 2010.
|
||||
"""
|
||||
|
||||
dataset_dir = "office31"
|
||||
domains = ["amazon", "webcam", "dslr"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
class_names = listdir_nohidden(domain_dir)
|
||||
class_names.sort()
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_path = osp.join(domain_dir, class_name)
|
||||
imnames = listdir_nohidden(class_path)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_path, imname)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=class_name
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
Normal file
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os.path as osp
|
||||
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OfficeHome(DatasetBase):
|
||||
"""Office-Home.
|
||||
|
||||
Statistics:
|
||||
- Around 15,500 images.
|
||||
- 65 classes related to office and home objects.
|
||||
- 4 domains: Art, Clipart, Product, Real World.
|
||||
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
|
||||
|
||||
Reference:
|
||||
- Venkateswara et al. Deep Hashing Network for Unsupervised
|
||||
Domain Adaptation. CVPR 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "office_home"
|
||||
domains = ["art", "clipart", "product", "real_world"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
|
||||
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, input_domains):
|
||||
items = []
|
||||
|
||||
for domain, dname in enumerate(input_domains):
|
||||
domain_dir = osp.join(self.dataset_dir, dname)
|
||||
class_names = listdir_nohidden(domain_dir)
|
||||
class_names.sort()
|
||||
|
||||
for label, class_name in enumerate(class_names):
|
||||
class_path = osp.join(domain_dir, class_name)
|
||||
imnames = listdir_nohidden(class_path)
|
||||
|
||||
for imname in imnames:
|
||||
impath = osp.join(class_path, imname)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=class_name.lower(),
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
61
Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
Normal file
61
Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os.path as osp
|
||||
|
||||
from ..build import DATASET_REGISTRY
|
||||
from ..base_dataset import Datum, DatasetBase
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VisDA17(DatasetBase):
|
||||
"""VisDA17.
|
||||
|
||||
Focusing on simulation-to-reality domain shift.
|
||||
|
||||
URL: http://ai.bu.edu/visda-2017/.
|
||||
|
||||
Reference:
|
||||
- Peng et al. VisDA: The Visual Domain Adaptation
|
||||
Challenge. ArXiv 2017.
|
||||
"""
|
||||
|
||||
dataset_dir = "visda17"
|
||||
domains = ["synthetic", "real"]
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
|
||||
self.check_input_domains(
|
||||
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||
)
|
||||
|
||||
train_x = self._read_data("synthetic")
|
||||
train_u = self._read_data("real")
|
||||
test = self._read_data("real")
|
||||
|
||||
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||
|
||||
def _read_data(self, dname):
|
||||
filedir = "train" if dname == "synthetic" else "validation"
|
||||
image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
|
||||
items = []
|
||||
# There is only one source domain
|
||||
domain = 0
|
||||
|
||||
with open(image_list, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
impath, label = line.split(" ")
|
||||
classname = impath.split("/")[0]
|
||||
impath = osp.join(self.dataset_dir, filedir, impath)
|
||||
label = int(label)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
domain=domain,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
Reference in New Issue
Block a user