init
This commit is contained in:
76
datasets/imagenet_sketch.py
Normal file
76
datasets/imagenet_sketch.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def read_classnames(text_file):
|
||||
"""Return a dictionary containing
|
||||
key-value pairs of <folder name>: <class name>.
|
||||
"""
|
||||
classnames = OrderedDict()
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
folder = line[0]
|
||||
classname = " ".join(line[1:])
|
||||
classnames[folder] = classname
|
||||
return classnames
|
||||
|
||||
def listdir_nohidden(path, sort=False):
|
||||
"""List non-hidden items in a directory.
|
||||
|
||||
Args:
|
||||
path (str): directory path.
|
||||
sort (bool): sort the items.
|
||||
"""
|
||||
items = [f for f in os.listdir(path) if not f.startswith(".")]
|
||||
if sort:
|
||||
items.sort()
|
||||
return items
|
||||
class ImageNetSketch(DatasetBase):
|
||||
"""ImageNet-Sketch.
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir ="imagenet_sketch"
|
||||
|
||||
def __init__(self, data_dir):
|
||||
root = data_dir
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'images')
|
||||
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, val=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = listdir_nohidden(image_dir, sort=True)
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(image_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(image_dir, folder, imname)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
Reference in New Issue
Block a user