187 lines
6.5 KiB
Python
187 lines
6.5 KiB
Python
import os
|
|
import pickle
|
|
import math
|
|
import random
|
|
from collections import defaultdict
|
|
|
|
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
|
from dassl.utils import read_json, write_json, mkdir_if_missing
|
|
|
|
|
|
@DATASET_REGISTRY.register()
|
|
class OxfordPets(DatasetBase):
|
|
|
|
dataset_dir = "oxford_pets"
|
|
|
|
def __init__(self, cfg):
|
|
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
|
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
|
self.image_dir = os.path.join(self.dataset_dir, "images")
|
|
self.anno_dir = os.path.join(self.dataset_dir, "annotations")
|
|
self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json")
|
|
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
|
mkdir_if_missing(self.split_fewshot_dir)
|
|
|
|
if os.path.exists(self.split_path):
|
|
train, val, test = self.read_split(self.split_path, self.image_dir)
|
|
else:
|
|
trainval = self.read_data(split_file="trainval.txt")
|
|
test = self.read_data(split_file="test.txt")
|
|
train, val = self.split_trainval(trainval)
|
|
self.save_split(train, val, test, self.split_path, self.image_dir)
|
|
|
|
num_shots = cfg.DATASET.NUM_SHOTS
|
|
if num_shots >= 1:
|
|
seed = cfg.SEED
|
|
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
|
|
|
if os.path.exists(preprocessed):
|
|
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
|
with open(preprocessed, "rb") as file:
|
|
data = pickle.load(file)
|
|
train, val = data["train"], data["val"]
|
|
else:
|
|
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
|
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
|
data = {"train": train, "val": val}
|
|
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
|
with open(preprocessed, "wb") as file:
|
|
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
|
train, val, test = self.subsample_classes(train, val, test, subsample=subsample)
|
|
|
|
super().__init__(train_x=train, val=val, test=test)
|
|
|
|
def read_data(self, split_file):
|
|
filepath = os.path.join(self.anno_dir, split_file)
|
|
items = []
|
|
|
|
with open(filepath, "r") as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
line = line.strip()
|
|
imname, label, species, _ = line.split(" ")
|
|
breed = imname.split("_")[:-1]
|
|
breed = "_".join(breed)
|
|
breed = breed.lower()
|
|
imname += ".jpg"
|
|
impath = os.path.join(self.image_dir, imname)
|
|
label = int(label) - 1 # convert to 0-based index
|
|
item = Datum(impath=impath, label=label, classname=breed)
|
|
items.append(item)
|
|
|
|
return items
|
|
|
|
@staticmethod
|
|
def split_trainval(trainval, p_val=0.2):
|
|
p_trn = 1 - p_val
|
|
print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val")
|
|
tracker = defaultdict(list)
|
|
for idx, item in enumerate(trainval):
|
|
label = item.label
|
|
tracker[label].append(idx)
|
|
|
|
train, val = [], []
|
|
for label, idxs in tracker.items():
|
|
n_val = round(len(idxs) * p_val)
|
|
assert n_val > 0
|
|
random.shuffle(idxs)
|
|
for n, idx in enumerate(idxs):
|
|
item = trainval[idx]
|
|
if n < n_val:
|
|
val.append(item)
|
|
else:
|
|
train.append(item)
|
|
|
|
return train, val
|
|
|
|
@staticmethod
|
|
def save_split(train, val, test, filepath, path_prefix):
|
|
def _extract(items):
|
|
out = []
|
|
for item in items:
|
|
impath = item.impath
|
|
label = item.label
|
|
classname = item.classname
|
|
impath = impath.replace(path_prefix, "")
|
|
if impath.startswith("/"):
|
|
impath = impath[1:]
|
|
out.append((impath, label, classname))
|
|
return out
|
|
|
|
train = _extract(train)
|
|
val = _extract(val)
|
|
test = _extract(test)
|
|
|
|
split = {"train": train, "val": val, "test": test}
|
|
|
|
write_json(split, filepath)
|
|
print(f"Saved split to {filepath}")
|
|
|
|
@staticmethod
|
|
def read_split(filepath, path_prefix):
|
|
def _convert(items):
|
|
out = []
|
|
for impath, label, classname in items:
|
|
impath = os.path.join(path_prefix, impath)
|
|
item = Datum(impath=impath, label=int(label), classname=classname)
|
|
out.append(item)
|
|
return out
|
|
|
|
print(f"Reading split from {filepath}")
|
|
split = read_json(filepath)
|
|
train = _convert(split["train"])
|
|
val = _convert(split["val"])
|
|
test = _convert(split["test"])
|
|
|
|
return train, val, test
|
|
|
|
@staticmethod
|
|
def subsample_classes(*args, subsample="all"):
|
|
"""Divide classes into two groups. The first group
|
|
represents base classes while the second group represents
|
|
new classes.
|
|
|
|
Args:
|
|
args: a list of datasets, e.g. train, val and test.
|
|
subsample (str): what classes to subsample.
|
|
"""
|
|
assert subsample in ["all", "base", "new"]
|
|
|
|
if subsample == "all":
|
|
return args
|
|
|
|
dataset = args[0]
|
|
labels = set()
|
|
for item in dataset:
|
|
labels.add(item.label)
|
|
labels = list(labels)
|
|
labels.sort()
|
|
n = len(labels)
|
|
# Divide classes into two halves
|
|
m = math.ceil(n / 2)
|
|
|
|
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
|
|
if subsample == "base":
|
|
selected = labels[:m] # take the first half
|
|
else:
|
|
selected = labels[m:] # take the second half
|
|
relabeler = {y: y_new for y_new, y in enumerate(selected)}
|
|
|
|
output = []
|
|
for dataset in args:
|
|
dataset_new = []
|
|
for item in dataset:
|
|
if item.label not in selected:
|
|
continue
|
|
item_new = Datum(
|
|
impath=item.impath,
|
|
label=relabeler[item.label],
|
|
classname=item.classname
|
|
)
|
|
dataset_new.append(item_new)
|
|
output.append(dataset_new)
|
|
|
|
return output
|