Upload to Main
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from .dtd import DescribableTextures as DTD
|
||||
import deepcore.methods as s_method
|
||||
import numpy as np
|
||||
|
||||
IGNORED = ["BACKGROUND_Google", "Faces_easy"]
|
||||
NEW_CNAMES = {
|
||||
"airplanes": "airplane",
|
||||
"Faces": "face",
|
||||
"Leopards": "leopard",
|
||||
"Motorbikes": "motorbike",
|
||||
}
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Caltech101(DatasetBase):
|
||||
|
||||
dataset_dir = "caltech-101"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
@@ -0,0 +1,481 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
import os
|
||||
from dassl.utils import read_image
|
||||
|
||||
from dassl.data.datasets import build_dataset
|
||||
from dassl.data.samplers import build_sampler
|
||||
from dassl.data.transforms import INTERPOLATION_MODES, build_transform
|
||||
from .new_da import RandomResizedCropPair, build_transform_pair
|
||||
from PIL import Image
|
||||
|
||||
def build_data_loader(
|
||||
cfg,
|
||||
sampler_type="SequentialSampler",
|
||||
data_source=None,
|
||||
batch_size=64,
|
||||
n_domain=0,
|
||||
n_ins=2,
|
||||
tfm=None,
|
||||
is_train=True,
|
||||
dataset_wrapper=None,
|
||||
weight=None,
|
||||
):
|
||||
# Build sampler
|
||||
sampler = build_sampler(
|
||||
sampler_type,
|
||||
cfg=cfg,
|
||||
data_source=data_source,
|
||||
batch_size=batch_size,
|
||||
n_domain=n_domain,
|
||||
n_ins=n_ins
|
||||
)
|
||||
|
||||
if dataset_wrapper is None:
|
||||
dataset_wrapper = DatasetWrapper
|
||||
|
||||
# Build data loader
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset_wrapper(cfg, data_source,transform=tfm, is_train=is_train,weight=weight),
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
||||
drop_last=is_train and len(data_source) >= batch_size,
|
||||
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA)
|
||||
)
|
||||
assert len(data_loader) > 0
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
|
||||
def build_data_loader_mask(
|
||||
cfg,
|
||||
dataset,
|
||||
sampler_type="SequentialSampler",
|
||||
data_source=None,
|
||||
batch_size=64,
|
||||
n_domain=0,
|
||||
n_ins=2,
|
||||
tfm=None,
|
||||
is_train=True,
|
||||
dataset_wrapper=None,
|
||||
weight=None,
|
||||
):
|
||||
# Build sampler
|
||||
sampler = build_sampler(
|
||||
sampler_type,
|
||||
cfg=cfg,
|
||||
data_source=data_source,
|
||||
batch_size=batch_size,
|
||||
n_domain=n_domain,
|
||||
n_ins=n_ins
|
||||
)
|
||||
|
||||
if dataset_wrapper is None:
|
||||
dataset_wrapper = DatasetWrapperMask
|
||||
|
||||
# Build data loader
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset_wrapper(cfg, dataset,data_source,transform=tfm, is_train=is_train,weight=weight),
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
||||
drop_last=is_train and len(data_source) >= batch_size,
|
||||
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA)
|
||||
)
|
||||
assert len(data_loader) > 0
|
||||
|
||||
return data_loader
|
||||
|
||||
def select_dm_loader(cfg,dataset,s_ind=None,is_train=False):
|
||||
|
||||
tfm = build_transform(cfg, is_train=is_train)
|
||||
if is_train:
|
||||
dataloader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=list(np.asarray(dataset)[s_ind]) if s_ind is not None else dataset,
|
||||
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, #cfg.DATALOADER.TRAIN_X.BATCH_SIZE*
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm,
|
||||
is_train=is_train,
|
||||
dataset_wrapper=None,
|
||||
)
|
||||
else:
|
||||
dataloader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=list(np.asarray(dataset)[s_ind]) if s_ind is not None else dataset,
|
||||
batch_size=cfg.DATASET.SELECTION_BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm,
|
||||
is_train=is_train,
|
||||
dataset_wrapper=None,
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
class DataManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
dataset,
|
||||
s_ind=None,
|
||||
custom_tfm_train=None,
|
||||
custom_tfm_test=None,
|
||||
dataset_wrapper=None,
|
||||
weight=None,
|
||||
):
|
||||
# # Load dataset
|
||||
# dataset = build_dataset(cfg)
|
||||
|
||||
# Build transform
|
||||
if custom_tfm_train is None:
|
||||
###pair is for
|
||||
tfm_train_pair = build_transform_pair(cfg, is_train=True)
|
||||
tfm_train = build_transform(cfg,is_train=True)
|
||||
else:
|
||||
print("* Using custom transform for training")
|
||||
tfm_train = custom_tfm_train
|
||||
|
||||
if custom_tfm_test is None:
|
||||
tfm_test = build_transform(cfg, is_train=False)
|
||||
else:
|
||||
print("* Using custom transform for testing")
|
||||
tfm_test = custom_tfm_test
|
||||
|
||||
|
||||
# Build train_loader_x
|
||||
|
||||
train_loader_x = build_data_loader_mask(
|
||||
cfg,
|
||||
dataset,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=list(np.asarray(dataset.train_x)[s_ind]) if s_ind is not None else dataset.train_x,
|
||||
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm_train_pair,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
weight=weight
|
||||
)
|
||||
|
||||
|
||||
train_loader_xmore = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=list(np.asarray(dataset.train_x)[s_ind]) if s_ind is not None else dataset.train_x,
|
||||
batch_size=cfg.DATASET.SELECTION_BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
weight=weight
|
||||
)
|
||||
|
||||
# Build train_loader_u
|
||||
train_loader_u = None
|
||||
if dataset.train_u:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
|
||||
|
||||
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
|
||||
|
||||
train_loader_u = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=sampler_type_,
|
||||
data_source=dataset.train_u,
|
||||
batch_size=batch_size_,
|
||||
n_domain=n_domain_,
|
||||
n_ins=n_ins_,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper
|
||||
)
|
||||
|
||||
# Build val_loader
|
||||
val_loader = None
|
||||
if dataset.val:
|
||||
val_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.val,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper
|
||||
)
|
||||
|
||||
# Build test_loader
|
||||
test_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.test,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper
|
||||
)
|
||||
|
||||
# Attributes
|
||||
self._num_classes = dataset.num_classes
|
||||
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
|
||||
self._lab2cname = dataset.lab2cname
|
||||
|
||||
# Dataset and data-loaders
|
||||
self.dataset = dataset
|
||||
self.train_loader_x = train_loader_x
|
||||
self.train_loader_u = train_loader_u
|
||||
self.train_loader_xmore = train_loader_xmore
|
||||
self.val_loader = val_loader
|
||||
self.test_loader = test_loader
|
||||
|
||||
if cfg.VERBOSE:
|
||||
self.show_dataset_summary(cfg)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._num_classes
|
||||
|
||||
@property
|
||||
def num_source_domains(self):
|
||||
return self._num_source_domains
|
||||
|
||||
@property
|
||||
def lab2cname(self):
|
||||
return self._lab2cname
|
||||
|
||||
def show_dataset_summary(self, cfg):
|
||||
dataset_name = cfg.DATASET.NAME
|
||||
source_domains = cfg.DATASET.SOURCE_DOMAINS
|
||||
target_domains = cfg.DATASET.TARGET_DOMAINS
|
||||
|
||||
table = []
|
||||
table.append(["Dataset", dataset_name])
|
||||
if source_domains:
|
||||
table.append(["Source", source_domains])
|
||||
if target_domains:
|
||||
table.append(["Target", target_domains])
|
||||
table.append(["# classes", f"{self.num_classes:,}"])
|
||||
table.append(["# train_x", f"{len(self.dataset.train_x):,}"])
|
||||
if self.dataset.train_u:
|
||||
table.append(["# train_u", f"{len(self.dataset.train_u):,}"])
|
||||
if self.dataset.val:
|
||||
table.append(["# val", f"{len(self.dataset.val):,}"])
|
||||
table.append(["# test", f"{len(self.dataset.test):,}"])
|
||||
|
||||
print(tabulate(table))
|
||||
|
||||
|
||||
class DatasetWrapperMask(TorchDataset):
|
||||
|
||||
def __init__(self, cfg, dataset,data_source,transform=None, is_train=False,weight=None):
|
||||
self.cfg = cfg
|
||||
self.data_source = data_source
|
||||
self.transform = transform # accept list (tuple) as input
|
||||
self.is_train = is_train
|
||||
self.data_path = dataset.dataset_dir
|
||||
self.mask_path = os.path.join(dataset.dataset_dir,'mask')
|
||||
# Augmenting an image K>1 times is only allowed during training
|
||||
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
||||
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
if self.k_tfm > 1 and transform is None:
|
||||
raise ValueError(
|
||||
"Cannot augment the image {} times "
|
||||
"because transform is None".format(self.k_tfm)
|
||||
)
|
||||
|
||||
# Build transform that doesn't apply any data augmentation
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
to_tensor = []
|
||||
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
to_tensor += [T.ToTensor()]
|
||||
if "normalize" in cfg.INPUT.TRANSFORMS:
|
||||
normalize = T.Normalize(
|
||||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||||
)
|
||||
to_tensor += [normalize]
|
||||
self.to_tensor = T.Compose(to_tensor)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_source[idx]
|
||||
|
||||
if self.weight is None:
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath,
|
||||
"index": idx
|
||||
}
|
||||
else:
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath,
|
||||
"index": idx,
|
||||
"weight": self.weight[idx]
|
||||
}
|
||||
|
||||
# img_path = os.path.join('/'.join(item.impath.split('/')[:-1]),'mask',item.impath.split('/')[-1]) ('/').join(item.impath.split('/')[-2:])
|
||||
if self.cfg.DATASET.NAME in ['Food101','Caltech101','DescribableTextures','EuroSAT','UCF101']:
|
||||
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[-2:])))
|
||||
elif self.cfg.DATASET.NAME in ['SUN397']:
|
||||
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[7:])))
|
||||
elif self.cfg.DATASET.NAME in ['ImageNet']:
|
||||
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[7:])))
|
||||
elif self.cfg.DATASET.NAME in ['VOC12']:
|
||||
mask_path = os.path.join(self.data_path,'VOCdevkit/VOC2012/SegmentationClass_All',item.impath.split('/')[-1][:-3]+'png')
|
||||
mask = read_image(mask_path)
|
||||
else:
|
||||
mask = read_image(os.path.join(self.mask_path, item.impath.split('/')[-1]))
|
||||
img0 = read_image(item.impath)
|
||||
mask = mask.resize(img0.size)
|
||||
if self.transform is not None:
|
||||
if isinstance(self.transform, (list, tuple)):
|
||||
for i, tfm in enumerate(self.transform):
|
||||
img = self._transform_image(tfm, img0,img0)
|
||||
keyname = "img"
|
||||
if (i + 1) > 1:
|
||||
keyname += str(i + 1)
|
||||
output[keyname] = img
|
||||
else:
|
||||
img,mask = self._transform_image(self.transform, img0,mask)
|
||||
output["img"] = img
|
||||
output["mask"] = mask
|
||||
else:
|
||||
output["img"] = img0
|
||||
|
||||
if self.return_img0:
|
||||
output["img0"] = self.to_tensor(img0) # without any augmentation
|
||||
|
||||
return output
|
||||
|
||||
def _transform_image(self, tfm, img0,mask):
|
||||
img_list = []
|
||||
for k in range(self.k_tfm):
|
||||
img_list.append(tfm(img0,mask))
|
||||
|
||||
img = img_list
|
||||
if len(img_list) == 1:
|
||||
img = img_list[0][0]
|
||||
mask = img_list[0][1]
|
||||
|
||||
return img,mask
|
||||
|
||||
|
||||
class DatasetWrapper(TorchDataset):
|
||||
|
||||
def __init__(self, cfg, data_source,transform=None, is_train=False,weight=None):
|
||||
self.cfg = cfg
|
||||
self.data_source = data_source
|
||||
self.transform = transform # accept list (tuple) as input
|
||||
self.is_train = is_train
|
||||
self.mask_path = ('/').join(data_source[0].impath.split('/')[:-2])+'/mask'
|
||||
# Augmenting an image K>1 times is only allowed during training
|
||||
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
||||
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
if self.k_tfm > 1 and transform is None:
|
||||
raise ValueError(
|
||||
"Cannot augment the image {} times "
|
||||
"because transform is None".format(self.k_tfm)
|
||||
)
|
||||
|
||||
# Build transform that doesn't apply any data augmentation
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
to_tensor = []
|
||||
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
to_tensor += [T.ToTensor()]
|
||||
if "normalize" in cfg.INPUT.TRANSFORMS:
|
||||
normalize = T.Normalize(
|
||||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||||
)
|
||||
to_tensor += [normalize]
|
||||
self.to_tensor = T.Compose(to_tensor)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_source[idx]
|
||||
|
||||
if self.weight is None:
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath,
|
||||
"index": idx
|
||||
}
|
||||
else:
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath,
|
||||
"index": idx,
|
||||
"weight": self.weight[idx]
|
||||
}
|
||||
|
||||
# img0 = read_image(item.impath)
|
||||
img0 = read_image(item.impath)
|
||||
# img0 = img0.resize(mask.size)
|
||||
# mask = read_image(item.impath.split('/')[:-1].join('/'))
|
||||
if self.transform is not None:
|
||||
if isinstance(self.transform, (list, tuple)):
|
||||
for i, tfm in enumerate(self.transform):
|
||||
img = self._transform_image(tfm, img0)
|
||||
keyname = "img"
|
||||
if (i + 1) > 1:
|
||||
keyname += str(i + 1)
|
||||
output[keyname] = img
|
||||
else:
|
||||
img = self._transform_image(self.transform, img0)
|
||||
output["img"] = img
|
||||
output['mask'] = 1
|
||||
else:
|
||||
output["img"] = img0
|
||||
|
||||
if self.return_img0:
|
||||
output["img0"] = self.to_tensor(img0) # without any augmentation
|
||||
|
||||
return output
|
||||
|
||||
def _transform_image(self, tfm, img0):
|
||||
img_list = []
|
||||
|
||||
for k in range(self.k_tfm):
|
||||
img_list.append(tfm(img0))
|
||||
|
||||
img = img_list
|
||||
if len(img) == 1:
|
||||
img = img[0]
|
||||
|
||||
return img
|
||||
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DescribableTextures(DatasetBase):
|
||||
|
||||
dataset_dir = "dtd"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = self.read_and_split_data(self.image_dir)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
@staticmethod
|
||||
def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None):
|
||||
# The data are supposed to be organized into the following structure
|
||||
# =============
|
||||
# images/
|
||||
# dog/
|
||||
# cat/
|
||||
# horse/
|
||||
# =============
|
||||
categories = listdir_nohidden(image_dir)
|
||||
categories = [c for c in categories if c not in ignored]
|
||||
categories.sort()
|
||||
|
||||
p_tst = 1 - p_trn - p_val
|
||||
print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test")
|
||||
|
||||
def _collate(ims, y, c):
|
||||
items = []
|
||||
for im in ims:
|
||||
item = Datum(impath=im, label=y, classname=c) # is already 0-based
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
train, val, test = [], [], []
|
||||
for label, category in enumerate(categories):
|
||||
category_dir = os.path.join(image_dir, category)
|
||||
images = listdir_nohidden(category_dir)
|
||||
images = [os.path.join(category_dir, im) for im in images]
|
||||
random.shuffle(images)
|
||||
n_total = len(images)
|
||||
n_train = round(n_total * p_trn)
|
||||
n_val = round(n_total * p_val)
|
||||
n_test = n_total - n_train - n_val
|
||||
assert n_train > 0 and n_val > 0 and n_test > 0
|
||||
|
||||
if new_cnames is not None and category in new_cnames:
|
||||
category = new_cnames[category]
|
||||
|
||||
train.extend(_collate(images[:n_train], label, category))
|
||||
val.extend(_collate(images[n_train : n_train + n_val], label, category))
|
||||
test.extend(_collate(images[n_train + n_val :], label, category))
|
||||
|
||||
return train, val, test
|
||||
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from .dtd import DescribableTextures as DTD
|
||||
|
||||
NEW_CNAMES = {
|
||||
"AnnualCrop": "Annual Crop Land",
|
||||
"Forest": "Forest",
|
||||
"HerbaceousVegetation": "Herbaceous Vegetation Land",
|
||||
"Highway": "Highway or Road",
|
||||
"Industrial": "Industrial Buildings",
|
||||
"Pasture": "Pasture Land",
|
||||
"PermanentCrop": "Permanent Crop Land",
|
||||
"Residential": "Residential Buildings",
|
||||
"River": "River",
|
||||
"SeaLake": "Sea or Lake",
|
||||
}
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class EuroSAT(DatasetBase):
|
||||
|
||||
dataset_dir = "eurosat"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "2750")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def update_classname(self, dataset_old):
|
||||
dataset_new = []
|
||||
for item_old in dataset_old:
|
||||
cname_old = item_old.classname
|
||||
cname_new = NEW_CNAMES[cname_old]
|
||||
item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new)
|
||||
dataset_new.append(item_new)
|
||||
return dataset_new
|
||||
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class FGVCAircraft(DatasetBase):
|
||||
|
||||
dataset_dir = "fgvc_aircraft"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
classnames = []
|
||||
with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
classnames.append(line.strip())
|
||||
cname2lab = {c: i for i, c in enumerate(classnames)}
|
||||
|
||||
train = self.read_data(cname2lab, "images_variant_train.txt")
|
||||
val = self.read_data(cname2lab, "images_variant_val.txt")
|
||||
test = self.read_data(cname2lab, "images_variant_test.txt")
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, split_file):
|
||||
filepath = os.path.join(self.dataset_dir, split_file)
|
||||
items = []
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
imname = line[0] + ".jpg"
|
||||
classname = " ".join(line[1:])
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = cname2lab[classname]
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from .dtd import DescribableTextures as DTD
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Food101(DatasetBase):
|
||||
|
||||
dataset_dir = "food-101"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = DTD.read_and_split_data(self.image_dir)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from random import sample
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNet(DatasetBase):
|
||||
|
||||
dataset_dir = "imagenet"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.preprocessed):
|
||||
with open(self.preprocessed, "rb") as f:
|
||||
preprocessed = pickle.load(f)
|
||||
train = preprocessed["train"]
|
||||
test = preprocessed["test"]
|
||||
else:
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = self.read_classnames(text_file)
|
||||
train = self.read_data(classnames, "train")
|
||||
# Follow standard practice to perform evaluation on the val set
|
||||
# Also used as the val set (so evaluate the last-step model)
|
||||
test = self.read_data(classnames, "val")
|
||||
|
||||
preprocessed = {"train": train, "test": test}
|
||||
with open(self.preprocessed, "wb") as f:
|
||||
pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1000:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train = data["train"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
data = {"train": train}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, test = OxfordPets.subsample_classes(train, test, subsample=subsample)
|
||||
|
||||
|
||||
super().__init__(train_x=sample(train,int(len(train)*0.8)), val=sample(test,5000), test=test)
|
||||
|
||||
@staticmethod
|
||||
def read_classnames(text_file):
|
||||
"""Return a dictionary containing
|
||||
key-value pairs of <folder name>: <class name>.
|
||||
"""
|
||||
classnames = OrderedDict()
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
folder = line[0]
|
||||
classname = " ".join(line[1:])
|
||||
classnames[folder] = classname
|
||||
return classnames
|
||||
|
||||
def read_data(self, classnames, split_dir):
|
||||
split_dir = os.path.join(self.image_dir, split_dir)
|
||||
folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders): ##sub evaluation
|
||||
imnames = listdir_nohidden(os.path.join(split_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(split_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from .imagenet import ImageNet
|
||||
|
||||
TO_BE_IGNORED = ["README.txt"]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNetA(DatasetBase):
|
||||
"""ImageNet-A(dversarial).
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir = "imagenet-adversarial"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "imagenet-a")
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = ImageNet.read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = listdir_nohidden(image_dir, sort=True)
|
||||
folders = [f for f in folders if f not in TO_BE_IGNORED]
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(image_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(image_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from .imagenet import ImageNet
|
||||
|
||||
TO_BE_IGNORED = ["README.txt"]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNetR(DatasetBase):
|
||||
"""ImageNet-R(endition).
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir = "imagenet-rendition"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "imagenet-r")
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = ImageNet.read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = listdir_nohidden(image_dir, sort=True)
|
||||
folders = [f for f in folders if f not in TO_BE_IGNORED]
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(image_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(image_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from .imagenet import ImageNet
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNetSketch(DatasetBase):
|
||||
"""ImageNet-Sketch.
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir = "imagenet-sketch"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = ImageNet.read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = listdir_nohidden(image_dir, sort=True)
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(image_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(image_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from .imagenet import ImageNet
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNetV2(DatasetBase):
|
||||
"""ImageNetV2.
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir = "imagenetv2"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
image_dir = "imagenetv2-matched-frequency-format-val"
|
||||
self.image_dir = os.path.join(self.dataset_dir, image_dir)
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = ImageNet.read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = list(classnames.keys())
|
||||
items = []
|
||||
|
||||
for label in range(1000):
|
||||
class_dir = os.path.join(image_dir, str(label))
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
folder = folders[label]
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,567 @@
|
||||
import torch
|
||||
from torchvision.transforms import RandomResizedCrop,InterpolationMode
|
||||
from torchvision.transforms import functional as F
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from torchvision.transforms import (
|
||||
Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,
|
||||
RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,
|
||||
RandomHorizontalFlip
|
||||
)
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from dassl.data.transforms.transforms import SVHNPolicy, CIFAR10Policy, ImageNetPolicy
|
||||
from dassl.data.transforms.transforms import RandAugment, RandAugment2, RandAugmentFixMatch
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
class RandomResizedCropPair(RandomResizedCrop):
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR):
|
||||
super(RandomResizedCropPair, self).__init__(size, scale, ratio, interpolation)
|
||||
|
||||
def __call__(self, img,mask):
|
||||
i,j,h,w = self.get_params(img,self.scale,self.ratio)
|
||||
return F.resized_crop(img,i,j,h,w,self.size,self.interpolation),F.resized_crop(mask,i,j,h,w,self.size,self.interpolation)
|
||||
|
||||
|
||||
class ComposePair:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img,mask):
|
||||
|
||||
for t in self.transforms:
|
||||
if isinstance(t,Normalize):
|
||||
img = t(img)
|
||||
elif isinstance(t,ToTensor):
|
||||
img = t(img)
|
||||
mask = torch.from_numpy(np.array(mask,dtype=np.float16)).permute(2,0,1)[:1]
|
||||
|
||||
|
||||
###design the mask split
|
||||
mask[mask==255] = 0
|
||||
mask[mask > 1] = 1
|
||||
else:
|
||||
img,mask = t(img,mask)
|
||||
|
||||
return img,mask
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += ' {0}'.format(t)
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
class RandomHorizontalFlipPair(RandomHorizontalFlip):
|
||||
def __init__(self, p=0.5):
|
||||
super().__init__(p)
|
||||
|
||||
def __call__(self, img, mask):
|
||||
if torch.rand(1) < self.p:
|
||||
return F.hflip(img),F.hflip(mask)
|
||||
return img,mask
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
AVAI_CHOICES = [
|
||||
"random_flip",
|
||||
"random_resized_crop",
|
||||
"normalize",
|
||||
"instance_norm",
|
||||
"random_crop",
|
||||
"random_translation",
|
||||
"center_crop", # This has become a default operation during testing
|
||||
"cutout",
|
||||
"imagenet_policy",
|
||||
"cifar10_policy",
|
||||
"svhn_policy",
|
||||
"randaugment",
|
||||
"randaugment_fixmatch",
|
||||
"randaugment2",
|
||||
"gaussian_noise",
|
||||
"colorjitter",
|
||||
"randomgrayscale",
|
||||
"gaussian_blur",
|
||||
|
||||
"random_flip_pair",
|
||||
"random_resized_crop_pair",
|
||||
]
|
||||
|
||||
INTERPOLATION_MODES = {
|
||||
"bilinear": InterpolationMode.BILINEAR,
|
||||
"bicubic": InterpolationMode.BICUBIC,
|
||||
"nearest": InterpolationMode.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
class Random2DTranslation:
|
||||
"""Given an image of (height, width), we resize it to
|
||||
(height*1.125, width*1.125), and then perform random cropping.
|
||||
|
||||
Args:
|
||||
height (int): target image height.
|
||||
width (int): target image width.
|
||||
p (float, optional): probability that this operation takes place.
|
||||
Default is 0.5.
|
||||
interpolation (int, optional): desired interpolation. Default is
|
||||
``torchvision.transforms.functional.InterpolationMode.BILINEAR``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, height, width, p=0.5, interpolation=InterpolationMode.BILINEAR
|
||||
):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.p = p
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return F.resize(
|
||||
img=img,
|
||||
size=[self.height, self.width],
|
||||
interpolation=self.interpolation
|
||||
)
|
||||
|
||||
new_width = int(round(self.width * 1.125))
|
||||
new_height = int(round(self.height * 1.125))
|
||||
resized_img = F.resize(
|
||||
img=img,
|
||||
size=[new_height, new_width],
|
||||
interpolation=self.interpolation
|
||||
)
|
||||
x_maxrange = new_width - self.width
|
||||
y_maxrange = new_height - self.height
|
||||
x1 = int(round(random.uniform(0, x_maxrange)))
|
||||
y1 = int(round(random.uniform(0, y_maxrange)))
|
||||
croped_img = F.crop(
|
||||
img=resized_img,
|
||||
top=y1,
|
||||
left=x1,
|
||||
height=self.height,
|
||||
width=self.width
|
||||
)
|
||||
|
||||
return croped_img
|
||||
|
||||
|
||||
class InstanceNormalization:
|
||||
"""Normalize data using per-channel mean and standard deviation.
|
||||
|
||||
Reference:
|
||||
- Ulyanov et al. Instance normalization: The missing in- gredient
|
||||
for fast stylization. ArXiv 2016.
|
||||
- Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.
|
||||
ICLR 2018.
|
||||
"""
|
||||
|
||||
def __init__(self, eps=1e-8):
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, img):
|
||||
C, H, W = img.shape
|
||||
img_re = img.reshape(C, H * W)
|
||||
mean = img_re.mean(1).view(C, 1, 1)
|
||||
std = img_re.std(1).view(C, 1, 1)
|
||||
return (img-mean) / (std + self.eps)
|
||||
|
||||
|
||||
class Cutout:
|
||||
"""Randomly mask out one or more patches from an image.
|
||||
|
||||
https://github.com/uoguelph-mlrg/Cutout
|
||||
|
||||
Args:
|
||||
n_holes (int, optional): number of patches to cut out
|
||||
of each image. Default is 1.
|
||||
length (int, optinal): length (in pixels) of each square
|
||||
patch. Default is 16.
|
||||
"""
|
||||
|
||||
def __init__(self, n_holes=1, length=16):
|
||||
self.n_holes = n_holes
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): tensor image of size (C, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor: image with n_holes of dimension
|
||||
length x length cut out of it.
|
||||
"""
|
||||
h = img.size(1)
|
||||
w = img.size(2)
|
||||
|
||||
mask = np.ones((h, w), np.float32)
|
||||
|
||||
for n in range(self.n_holes):
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
return img * mask
|
||||
|
||||
|
||||
class GaussianNoise:
|
||||
"""Add gaussian noise."""
|
||||
|
||||
def __init__(self, mean=0, std=0.15, p=0.5):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return img
|
||||
noise = torch.randn(img.size()) * self.std + self.mean
|
||||
return img + noise
|
||||
|
||||
|
||||
def build_transform(cfg, is_train=True, choices=None):
|
||||
"""Build transformation function.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): config.
|
||||
is_train (bool, optional): for training (True) or test (False).
|
||||
Default is True.
|
||||
choices (list, optional): list of strings which will overwrite
|
||||
cfg.INPUT.TRANSFORMS if given. Default is None.
|
||||
"""
|
||||
if cfg.INPUT.NO_TRANSFORM:
|
||||
print("Note: no transform is applied!")
|
||||
return None
|
||||
|
||||
if choices is None:
|
||||
choices = cfg.INPUT.TRANSFORMS
|
||||
|
||||
for choice in choices:
|
||||
assert choice in AVAI_CHOICES
|
||||
|
||||
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
|
||||
|
||||
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||
|
||||
if is_train:
|
||||
return _build_transform_train(cfg, choices, target_size, normalize)
|
||||
else:
|
||||
return _build_transform_test(cfg, choices, target_size, normalize)
|
||||
|
||||
|
||||
def build_transform_pair(cfg, is_train=True, choices=None):
|
||||
"""Build transformation function.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): config.
|
||||
is_train (bool, optional): for training (True) or test (False).
|
||||
Default is True.
|
||||
choices (list, optional): list of strings which will overwrite
|
||||
cfg.INPUT.TRANSFORMS if given. Default is None.
|
||||
"""
|
||||
if cfg.INPUT.NO_TRANSFORM:
|
||||
print("Note: no transform is applied!")
|
||||
return None
|
||||
|
||||
if choices is None:
|
||||
choices = cfg.INPUT.TRANSFORMS
|
||||
|
||||
for choice in choices:
|
||||
assert choice in AVAI_CHOICES
|
||||
|
||||
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
|
||||
|
||||
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||
|
||||
if is_train:
|
||||
return _build_transform_train_pair(cfg, choices, target_size, normalize)
|
||||
else:
|
||||
return _build_transform_test(cfg, choices, target_size, normalize)
|
||||
|
||||
def _build_transform_train_pair(cfg, choices, target_size, normalize):
|
||||
print("Building transform_train_pair")
|
||||
tfm_train = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
input_size = cfg.INPUT.SIZE
|
||||
|
||||
# Make sure the image size matches the target size
|
||||
conditions = []
|
||||
conditions += ["random_crop" not in choices]
|
||||
conditions += ["random_resized_crop" not in choices]
|
||||
if all(conditions):
|
||||
print(f"+ resize to {target_size}")
|
||||
tfm_train += [Resize(input_size, interpolation=interp_mode)]
|
||||
|
||||
# if "random_translation" in choices:
|
||||
# print("+ random translation")
|
||||
# tfm_train += [Random2DTranslation(input_size[0], input_size[1])]
|
||||
#
|
||||
# if "random_crop" in choices:
|
||||
# crop_padding = cfg.INPUT.CROP_PADDING
|
||||
# print(f"+ random crop (padding = {crop_padding})")
|
||||
# tfm_train += [RandomCrop(input_size, padding=crop_padding)]
|
||||
|
||||
if "random_resized_crop" in choices:
|
||||
s_ = cfg.INPUT.RRCROP_SCALE
|
||||
print(f"+ random resized crop pair (size={input_size}, scale={s_})")
|
||||
tfm_train += [
|
||||
RandomResizedCropPair(input_size, scale=s_, interpolation=interp_mode)
|
||||
]
|
||||
|
||||
if "random_flip" in choices:
|
||||
print("+ random flip pair")
|
||||
tfm_train += [RandomHorizontalFlipPair()]
|
||||
|
||||
if "imagenet_policy" in choices:
|
||||
print("+ imagenet policy")
|
||||
tfm_train += [ImageNetPolicy()]
|
||||
|
||||
if "cifar10_policy" in choices:
|
||||
print("+ cifar10 policy")
|
||||
tfm_train += [CIFAR10Policy()]
|
||||
|
||||
if "svhn_policy" in choices:
|
||||
print("+ svhn policy")
|
||||
tfm_train += [SVHNPolicy()]
|
||||
|
||||
if "randaugment" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
m_ = cfg.INPUT.RANDAUGMENT_M
|
||||
print(f"+ randaugment (n={n_}, m={m_})")
|
||||
tfm_train += [RandAugment(n_, m_)]
|
||||
|
||||
if "randaugment_fixmatch" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print(f"+ randaugment_fixmatch (n={n_})")
|
||||
tfm_train += [RandAugmentFixMatch(n_)]
|
||||
|
||||
if "randaugment2" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print(f"+ randaugment2 (n={n_})")
|
||||
tfm_train += [RandAugment2(n_)]
|
||||
|
||||
if "colorjitter" in choices:
|
||||
b_ = cfg.INPUT.COLORJITTER_B
|
||||
c_ = cfg.INPUT.COLORJITTER_C
|
||||
s_ = cfg.INPUT.COLORJITTER_S
|
||||
h_ = cfg.INPUT.COLORJITTER_H
|
||||
print(
|
||||
f"+ color jitter (brightness={b_}, "
|
||||
f"contrast={c_}, saturation={s_}, hue={h_})"
|
||||
)
|
||||
tfm_train += [
|
||||
ColorJitter(
|
||||
brightness=b_,
|
||||
contrast=c_,
|
||||
saturation=s_,
|
||||
hue=h_,
|
||||
)
|
||||
]
|
||||
|
||||
if "randomgrayscale" in choices:
|
||||
print("+ random gray scale")
|
||||
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
|
||||
|
||||
if "gaussian_blur" in choices:
|
||||
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
|
||||
gb_k, gb_p = cfg.INPUT.GB_K, cfg.INPUT.GB_P
|
||||
tfm_train += [RandomApply([GaussianBlur(gb_k)], p=gb_p)]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_train += [ToTensor()]
|
||||
|
||||
if "cutout" in choices:
|
||||
cutout_n = cfg.INPUT.CUTOUT_N
|
||||
cutout_len = cfg.INPUT.CUTOUT_LEN
|
||||
print(f"+ cutout (n_holes={cutout_n}, length={cutout_len})")
|
||||
tfm_train += [Cutout(cutout_n, cutout_len)]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
f"+ normalization (mean={cfg.INPUT.PIXEL_MEAN}, std={cfg.INPUT.PIXEL_STD})"
|
||||
)
|
||||
tfm_train += [normalize]
|
||||
|
||||
if "gaussian_noise" in choices:
|
||||
print(
|
||||
f"+ gaussian noise (mean={cfg.INPUT.GN_MEAN}, std={cfg.INPUT.GN_STD})"
|
||||
)
|
||||
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_train += [InstanceNormalization()]
|
||||
|
||||
tfm_train = ComposePair(tfm_train)
|
||||
|
||||
|
||||
return tfm_train
|
||||
|
||||
|
||||
def _build_transform_train(cfg, choices, target_size, normalize):
|
||||
print("Building transform_train")
|
||||
tfm_train = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
input_size = cfg.INPUT.SIZE
|
||||
|
||||
# Make sure the image size matches the target size
|
||||
conditions = []
|
||||
conditions += ["random_crop" not in choices]
|
||||
conditions += ["random_resized_crop" not in choices]
|
||||
if all(conditions):
|
||||
print(f"+ resize to {target_size}")
|
||||
tfm_train += [Resize(input_size, interpolation=interp_mode)]
|
||||
|
||||
if "random_translation" in choices:
|
||||
print("+ random translation")
|
||||
tfm_train += [Random2DTranslation(input_size[0], input_size[1])]
|
||||
|
||||
if "random_crop" in choices:
|
||||
crop_padding = cfg.INPUT.CROP_PADDING
|
||||
print(f"+ random crop (padding = {crop_padding})")
|
||||
tfm_train += [RandomCrop(input_size, padding=crop_padding)]
|
||||
|
||||
if "random_resized_crop" in choices:
|
||||
s_ = cfg.INPUT.RRCROP_SCALE
|
||||
print(f"+ random resized crop (size={input_size}, scale={s_})")
|
||||
tfm_train += [
|
||||
RandomResizedCrop(input_size, scale=s_, interpolation=interp_mode)
|
||||
]
|
||||
|
||||
if "random_flip" in choices:
|
||||
print("+ random flip")
|
||||
tfm_train += [RandomHorizontalFlip()]
|
||||
|
||||
if "imagenet_policy" in choices:
|
||||
print("+ imagenet policy")
|
||||
tfm_train += [ImageNetPolicy()]
|
||||
|
||||
if "cifar10_policy" in choices:
|
||||
print("+ cifar10 policy")
|
||||
tfm_train += [CIFAR10Policy()]
|
||||
|
||||
if "svhn_policy" in choices:
|
||||
print("+ svhn policy")
|
||||
tfm_train += [SVHNPolicy()]
|
||||
|
||||
if "randaugment" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
m_ = cfg.INPUT.RANDAUGMENT_M
|
||||
print(f"+ randaugment (n={n_}, m={m_})")
|
||||
tfm_train += [RandAugment(n_, m_)]
|
||||
|
||||
if "randaugment_fixmatch" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print(f"+ randaugment_fixmatch (n={n_})")
|
||||
tfm_train += [RandAugmentFixMatch(n_)]
|
||||
|
||||
if "randaugment2" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print(f"+ randaugment2 (n={n_})")
|
||||
tfm_train += [RandAugment2(n_)]
|
||||
|
||||
if "colorjitter" in choices:
|
||||
b_ = cfg.INPUT.COLORJITTER_B
|
||||
c_ = cfg.INPUT.COLORJITTER_C
|
||||
s_ = cfg.INPUT.COLORJITTER_S
|
||||
h_ = cfg.INPUT.COLORJITTER_H
|
||||
print(
|
||||
f"+ color jitter (brightness={b_}, "
|
||||
f"contrast={c_}, saturation={s_}, hue={h_})"
|
||||
)
|
||||
tfm_train += [
|
||||
ColorJitter(
|
||||
brightness=b_,
|
||||
contrast=c_,
|
||||
saturation=s_,
|
||||
hue=h_,
|
||||
)
|
||||
]
|
||||
|
||||
if "randomgrayscale" in choices:
|
||||
print("+ random gray scale")
|
||||
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
|
||||
|
||||
if "gaussian_blur" in choices:
|
||||
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
|
||||
gb_k, gb_p = cfg.INPUT.GB_K, cfg.INPUT.GB_P
|
||||
tfm_train += [RandomApply([GaussianBlur(gb_k)], p=gb_p)]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_train += [ToTensor()]
|
||||
|
||||
if "cutout" in choices:
|
||||
cutout_n = cfg.INPUT.CUTOUT_N
|
||||
cutout_len = cfg.INPUT.CUTOUT_LEN
|
||||
print(f"+ cutout (n_holes={cutout_n}, length={cutout_len})")
|
||||
tfm_train += [Cutout(cutout_n, cutout_len)]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
f"+ normalization (mean={cfg.INPUT.PIXEL_MEAN}, std={cfg.INPUT.PIXEL_STD})"
|
||||
)
|
||||
tfm_train += [normalize]
|
||||
|
||||
if "gaussian_noise" in choices:
|
||||
print(
|
||||
f"+ gaussian noise (mean={cfg.INPUT.GN_MEAN}, std={cfg.INPUT.GN_STD})"
|
||||
)
|
||||
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_train += [InstanceNormalization()]
|
||||
|
||||
tfm_train = Compose(tfm_train)
|
||||
|
||||
return tfm_train
|
||||
|
||||
|
||||
def _build_transform_test(cfg, choices, target_size, normalize):
|
||||
print("Building transform_test")
|
||||
tfm_test = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
input_size = cfg.INPUT.SIZE
|
||||
|
||||
print(f"+ resize the smaller edge to {max(input_size)}")
|
||||
tfm_test += [Resize(max(input_size), interpolation=interp_mode)]
|
||||
|
||||
print(f"+ {target_size} center crop")
|
||||
tfm_test += [CenterCrop(input_size)]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_test += [ToTensor()]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
f"+ normalization (mean={cfg.INPUT.PIXEL_MEAN}, std={cfg.INPUT.PIXEL_STD})"
|
||||
)
|
||||
tfm_test += [normalize]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_test += [InstanceNormalization()]
|
||||
|
||||
tfm_test = Compose(tfm_test)
|
||||
|
||||
return tfm_test
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
from scipy.io import loadmat
|
||||
from collections import defaultdict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import read_json, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OxfordFlowers(DatasetBase):
|
||||
|
||||
dataset_dir = "oxford_flowers"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "jpg")
|
||||
self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat")
|
||||
self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = self.read_data()
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self):
|
||||
tracker = defaultdict(list)
|
||||
label_file = loadmat(self.label_file)["labels"][0]
|
||||
for i, label in enumerate(label_file):
|
||||
imname = f"image_{str(i + 1).zfill(5)}.jpg"
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = int(label)
|
||||
tracker[label].append(impath)
|
||||
|
||||
print("Splitting data into 50% train, 20% val, and 30% test")
|
||||
|
||||
def _collate(ims, y, c):
|
||||
items = []
|
||||
for im in ims:
|
||||
item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
lab2cname = read_json(self.lab2cname_file)
|
||||
train, val, test = [], [], []
|
||||
for label, impaths in tracker.items():
|
||||
random.shuffle(impaths)
|
||||
n_total = len(impaths)
|
||||
n_train = round(n_total * 0.5)
|
||||
n_val = round(n_total * 0.2)
|
||||
n_test = n_total - n_train - n_val
|
||||
assert n_train > 0 and n_val > 0 and n_test > 0
|
||||
cname = lab2cname[str(label)]
|
||||
train.extend(_collate(impaths[:n_train], label, cname))
|
||||
val.extend(_collate(impaths[n_train : n_train + n_val], label, cname))
|
||||
test.extend(_collate(impaths[n_train + n_val :], label, cname))
|
||||
|
||||
return train, val, test
|
||||
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
import pickle
|
||||
import math
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import read_json, write_json, mkdir_if_missing
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OxfordPets(DatasetBase):
|
||||
|
||||
dataset_dir = "oxford_pets"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.anno_dir = os.path.join(self.dataset_dir, "annotations")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = self.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
trainval = self.read_data(split_file="trainval.txt")
|
||||
test = self.read_data(split_file="test.txt")
|
||||
train, val = self.split_trainval(trainval)
|
||||
self.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = self.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, split_file):
|
||||
filepath = os.path.join(self.anno_dir, split_file)
|
||||
items = []
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
imname, label, species, _ = line.split(" ")
|
||||
breed = imname.split("_")[:-1]
|
||||
breed = "_".join(breed)
|
||||
breed = breed.lower()
|
||||
imname += ".jpg"
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = int(label) - 1 # convert to 0-based index
|
||||
item = Datum(impath=impath, label=label, classname=breed)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def split_trainval(trainval, p_val=0.2):
|
||||
p_trn = 1 - p_val
|
||||
print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val")
|
||||
tracker = defaultdict(list)
|
||||
for idx, item in enumerate(trainval):
|
||||
label = item.label
|
||||
tracker[label].append(idx)
|
||||
|
||||
train, val = [], []
|
||||
for label, idxs in tracker.items():
|
||||
n_val = round(len(idxs) * p_val)
|
||||
assert n_val > 0
|
||||
random.shuffle(idxs)
|
||||
for n, idx in enumerate(idxs):
|
||||
item = trainval[idx]
|
||||
if n < n_val:
|
||||
val.append(item)
|
||||
else:
|
||||
train.append(item)
|
||||
|
||||
return train, val
|
||||
|
||||
@staticmethod
|
||||
def save_split(train, val, test, filepath, path_prefix):
|
||||
def _extract(items):
|
||||
out = []
|
||||
for item in items:
|
||||
impath = item.impath
|
||||
label = item.label
|
||||
classname = item.classname
|
||||
impath = impath.replace(path_prefix, "")
|
||||
if impath.startswith("/"):
|
||||
impath = impath[1:]
|
||||
out.append((impath, label, classname))
|
||||
return out
|
||||
|
||||
train = _extract(train)
|
||||
val = _extract(val)
|
||||
test = _extract(test)
|
||||
|
||||
split = {"train": train, "val": val, "test": test}
|
||||
|
||||
write_json(split, filepath)
|
||||
print(f"Saved split to {filepath}")
|
||||
|
||||
@staticmethod
|
||||
def read_split(filepath, path_prefix):
|
||||
def _convert(items):
|
||||
out = []
|
||||
for impath, label, classname in items:
|
||||
impath = os.path.join(path_prefix, impath)
|
||||
item = Datum(impath=impath, label=int(label), classname=classname)
|
||||
out.append(item)
|
||||
return out
|
||||
|
||||
print(f"Reading split from {filepath}")
|
||||
split = read_json(filepath)
|
||||
train = _convert(split["train"])
|
||||
val = _convert(split["val"])
|
||||
test = _convert(split["test"])
|
||||
|
||||
return train, val, test
|
||||
|
||||
@staticmethod
|
||||
def subsample_classes(*args, subsample="all"):
|
||||
"""Divide classes into two groups. The first group
|
||||
represents base classes while the second group represents
|
||||
new classes.
|
||||
|
||||
Args:
|
||||
args: a list of datasets, e.g. train, val and test.
|
||||
subsample (str): what classes to subsample.
|
||||
"""
|
||||
assert subsample in ["all", "base", "new"]
|
||||
|
||||
if subsample == "all":
|
||||
return args
|
||||
|
||||
dataset = args[0]
|
||||
labels = set()
|
||||
for item in dataset:
|
||||
labels.add(item.label)
|
||||
labels = list(labels)
|
||||
labels.sort()
|
||||
n = len(labels)
|
||||
# Divide classes into two halves
|
||||
m = math.ceil(n / 2)
|
||||
|
||||
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
|
||||
if subsample == "base":
|
||||
selected = labels[:m] # take the first half
|
||||
else:
|
||||
selected = labels[m:] # take the second half
|
||||
relabeler = {y: y_new for y_new, y in enumerate(selected)}
|
||||
|
||||
output = []
|
||||
for dataset in args:
|
||||
dataset_new = []
|
||||
for item in dataset:
|
||||
if item.label not in selected:
|
||||
continue
|
||||
item_new = Datum(
|
||||
impath=item.impath,
|
||||
label=relabeler[item.label],
|
||||
classname=item.classname
|
||||
)
|
||||
dataset_new.append(item_new)
|
||||
output.append(dataset_new)
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import math
|
||||
CAT_LIST = ['aeroplane',
|
||||
'bicycle',
|
||||
'bird',
|
||||
'boat',
|
||||
'bottle',
|
||||
'bus',
|
||||
'car',
|
||||
'cat',
|
||||
'chair',
|
||||
'cow',
|
||||
'table',
|
||||
'dog',
|
||||
'horse',
|
||||
'motorbike',
|
||||
'person',
|
||||
'plant',
|
||||
'sheep',
|
||||
'sofa',
|
||||
'train',
|
||||
'tvmonitor']
|
||||
|
||||
CAT_LIST_TO_NAME = dict(zip(range(len(CAT_LIST)) ,CAT_LIST))
|
||||
|
||||
|
||||
def _collate(ims, y, c):
|
||||
return Datum(impath=ims, label=y, classname=c)
|
||||
|
||||
def load_img_name_list(dataset_path):
|
||||
|
||||
img_gt_name_list = open(dataset_path).readlines()
|
||||
img_name_list = [img_gt_name.strip() for img_gt_name in img_gt_name_list]
|
||||
|
||||
return img_name_list
|
||||
|
||||
def load_image_label_list_from_npy(data_root,img_name_list, label_file_path=None):
|
||||
if label_file_path is None:
|
||||
label_file_path = 'voc12/cls_labels.npy'
|
||||
cls_labels_dict = np.load(label_file_path, allow_pickle=True).item()
|
||||
label_list = []
|
||||
data_dtm = []
|
||||
|
||||
for id in img_name_list:
|
||||
if id not in cls_labels_dict.keys():
|
||||
img_name = id + '.jpg'
|
||||
else:
|
||||
img_name = id
|
||||
label = cls_labels_dict[img_name]
|
||||
label_idx = np.where(label==1)[0]
|
||||
class_name = [CAT_LIST[idx] for idx in range(len(label_idx))]
|
||||
data_dtm.append(_collate(os.path.join(data_root,img_name+'.jpg'),label,class_name))
|
||||
|
||||
return data_dtm
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VOC12(DatasetBase):
|
||||
dataset_dir = "voc12data"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir,'VOCdevkit/VOC2012/JPEGImages')
|
||||
train_img_name_list_path = os.path.join('voc12/train_aug_id.txt')
|
||||
val_img_name_list_path = os.path.join('voc12/val_id.txt')
|
||||
|
||||
train = load_image_label_list_from_npy(self.image_dir,load_img_name_list(train_img_name_list_path))
|
||||
val = load_image_label_list_from_npy(self.image_dir,load_img_name_list(val_img_name_list_path))
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train = data["train"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
data = {"train": train}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val = self.subsample_classes(train, val, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=val)
|
||||
|
||||
@staticmethod
|
||||
def subsample_classes(*args, subsample="all"):
|
||||
"""Divide classes into two groups. The first group
|
||||
represents base classes while the second group represents
|
||||
new classes.
|
||||
|
||||
Args:
|
||||
args: a list of datasets, e.g. train, val and test.
|
||||
subsample (str): what classes to subsample.
|
||||
"""
|
||||
assert subsample in ["all", "base", "new"]
|
||||
|
||||
if subsample == "all":
|
||||
return args
|
||||
|
||||
dataset = args[0]
|
||||
labels = set()
|
||||
for item in dataset:
|
||||
label_idx = random.choices(np.where(item.label == 1)[0])[0]
|
||||
labels.add(label_idx)
|
||||
labels = list(labels)
|
||||
labels.sort()
|
||||
n = len(labels)
|
||||
# Divide classes into two halves
|
||||
m = math.ceil(n / 2)
|
||||
|
||||
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
|
||||
if subsample == "base":
|
||||
selected = labels[:m] # take the first half
|
||||
else:
|
||||
selected = labels[m:] # take the second half
|
||||
relabeler = {y: y_new for y_new, y in enumerate(selected)}
|
||||
|
||||
output = []
|
||||
for dataset in args:
|
||||
dataset_new = []
|
||||
for item in dataset:
|
||||
label_idx = random.choices(np.where(item.label == 1)[0])[0]
|
||||
if label_idx not in selected:
|
||||
continue
|
||||
|
||||
item_new = Datum(
|
||||
impath=item.impath,
|
||||
label=item.label,
|
||||
classname=item.classname
|
||||
)
|
||||
dataset_new.append(item_new)
|
||||
output.append(dataset_new)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_num_classes(data_source):
|
||||
"""Count number of classes.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
return len(CAT_LIST)
|
||||
|
||||
@staticmethod
|
||||
def get_lab2cname(data_source):
|
||||
"""Get a label-to-classname mapping (dict).
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
return CAT_LIST_TO_NAME, CAT_LIST
|
||||
|
||||
def split_dataset_by_label(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into class-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
one_hot_label = item.label
|
||||
label_idx = random.choices(np.where(one_hot_label==1)[0])[0]
|
||||
output[label_idx].append(item)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def read_classnames(text_file):
|
||||
"""Return a dictionary containing
|
||||
key-value pairs of <folder name>: <class name>.
|
||||
"""
|
||||
classnames = OrderedDict()
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
folder = line[0]
|
||||
classname = " ".join(line[1:])
|
||||
classnames[folder] = classname
|
||||
return classnames
|
||||
|
||||
def read_data(self, classnames, split_dir):
|
||||
split_dir = os.path.join(self.image_dir, split_dir)
|
||||
folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(split_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(split_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
import pickle
|
||||
from scipy.io import loadmat
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
import numpy as np
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class StanfordCars(DatasetBase):
|
||||
|
||||
dataset_dir = "stanford_cars"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir)
|
||||
else:
|
||||
trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat")
|
||||
test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat")
|
||||
meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat")
|
||||
trainval = self.read_data("cars_train", trainval_file, meta_file)
|
||||
test = self.read_data("cars_test", test_file, meta_file)
|
||||
train, val = OxfordPets.split_trainval(trainval)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, image_dir, anno_file, meta_file):
|
||||
anno_file = loadmat(anno_file)["annotations"][0]
|
||||
meta_file = loadmat(meta_file)["class_names"][0]
|
||||
items = []
|
||||
|
||||
for i in range(len(anno_file)):
|
||||
imname = anno_file[i]["fname"][0]
|
||||
impath = os.path.join(self.dataset_dir, image_dir, imname)
|
||||
label = anno_file[i]["class"][0, 0]
|
||||
label = int(label) - 1 # convert to 0-based index
|
||||
classname = meta_file[label][0]
|
||||
names = classname.split(" ")
|
||||
year = names.pop(-1)
|
||||
names.insert(0, year)
|
||||
classname = " ".join(names)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SUN397(DatasetBase):
|
||||
|
||||
dataset_dir = "sun397"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "SUN397")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
classnames = []
|
||||
with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()[1:] # remove /
|
||||
classnames.append(line)
|
||||
cname2lab = {c: i for i, c in enumerate(classnames)}
|
||||
trainval = self.read_data(cname2lab, "Training_01.txt")
|
||||
test = self.read_data(cname2lab, "Testing_01.txt")
|
||||
train, val = OxfordPets.split_trainval(trainval)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, text_file):
|
||||
text_file = os.path.join(self.dataset_dir, text_file)
|
||||
items = []
|
||||
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
imname = line.strip()[1:] # remove /
|
||||
classname = os.path.dirname(imname)
|
||||
label = cname2lab[classname]
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
|
||||
names = classname.split("/")[1:] # remove 1st letter
|
||||
names = names[::-1] # put words like indoor/outdoor at first
|
||||
classname = " ".join(names)
|
||||
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class UCF101(DatasetBase):
|
||||
|
||||
dataset_dir = "ucf101"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
cname2lab = {}
|
||||
filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt")
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
label, classname = line.strip().split(" ")
|
||||
label = int(label) - 1 # conver to 0-based index
|
||||
cname2lab[classname] = label
|
||||
|
||||
trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt")
|
||||
test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt")
|
||||
train, val = OxfordPets.split_trainval(trainval)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, text_file):
|
||||
text_file = os.path.join(self.dataset_dir, text_file)
|
||||
items = []
|
||||
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")[0] # trainlist: filename, label
|
||||
action, filename = line.split("/")
|
||||
label = cname2lab[action]
|
||||
|
||||
elements = re.findall("[A-Z][^A-Z]*", action)
|
||||
renamed_action = "_".join(elements)
|
||||
|
||||
filename = filename.replace(".avi", ".jpg")
|
||||
impath = os.path.join(self.image_dir, renamed_action, filename)
|
||||
|
||||
item = Datum(impath=impath, label=label, classname=renamed_action)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
Reference in New Issue
Block a user