Upload to Main
This commit is contained in:
229
datasets/pascal_voc.py
Normal file
229
datasets/pascal_voc.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import math
|
||||
CAT_LIST = ['aeroplane',
|
||||
'bicycle',
|
||||
'bird',
|
||||
'boat',
|
||||
'bottle',
|
||||
'bus',
|
||||
'car',
|
||||
'cat',
|
||||
'chair',
|
||||
'cow',
|
||||
'table',
|
||||
'dog',
|
||||
'horse',
|
||||
'motorbike',
|
||||
'person',
|
||||
'plant',
|
||||
'sheep',
|
||||
'sofa',
|
||||
'train',
|
||||
'tvmonitor']
|
||||
|
||||
CAT_LIST_TO_NAME = dict(zip(range(len(CAT_LIST)) ,CAT_LIST))
|
||||
|
||||
|
||||
def _collate(ims, y, c):
|
||||
return Datum(impath=ims, label=y, classname=c)
|
||||
|
||||
def load_img_name_list(dataset_path):
|
||||
|
||||
img_gt_name_list = open(dataset_path).readlines()
|
||||
img_name_list = [img_gt_name.strip() for img_gt_name in img_gt_name_list]
|
||||
|
||||
return img_name_list
|
||||
|
||||
def load_image_label_list_from_npy(data_root,img_name_list, label_file_path=None):
|
||||
if label_file_path is None:
|
||||
label_file_path = 'voc12/cls_labels.npy'
|
||||
cls_labels_dict = np.load(label_file_path, allow_pickle=True).item()
|
||||
label_list = []
|
||||
data_dtm = []
|
||||
|
||||
for id in img_name_list:
|
||||
if id not in cls_labels_dict.keys():
|
||||
img_name = id + '.jpg'
|
||||
else:
|
||||
img_name = id
|
||||
label = cls_labels_dict[img_name]
|
||||
label_idx = np.where(label==1)[0]
|
||||
class_name = [CAT_LIST[idx] for idx in range(len(label_idx))]
|
||||
data_dtm.append(_collate(os.path.join(data_root,img_name+'.jpg'),label,class_name))
|
||||
|
||||
return data_dtm
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VOC12(DatasetBase):
|
||||
dataset_dir = "voc12data"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir,'VOCdevkit/VOC2012/JPEGImages')
|
||||
train_img_name_list_path = os.path.join('voc12/train_aug_id.txt')
|
||||
val_img_name_list_path = os.path.join('voc12/val_id.txt')
|
||||
|
||||
train = load_image_label_list_from_npy(self.image_dir,load_img_name_list(train_img_name_list_path))
|
||||
val = load_image_label_list_from_npy(self.image_dir,load_img_name_list(val_img_name_list_path))
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train = data["train"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
data = {"train": train}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val = self.subsample_classes(train, val, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=val)
|
||||
|
||||
@staticmethod
|
||||
def subsample_classes(*args, subsample="all"):
|
||||
"""Divide classes into two groups. The first group
|
||||
represents base classes while the second group represents
|
||||
new classes.
|
||||
|
||||
Args:
|
||||
args: a list of datasets, e.g. train, val and test.
|
||||
subsample (str): what classes to subsample.
|
||||
"""
|
||||
assert subsample in ["all", "base", "new"]
|
||||
|
||||
if subsample == "all":
|
||||
return args
|
||||
|
||||
dataset = args[0]
|
||||
labels = set()
|
||||
for item in dataset:
|
||||
label_idx = random.choices(np.where(item.label == 1)[0])[0]
|
||||
labels.add(label_idx)
|
||||
labels = list(labels)
|
||||
labels.sort()
|
||||
n = len(labels)
|
||||
# Divide classes into two halves
|
||||
m = math.ceil(n / 2)
|
||||
|
||||
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
|
||||
if subsample == "base":
|
||||
selected = labels[:m] # take the first half
|
||||
else:
|
||||
selected = labels[m:] # take the second half
|
||||
relabeler = {y: y_new for y_new, y in enumerate(selected)}
|
||||
|
||||
output = []
|
||||
for dataset in args:
|
||||
dataset_new = []
|
||||
for item in dataset:
|
||||
label_idx = random.choices(np.where(item.label == 1)[0])[0]
|
||||
if label_idx not in selected:
|
||||
continue
|
||||
|
||||
item_new = Datum(
|
||||
impath=item.impath,
|
||||
label=item.label,
|
||||
classname=item.classname
|
||||
)
|
||||
dataset_new.append(item_new)
|
||||
output.append(dataset_new)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_num_classes(data_source):
|
||||
"""Count number of classes.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
return len(CAT_LIST)
|
||||
|
||||
@staticmethod
|
||||
def get_lab2cname(data_source):
|
||||
"""Get a label-to-classname mapping (dict).
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
return CAT_LIST_TO_NAME, CAT_LIST
|
||||
|
||||
def split_dataset_by_label(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into class-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
one_hot_label = item.label
|
||||
label_idx = random.choices(np.where(one_hot_label==1)[0])[0]
|
||||
output[label_idx].append(item)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def read_classnames(text_file):
|
||||
"""Return a dictionary containing
|
||||
key-value pairs of <folder name>: <class name>.
|
||||
"""
|
||||
classnames = OrderedDict()
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
folder = line[0]
|
||||
classname = " ".join(line[1:])
|
||||
classnames[folder] = classname
|
||||
return classnames
|
||||
|
||||
def read_data(self, classnames, split_dir):
|
||||
split_dir = os.path.join(self.image_dir, split_dir)
|
||||
folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(split_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(split_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
Reference in New Issue
Block a user