Files
clip-symnets/datasets/imagenetv2.py
2024-05-21 19:41:56 +08:00

83 lines
2.2 KiB
Python

import os
import math
import random
from collections import defaultdict
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
import torch
import torchvision
import torchvision.transforms as transforms
from collections import OrderedDict
def read_classnames(text_file):
"""Return a dictionary containing
key-value pairs of <folder name>: <class name>.
"""
classnames = OrderedDict()
with open(text_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(" ")
folder = line[0]
classname = " ".join(line[1:])
classnames[folder] = classname
return classnames
def listdir_nohidden(path, sort=False):
"""List non-hidden items in a directory.
Args:
path (str): directory path.
sort (bool): sort the items.
"""
items = [f for f in os.listdir(path) if not f.startswith(".")]
if sort:
items.sort()
return items
class ImageNetV2(DatasetBase):
"""ImageNetV2.
This dataset is used for testing only.
"""
dataset_dir ="imagenetv2"
def __init__(self, root):
self.dataset_dir = os.path.join(root, self.dataset_dir)
self.image_dir = os.path.join(self.dataset_dir, 'imagenetv2')
text_file = os.path.join(self.dataset_dir, "classnames.txt")
classnames = read_classnames(text_file)
data = self.read_data(classnames)
super().__init__(train_x=data, val=data, test=data)
def read_data(self, classnames):
image_dir = self.image_dir
folders = list(classnames.keys())
items = []
for label in range(1000):
class_dir = os.path.join(image_dir, str(label))
imnames = listdir_nohidden(class_dir)
folder = folders[label]
classname = classnames[folder]
for imname in imnames:
impath = os.path.join(class_dir, imname)
# item = {"impath": impath, "label": label, "classname": classname}
item = Datum(
impath=impath,
label=label,
classname=classname
)
items.append(item)
return items