125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
import os
|
|
import math
|
|
import random
|
|
from collections import defaultdict
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
|
|
|
|
|
template = ['a photo of a {}, a type of pet.']
|
|
|
|
|
|
class OxfordPets(DatasetBase):
|
|
|
|
dataset_dir = 'oxford_pets'
|
|
|
|
def __init__(self, root, num_shots):
|
|
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
|
self.image_dir = os.path.join(self.dataset_dir, 'images')
|
|
self.anno_dir = os.path.join(self.dataset_dir, 'annotations')
|
|
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json')
|
|
|
|
self.template = template
|
|
|
|
train, val, test = self.read_split(self.split_path, self.image_dir)
|
|
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
|
|
|
super().__init__(train_x=train, val=val, test=test)
|
|
|
|
def read_data(self, split_file):
|
|
filepath = os.path.join(self.anno_dir, split_file)
|
|
items = []
|
|
|
|
with open(filepath, 'r') as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
line = line.strip()
|
|
imname, label, species, _ = line.split(' ')
|
|
breed = imname.split('_')[:-1]
|
|
breed = '_'.join(breed)
|
|
breed = breed.lower()
|
|
imname += '.jpg'
|
|
impath = os.path.join(self.image_dir, imname)
|
|
label = int(label) - 1 # convert to 0-based index
|
|
item = Datum(
|
|
impath=impath,
|
|
label=label,
|
|
classname=breed
|
|
)
|
|
items.append(item)
|
|
|
|
return items
|
|
|
|
@staticmethod
|
|
def split_trainval(trainval, p_val=0.2):
|
|
p_trn = 1 - p_val
|
|
print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val')
|
|
tracker = defaultdict(list)
|
|
for idx, item in enumerate(trainval):
|
|
label = item.label
|
|
tracker[label].append(idx)
|
|
|
|
train, val = [], []
|
|
for label, idxs in tracker.items():
|
|
n_val = round(len(idxs) * p_val)
|
|
assert n_val > 0
|
|
random.shuffle(idxs)
|
|
for n, idx in enumerate(idxs):
|
|
item = trainval[idx]
|
|
if n < n_val:
|
|
val.append(item)
|
|
else:
|
|
train.append(item)
|
|
|
|
return train, val
|
|
|
|
@staticmethod
|
|
def save_split(train, val, test, filepath, path_prefix):
|
|
def _extract(items):
|
|
out = []
|
|
for item in items:
|
|
impath = item.impath
|
|
label = item.label
|
|
classname = item.classname
|
|
impath = impath.replace(path_prefix, '')
|
|
if impath.startswith('/'):
|
|
impath = impath[1:]
|
|
out.append((impath, label, classname))
|
|
return out
|
|
|
|
train = _extract(train)
|
|
val = _extract(val)
|
|
test = _extract(test)
|
|
|
|
split = {
|
|
'train': train,
|
|
'val': val,
|
|
'test': test
|
|
}
|
|
|
|
write_json(split, filepath)
|
|
print(f'Saved split to {filepath}')
|
|
|
|
@staticmethod
|
|
def read_split(filepath, path_prefix):
|
|
def _convert(items):
|
|
out = []
|
|
for impath, label, classname in items:
|
|
impath = os.path.join(path_prefix, impath)
|
|
item = Datum(
|
|
impath=impath,
|
|
label=int(label),
|
|
classname=classname
|
|
)
|
|
out.append(item)
|
|
return out
|
|
|
|
print(f'Reading split from {filepath}')
|
|
split = read_json(filepath)
|
|
train = _convert(split['train'])
|
|
val = _convert(split['val'])
|
|
test = _convert(split['test'])
|
|
|
|
return train, val, test |