import os import math import random from collections import defaultdict import torchvision.transforms as transforms from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader template = ['a photo of a {}, a type of pet.'] class OxfordPets(DatasetBase): dataset_dir = 'oxford_pets' def __init__(self, root, num_shots): self.dataset_dir = os.path.join(root, self.dataset_dir) self.image_dir = os.path.join(self.dataset_dir, 'images') self.anno_dir = os.path.join(self.dataset_dir, 'annotations') self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json') self.template = template train, val, test = self.read_split(self.split_path, self.image_dir) train = self.generate_fewshot_dataset(train, num_shots=num_shots) super().__init__(train_x=train, val=val, test=test) def read_data(self, split_file): filepath = os.path.join(self.anno_dir, split_file) items = [] with open(filepath, 'r') as f: lines = f.readlines() for line in lines: line = line.strip() imname, label, species, _ = line.split(' ') breed = imname.split('_')[:-1] breed = '_'.join(breed) breed = breed.lower() imname += '.jpg' impath = os.path.join(self.image_dir, imname) label = int(label) - 1 # convert to 0-based index item = Datum( impath=impath, label=label, classname=breed ) items.append(item) return items @staticmethod def split_trainval(trainval, p_val=0.2): p_trn = 1 - p_val print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val') tracker = defaultdict(list) for idx, item in enumerate(trainval): label = item.label tracker[label].append(idx) train, val = [], [] for label, idxs in tracker.items(): n_val = round(len(idxs) * p_val) assert n_val > 0 random.shuffle(idxs) for n, idx in enumerate(idxs): item = trainval[idx] if n < n_val: val.append(item) else: train.append(item) return train, val @staticmethod def save_split(train, val, test, filepath, path_prefix): def _extract(items): out = [] for item in items: impath = item.impath label = item.label classname = item.classname impath = impath.replace(path_prefix, '') if impath.startswith('/'): impath = impath[1:] out.append((impath, label, classname)) return out train = _extract(train) val = _extract(val) test = _extract(test) split = { 'train': train, 'val': val, 'test': test } write_json(split, filepath) print(f'Saved split to {filepath}') @staticmethod def read_split(filepath, path_prefix): def _convert(items): out = [] for impath, label, classname in items: impath = os.path.join(path_prefix, impath) item = Datum( impath=impath, label=int(label), classname=classname ) out.append(item) return out print(f'Reading split from {filepath}') split = read_json(filepath) train = _convert(split['train']) val = _convert(split['val']) test = _convert(split['test']) return train, val, test