Files
MSGCoOp/Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
2025-08-16 21:13:50 +08:00

61 lines
1.9 KiB
Python

import glob
import os.path as osp
from dassl.utils import listdir_nohidden
from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase
@DATASET_REGISTRY.register()
class VLCS(DatasetBase):
"""VLCS.
Statistics:
- 4 domains: CALTECH, LABELME, PASCAL, SUN
- 5 categories: bird, car, chair, dog, and person.
Reference:
- Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
"""
dataset_dir = "VLCS"
domains = ["caltech", "labelme", "pascal", "sun"]
data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = osp.join(root, self.dataset_dir)
if not osp.exists(self.dataset_dir):
dst = osp.join(root, "vlcs.zip")
self.download_data(self.data_url, dst, from_gdrive=True)
self.check_input_domains(
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
)
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
super().__init__(train_x=train, val=val, test=test)
def _read_data(self, input_domains, split):
items = []
for domain, dname in enumerate(input_domains):
dname = dname.upper()
path = osp.join(self.dataset_dir, dname, split)
folders = listdir_nohidden(path)
folders.sort()
for label, folder in enumerate(folders):
impaths = glob.glob(osp.join(path, folder, "*.jpg"))
for impath in impaths:
item = Datum(impath=impath, label=label, domain=domain)
items.append(item)
return items