Files
DAPT/datasets/pascal_voc.py
2025-10-07 22:42:55 +08:00

230 lines
7.2 KiB
Python

import os
import pickle
from collections import OrderedDict
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from dassl.utils import listdir_nohidden, mkdir_if_missing
from .oxford_pets import OxfordPets
import numpy as np
from pathlib import Path
from collections import defaultdict
import random
import math
CAT_LIST = ['aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'table',
'dog',
'horse',
'motorbike',
'person',
'plant',
'sheep',
'sofa',
'train',
'tvmonitor']
CAT_LIST_TO_NAME = dict(zip(range(len(CAT_LIST)) ,CAT_LIST))
def _collate(ims, y, c):
return Datum(impath=ims, label=y, classname=c)
def load_img_name_list(dataset_path):
img_gt_name_list = open(dataset_path).readlines()
img_name_list = [img_gt_name.strip() for img_gt_name in img_gt_name_list]
return img_name_list
def load_image_label_list_from_npy(data_root,img_name_list, label_file_path=None):
if label_file_path is None:
label_file_path = 'voc12/cls_labels.npy'
cls_labels_dict = np.load(label_file_path, allow_pickle=True).item()
label_list = []
data_dtm = []
for id in img_name_list:
if id not in cls_labels_dict.keys():
img_name = id + '.jpg'
else:
img_name = id
label = cls_labels_dict[img_name]
label_idx = np.where(label==1)[0]
class_name = [CAT_LIST[idx] for idx in range(len(label_idx))]
data_dtm.append(_collate(os.path.join(data_root,img_name+'.jpg'),label,class_name))
return data_dtm
@DATASET_REGISTRY.register()
class VOC12(DatasetBase):
dataset_dir = "voc12data"
def __init__(self, cfg):
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir,'VOCdevkit/VOC2012/JPEGImages')
train_img_name_list_path = os.path.join('voc12/train_aug_id.txt')
val_img_name_list_path = os.path.join('voc12/val_id.txt')
train = load_image_label_list_from_npy(self.image_dir,load_img_name_list(train_img_name_list_path))
val = load_image_label_list_from_npy(self.image_dir,load_img_name_list(val_img_name_list_path))
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
mkdir_if_missing(self.split_fewshot_dir)
num_shots = cfg.DATASET.NUM_SHOTS
if num_shots >= 1:
seed = cfg.SEED
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
if os.path.exists(preprocessed):
print(f"Loading preprocessed few-shot data from {preprocessed}")
with open(preprocessed, "rb") as file:
data = pickle.load(file)
train = data["train"]
else:
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
data = {"train": train}
print(f"Saving preprocessed few-shot data to {preprocessed}")
with open(preprocessed, "wb") as file:
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
train, val = self.subsample_classes(train, val, subsample=subsample)
super().__init__(train_x=train, val=val, test=val)
@staticmethod
def subsample_classes(*args, subsample="all"):
"""Divide classes into two groups. The first group
represents base classes while the second group represents
new classes.
Args:
args: a list of datasets, e.g. train, val and test.
subsample (str): what classes to subsample.
"""
assert subsample in ["all", "base", "new"]
if subsample == "all":
return args
dataset = args[0]
labels = set()
for item in dataset:
label_idx = random.choices(np.where(item.label == 1)[0])[0]
labels.add(label_idx)
labels = list(labels)
labels.sort()
n = len(labels)
# Divide classes into two halves
m = math.ceil(n / 2)
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
if subsample == "base":
selected = labels[:m] # take the first half
else:
selected = labels[m:] # take the second half
relabeler = {y: y_new for y_new, y in enumerate(selected)}
output = []
for dataset in args:
dataset_new = []
for item in dataset:
label_idx = random.choices(np.where(item.label == 1)[0])[0]
if label_idx not in selected:
continue
item_new = Datum(
impath=item.impath,
label=item.label,
classname=item.classname
)
dataset_new.append(item_new)
output.append(dataset_new)
return output
@staticmethod
def get_num_classes(data_source):
"""Count number of classes.
Args:
data_source (list): a list of Datum objects.
"""
return len(CAT_LIST)
@staticmethod
def get_lab2cname(data_source):
"""Get a label-to-classname mapping (dict).
Args:
data_source (list): a list of Datum objects.
"""
return CAT_LIST_TO_NAME, CAT_LIST
def split_dataset_by_label(self, data_source):
"""Split a dataset, i.e. a list of Datum objects,
into class-specific groups stored in a dictionary.
Args:
data_source (list): a list of Datum objects.
"""
output = defaultdict(list)
for item in data_source:
one_hot_label = item.label
label_idx = random.choices(np.where(one_hot_label==1)[0])[0]
output[label_idx].append(item)
return output
@staticmethod
def read_classnames(text_file):
"""Return a dictionary containing
key-value pairs of <folder name>: <class name>.
"""
classnames = OrderedDict()
with open(text_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(" ")
folder = line[0]
classname = " ".join(line[1:])
classnames[folder] = classname
return classnames
def read_data(self, classnames, split_dir):
split_dir = os.path.join(self.image_dir, split_dir)
folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
items = []
for label, folder in enumerate(folders):
imnames = listdir_nohidden(os.path.join(split_dir, folder))
classname = classnames[folder]
for imname in imnames:
impath = os.path.join(split_dir, folder, imname)
item = Datum(impath=impath, label=label, classname=classname)
items.append(item)
return items