83 lines
2.2 KiB
Python
83 lines
2.2 KiB
Python
import os
|
|
import math
|
|
import random
|
|
from collections import defaultdict
|
|
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
|
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
def read_classnames(text_file):
|
|
"""Return a dictionary containing
|
|
key-value pairs of <folder name>: <class name>.
|
|
"""
|
|
classnames = OrderedDict()
|
|
with open(text_file, "r") as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
line = line.strip().split(" ")
|
|
folder = line[0]
|
|
classname = " ".join(line[1:])
|
|
classnames[folder] = classname
|
|
return classnames
|
|
|
|
def listdir_nohidden(path, sort=False):
|
|
"""List non-hidden items in a directory.
|
|
|
|
Args:
|
|
path (str): directory path.
|
|
sort (bool): sort the items.
|
|
"""
|
|
items = [f for f in os.listdir(path) if not f.startswith(".")]
|
|
if sort:
|
|
items.sort()
|
|
return items
|
|
class ImageNetV2(DatasetBase):
|
|
"""ImageNetV2.
|
|
|
|
This dataset is used for testing only.
|
|
"""
|
|
|
|
dataset_dir ="imagenetv2"
|
|
|
|
def __init__(self, root):
|
|
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
|
self.image_dir = os.path.join(self.dataset_dir, 'imagenetv2')
|
|
|
|
|
|
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
|
|
|
classnames = read_classnames(text_file)
|
|
|
|
data = self.read_data(classnames)
|
|
|
|
super().__init__(train_x=data, val=data, test=data)
|
|
def read_data(self, classnames):
|
|
image_dir = self.image_dir
|
|
folders = list(classnames.keys())
|
|
items = []
|
|
|
|
for label in range(1000):
|
|
class_dir = os.path.join(image_dir, str(label))
|
|
imnames = listdir_nohidden(class_dir)
|
|
folder = folders[label]
|
|
classname = classnames[folder]
|
|
for imname in imnames:
|
|
impath = os.path.join(class_dir, imname)
|
|
# item = {"impath": impath, "label": label, "classname": classname}
|
|
item = Datum(
|
|
impath=impath,
|
|
label=label,
|
|
classname=classname
|
|
)
|
|
items.append(item)
|
|
|
|
return items
|
|
|
|
|
|
|