init
This commit is contained in:
33
datasets/__init__.py
Normal file
33
datasets/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from .oxford_pets import OxfordPets
|
||||
from .eurosat import EuroSAT
|
||||
from .ucf101 import UCF101
|
||||
from .sun397 import SUN397
|
||||
from .caltech101 import Caltech101
|
||||
from .dtd import DescribableTextures
|
||||
from .fgvc import FGVCAircraft
|
||||
from .food101 import Food101
|
||||
from .oxford_flowers import OxfordFlowers
|
||||
from .stanford_cars import StanfordCars
|
||||
from .imagenet import ImageNet
|
||||
from .caltech101_tsne import Caltech101_TSNE
|
||||
|
||||
|
||||
|
||||
dataset_list = {
|
||||
"oxford_pets": OxfordPets,
|
||||
"eurosat": EuroSAT,
|
||||
"ucf101": UCF101,
|
||||
"sun397": SUN397,
|
||||
"caltech101": Caltech101,
|
||||
"dtd": DescribableTextures,
|
||||
"fgvc": FGVCAircraft,
|
||||
"food101": Food101,
|
||||
"oxford_flowers": OxfordFlowers,
|
||||
"stanford_cars": StanfordCars,
|
||||
"caltech101_tsne": Caltech101_TSNE,
|
||||
"imagenet":ImageNet,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset(dataset, root_path, shots):
|
||||
return dataset_list[dataset](root_path, shots)
|
||||
24
datasets/caltech101.py
Normal file
24
datasets/caltech101.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
from .utils import Datum, DatasetBase
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
template = ['a photo of a {}.']
|
||||
|
||||
|
||||
class Caltech101(DatasetBase):
|
||||
|
||||
dataset_dir = 'caltech-101'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
t_sne_test = self.generate_fewshot_dataset(test, num_shots=num_shots)
|
||||
super().__init__(train_x=train, val=val, test=test,t_sne=t_sne_test)
|
||||
24
datasets/caltech101_tsne.py
Normal file
24
datasets/caltech101_tsne.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
from .utils import Datum, DatasetBase
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
template = ['a photo of a {}.']
|
||||
|
||||
|
||||
class Caltech101_TSNE(DatasetBase):
|
||||
|
||||
dataset_dir = 'caltech-101_tsne'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
t_sne_test = self.generate_fewshot_dataset(test, num_shots=num_shots)
|
||||
super().__init__(train_x=train, val=val, test=test,t_sne=t_sne_test)
|
||||
79
datasets/dtd.py
Normal file
79
datasets/dtd.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
from .utils import Datum, DatasetBase, listdir_nohidden
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
template = ['{} texture.']
|
||||
|
||||
|
||||
class DescribableTextures(DatasetBase):
|
||||
|
||||
dataset_dir = 'dtd'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'images')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
@staticmethod
|
||||
def read_and_split_data(
|
||||
image_dir,
|
||||
p_trn=0.5,
|
||||
p_val=0.2,
|
||||
ignored=[],
|
||||
new_cnames=None
|
||||
):
|
||||
# The data are supposed to be organized into the following structure
|
||||
# =============
|
||||
# images/
|
||||
# dog/
|
||||
# cat/
|
||||
# horse/
|
||||
# =============
|
||||
categories = listdir_nohidden(image_dir)
|
||||
categories = [c for c in categories if c not in ignored]
|
||||
categories.sort()
|
||||
|
||||
p_tst = 1 - p_trn - p_val
|
||||
print(f'Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test')
|
||||
|
||||
def _collate(ims, y, c):
|
||||
items = []
|
||||
for im in ims:
|
||||
item = Datum(
|
||||
impath=im,
|
||||
label=y, # is already 0-based
|
||||
classname=c
|
||||
)
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
train, val, test = [], [], []
|
||||
for label, category in enumerate(categories):
|
||||
category_dir = os.path.join(image_dir, category)
|
||||
images = listdir_nohidden(category_dir)
|
||||
images = [os.path.join(category_dir, im) for im in images]
|
||||
random.shuffle(images)
|
||||
n_total = len(images)
|
||||
n_train = round(n_total * p_trn)
|
||||
n_val = round(n_total * p_val)
|
||||
n_test = n_total - n_train - n_val
|
||||
assert n_train > 0 and n_val > 0 and n_test > 0
|
||||
|
||||
if new_cnames is not None and category in new_cnames:
|
||||
category = new_cnames[category]
|
||||
|
||||
train.extend(_collate(images[:n_train], label, category))
|
||||
val.extend(_collate(images[n_train:n_train+n_val], label, category))
|
||||
test.extend(_collate(images[n_train+n_val:], label, category))
|
||||
|
||||
return train, val, test
|
||||
50
datasets/eurosat.py
Normal file
50
datasets/eurosat.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
template = ['a centered satellite photo of {}.']
|
||||
|
||||
|
||||
NEW_CNAMES = {
|
||||
'AnnualCrop': 'Annual Crop Land',
|
||||
'Forest': 'Forest',
|
||||
'HerbaceousVegetation': 'Herbaceous Vegetation Land',
|
||||
'Highway': 'Highway or Road',
|
||||
'Industrial': 'Industrial Buildings',
|
||||
'Pasture': 'Pasture Land',
|
||||
'PermanentCrop': 'Permanent Crop Land',
|
||||
'Residential': 'Residential Buildings',
|
||||
'River': 'River',
|
||||
'SeaLake': 'Sea or Lake'
|
||||
}
|
||||
|
||||
|
||||
class EuroSAT(DatasetBase):
|
||||
|
||||
dataset_dir = 'eurosat'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, '2750')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def update_classname(self, dataset_old):
|
||||
dataset_new = []
|
||||
for item_old in dataset_old:
|
||||
cname_old = item_old.classname
|
||||
cname_new = NEW_CLASSNAMES[cname_old]
|
||||
item_new = Datum(
|
||||
impath=item_old.impath,
|
||||
label=item_old.label,
|
||||
classname=cname_new
|
||||
)
|
||||
dataset_new.append(item_new)
|
||||
return dataset_new
|
||||
54
datasets/fgvc.py
Normal file
54
datasets/fgvc.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
|
||||
|
||||
template = ['a photo of a {}, a type of aircraft.']
|
||||
|
||||
|
||||
class FGVCAircraft(DatasetBase):
|
||||
|
||||
dataset_dir = 'fgvc_aircraft'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'images')
|
||||
|
||||
self.template = template
|
||||
|
||||
classnames = []
|
||||
with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
classnames.append(line.strip())
|
||||
cname2lab = {c: i for i, c in enumerate(classnames)}
|
||||
|
||||
train = self.read_data(cname2lab, 'images_variant_train.txt')
|
||||
val = self.read_data(cname2lab, 'images_variant_val.txt')
|
||||
test = self.read_data(cname2lab, 'images_variant_test.txt')
|
||||
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, split_file):
|
||||
filepath = os.path.join(self.dataset_dir, split_file)
|
||||
items = []
|
||||
|
||||
with open(filepath, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(' ')
|
||||
imname = line[0] + '.jpg'
|
||||
classname = ' '.join(line[1:])
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = cname2lab[classname]
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
24
datasets/food101.py
Normal file
24
datasets/food101.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
template = ['a photo of {}, a type of food.']
|
||||
|
||||
|
||||
class Food101(DatasetBase):
|
||||
|
||||
dataset_dir = 'food-101'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'images')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Food101.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
221
datasets/imagenet.py
Normal file
221
datasets/imagenet.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
|
||||
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
|
||||
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
|
||||
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
|
||||
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
|
||||
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
|
||||
"box turtle", "banded gecko", "green iguana", "Carolina anole",
|
||||
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
|
||||
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
|
||||
"American alligator", "triceratops", "worm snake", "ring-necked snake",
|
||||
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
|
||||
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
|
||||
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
|
||||
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
|
||||
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
|
||||
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
|
||||
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
|
||||
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
|
||||
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
|
||||
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
|
||||
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
|
||||
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
|
||||
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
|
||||
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
|
||||
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
|
||||
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
|
||||
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
|
||||
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
|
||||
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
|
||||
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
|
||||
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
|
||||
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
|
||||
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
|
||||
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
|
||||
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
|
||||
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
|
||||
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
|
||||
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
|
||||
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
|
||||
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
|
||||
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
|
||||
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
|
||||
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
|
||||
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
|
||||
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
|
||||
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
|
||||
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
|
||||
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
|
||||
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
|
||||
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
|
||||
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
|
||||
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
|
||||
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
|
||||
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
|
||||
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
|
||||
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
|
||||
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
|
||||
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
|
||||
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
|
||||
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
|
||||
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
|
||||
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
|
||||
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
|
||||
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
|
||||
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
|
||||
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
|
||||
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
|
||||
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
|
||||
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
|
||||
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
|
||||
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
|
||||
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
|
||||
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
|
||||
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
|
||||
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
|
||||
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
|
||||
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
|
||||
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
|
||||
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
|
||||
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
|
||||
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
|
||||
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
|
||||
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
|
||||
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
|
||||
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
|
||||
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
|
||||
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
|
||||
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
|
||||
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
|
||||
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
|
||||
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
|
||||
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
|
||||
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
|
||||
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
|
||||
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
|
||||
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
|
||||
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
|
||||
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
|
||||
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
|
||||
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
|
||||
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
|
||||
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
|
||||
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
|
||||
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
|
||||
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
|
||||
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
|
||||
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
|
||||
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
|
||||
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
|
||||
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
|
||||
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
|
||||
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
|
||||
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
|
||||
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
|
||||
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
|
||||
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
|
||||
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
|
||||
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
|
||||
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
|
||||
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
|
||||
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
|
||||
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
|
||||
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
|
||||
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
|
||||
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
|
||||
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
|
||||
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
|
||||
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
|
||||
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
|
||||
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
|
||||
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
|
||||
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
|
||||
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
|
||||
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
|
||||
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
|
||||
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
|
||||
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
|
||||
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
|
||||
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
|
||||
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
|
||||
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
|
||||
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
|
||||
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
|
||||
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
|
||||
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
|
||||
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
|
||||
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
|
||||
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
|
||||
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
|
||||
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
|
||||
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
|
||||
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
|
||||
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
|
||||
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
|
||||
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
|
||||
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
|
||||
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
|
||||
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
|
||||
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
|
||||
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
|
||||
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
|
||||
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
|
||||
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
|
||||
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
|
||||
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
||||
|
||||
imagenet_templates = ["itap of a {}.",
|
||||
"a bad photo of the {}.",
|
||||
"a origami {}.",
|
||||
"a photo of the large {}.",
|
||||
"a {} in a video game.",
|
||||
"art of the {}.",
|
||||
"a photo of the small {}."]
|
||||
|
||||
|
||||
class ImageNet():
|
||||
|
||||
dataset_dir = 'imagenet'
|
||||
|
||||
def __init__(self, root, num_shots, preprocess):
|
||||
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'images')
|
||||
|
||||
train_preprocess = transforms.Compose([
|
||||
transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
||||
])
|
||||
test_preprocess = preprocess
|
||||
|
||||
self.train = torchvision.datasets.ImageNet(self.image_dir, split='train', transform=train_preprocess)
|
||||
self.val = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess)
|
||||
self.test = torchvision.datasets.ImageNet(self.image_dir, split='val', transform=test_preprocess)
|
||||
|
||||
self.template = imagenet_templates
|
||||
self.classnames = imagenet_classes
|
||||
|
||||
split_by_label_dict = defaultdict(list)
|
||||
for i in range(len(self.train.imgs)):
|
||||
split_by_label_dict[self.train.targets[i]].append(self.train.imgs[i])
|
||||
imgs = []
|
||||
targets = []
|
||||
|
||||
for label, items in split_by_label_dict.items():
|
||||
imgs = imgs + random.sample(items, num_shots)
|
||||
targets = targets + [label for i in range(num_shots)]
|
||||
self.train.imgs = imgs
|
||||
self.train.targets = targets
|
||||
self.train.samples = imgs
|
||||
76
datasets/imagenet_sketch.py
Normal file
76
datasets/imagenet_sketch.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def read_classnames(text_file):
|
||||
"""Return a dictionary containing
|
||||
key-value pairs of <folder name>: <class name>.
|
||||
"""
|
||||
classnames = OrderedDict()
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
folder = line[0]
|
||||
classname = " ".join(line[1:])
|
||||
classnames[folder] = classname
|
||||
return classnames
|
||||
|
||||
def listdir_nohidden(path, sort=False):
|
||||
"""List non-hidden items in a directory.
|
||||
|
||||
Args:
|
||||
path (str): directory path.
|
||||
sort (bool): sort the items.
|
||||
"""
|
||||
items = [f for f in os.listdir(path) if not f.startswith(".")]
|
||||
if sort:
|
||||
items.sort()
|
||||
return items
|
||||
class ImageNetSketch(DatasetBase):
|
||||
"""ImageNet-Sketch.
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir ="imagenet_sketch"
|
||||
|
||||
def __init__(self, data_dir):
|
||||
root = data_dir
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'images')
|
||||
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, val=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = listdir_nohidden(image_dir, sort=True)
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(image_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(image_dir, folder, imname)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
82
datasets/imagenetv2.py
Normal file
82
datasets/imagenetv2.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
|
||||
def read_classnames(text_file):
|
||||
"""Return a dictionary containing
|
||||
key-value pairs of <folder name>: <class name>.
|
||||
"""
|
||||
classnames = OrderedDict()
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
folder = line[0]
|
||||
classname = " ".join(line[1:])
|
||||
classnames[folder] = classname
|
||||
return classnames
|
||||
|
||||
def listdir_nohidden(path, sort=False):
|
||||
"""List non-hidden items in a directory.
|
||||
|
||||
Args:
|
||||
path (str): directory path.
|
||||
sort (bool): sort the items.
|
||||
"""
|
||||
items = [f for f in os.listdir(path) if not f.startswith(".")]
|
||||
if sort:
|
||||
items.sort()
|
||||
return items
|
||||
class ImageNetV2(DatasetBase):
|
||||
"""ImageNetV2.
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir ="imagenetv2"
|
||||
|
||||
def __init__(self, root):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'imagenetv2')
|
||||
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
|
||||
classnames = read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, val=data, test=data)
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = list(classnames.keys())
|
||||
items = []
|
||||
|
||||
for label in range(1000):
|
||||
class_dir = os.path.join(image_dir, str(label))
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
folder = folders[label]
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(class_dir, imname)
|
||||
# item = {"impath": impath, "label": label, "classname": classname}
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
|
||||
67
datasets/oxford_flowers.py
Normal file
67
datasets/oxford_flowers.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
import random
|
||||
from scipy.io import loadmat
|
||||
from collections import defaultdict
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from .utils import Datum, DatasetBase, read_json
|
||||
|
||||
|
||||
template = ['a photo of a {}, a type of flower.']
|
||||
|
||||
|
||||
class OxfordFlowers(DatasetBase):
|
||||
|
||||
dataset_dir = 'oxford_flowers'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'jpg')
|
||||
self.label_file = os.path.join(self.dataset_dir, 'imagelabels.mat')
|
||||
self.lab2cname_file = os.path.join(self.dataset_dir, 'cat_to_name.json')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordFlowers.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self):
|
||||
tracker = defaultdict(list)
|
||||
label_file = loadmat(self.label_file)['labels'][0]
|
||||
for i, label in enumerate(label_file):
|
||||
imname = f'image_{str(i + 1).zfill(5)}.jpg'
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = int(label)
|
||||
tracker[label].append(impath)
|
||||
|
||||
print('Splitting data into 50% train, 20% val, and 30% test')
|
||||
|
||||
def _collate(ims, y, c):
|
||||
items = []
|
||||
for im in ims:
|
||||
item = Datum(
|
||||
impath=im,
|
||||
label=y-1, # convert to 0-based label
|
||||
classname=c
|
||||
)
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
lab2cname = read_json(self.lab2cname_file)
|
||||
train, val, test = [], [], []
|
||||
for label, impaths in tracker.items():
|
||||
random.shuffle(impaths)
|
||||
n_total = len(impaths)
|
||||
n_train = round(n_total * 0.5)
|
||||
n_val = round(n_total * 0.2)
|
||||
n_test = n_total - n_train - n_val
|
||||
assert n_train > 0 and n_val > 0 and n_test > 0
|
||||
cname = lab2cname[str(label)]
|
||||
train.extend(_collate(impaths[:n_train], label, cname))
|
||||
val.extend(_collate(impaths[n_train:n_train+n_val], label, cname))
|
||||
test.extend(_collate(impaths[n_train+n_val:], label, cname))
|
||||
|
||||
return train, val, test
|
||||
125
datasets/oxford_pets.py
Normal file
125
datasets/oxford_pets.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
|
||||
|
||||
template = ['a photo of a {}, a type of pet.']
|
||||
|
||||
|
||||
class OxfordPets(DatasetBase):
|
||||
|
||||
dataset_dir = 'oxford_pets'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'images')
|
||||
self.anno_dir = os.path.join(self.dataset_dir, 'annotations')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = self.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, split_file):
|
||||
filepath = os.path.join(self.anno_dir, split_file)
|
||||
items = []
|
||||
|
||||
with open(filepath, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
imname, label, species, _ = line.split(' ')
|
||||
breed = imname.split('_')[:-1]
|
||||
breed = '_'.join(breed)
|
||||
breed = breed.lower()
|
||||
imname += '.jpg'
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = int(label) - 1 # convert to 0-based index
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
classname=breed
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def split_trainval(trainval, p_val=0.2):
|
||||
p_trn = 1 - p_val
|
||||
print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val')
|
||||
tracker = defaultdict(list)
|
||||
for idx, item in enumerate(trainval):
|
||||
label = item.label
|
||||
tracker[label].append(idx)
|
||||
|
||||
train, val = [], []
|
||||
for label, idxs in tracker.items():
|
||||
n_val = round(len(idxs) * p_val)
|
||||
assert n_val > 0
|
||||
random.shuffle(idxs)
|
||||
for n, idx in enumerate(idxs):
|
||||
item = trainval[idx]
|
||||
if n < n_val:
|
||||
val.append(item)
|
||||
else:
|
||||
train.append(item)
|
||||
|
||||
return train, val
|
||||
|
||||
@staticmethod
|
||||
def save_split(train, val, test, filepath, path_prefix):
|
||||
def _extract(items):
|
||||
out = []
|
||||
for item in items:
|
||||
impath = item.impath
|
||||
label = item.label
|
||||
classname = item.classname
|
||||
impath = impath.replace(path_prefix, '')
|
||||
if impath.startswith('/'):
|
||||
impath = impath[1:]
|
||||
out.append((impath, label, classname))
|
||||
return out
|
||||
|
||||
train = _extract(train)
|
||||
val = _extract(val)
|
||||
test = _extract(test)
|
||||
|
||||
split = {
|
||||
'train': train,
|
||||
'val': val,
|
||||
'test': test
|
||||
}
|
||||
|
||||
write_json(split, filepath)
|
||||
print(f'Saved split to {filepath}')
|
||||
|
||||
@staticmethod
|
||||
def read_split(filepath, path_prefix):
|
||||
def _convert(items):
|
||||
out = []
|
||||
for impath, label, classname in items:
|
||||
impath = os.path.join(path_prefix, impath)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=int(label),
|
||||
classname=classname
|
||||
)
|
||||
out.append(item)
|
||||
return out
|
||||
|
||||
print(f'Reading split from {filepath}')
|
||||
split = read_json(filepath)
|
||||
train = _convert(split['train'])
|
||||
val = _convert(split['val'])
|
||||
test = _convert(split['test'])
|
||||
|
||||
return train, val, test
|
||||
48
datasets/stanford_cars.py
Normal file
48
datasets/stanford_cars.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
from scipy.io import loadmat
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from .utils import Datum, DatasetBase
|
||||
|
||||
|
||||
template = ['a photo of a {}.']
|
||||
|
||||
|
||||
class StanfordCars(DatasetBase):
|
||||
|
||||
dataset_dir = 'stanford_cars'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_StanfordCars.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, image_dir, anno_file, meta_file):
|
||||
anno_file = loadmat(anno_file)['annotations'][0]
|
||||
meta_file = loadmat(meta_file)['class_names'][0]
|
||||
items = []
|
||||
|
||||
for i in range(len(anno_file)):
|
||||
imname = anno_file[i]['fname'][0]
|
||||
impath = os.path.join(self.dataset_dir, image_dir, imname)
|
||||
label = anno_file[i]['class'][0, 0]
|
||||
label = int(label) - 1 # convert to 0-based index
|
||||
classname = meta_file[label][0]
|
||||
names = classname.split(' ')
|
||||
year = names.pop(-1)
|
||||
names.insert(0, year)
|
||||
classname = ' '.join(names)
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
50
datasets/sun397.py
Normal file
50
datasets/sun397.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
template = ['a photo of a {}.']
|
||||
|
||||
|
||||
class SUN397(DatasetBase):
|
||||
|
||||
dataset_dir = 'sun397'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'SUN397')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, text_file):
|
||||
text_file = os.path.join(self.dataset_dir, text_file)
|
||||
items = []
|
||||
|
||||
with open(text_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
imname = line.strip()[1:] # remove /
|
||||
classname = os.path.dirname(imname)
|
||||
label = cname2lab[classname]
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
|
||||
names = classname.split('/')[1:] # remove 1st letter
|
||||
names = names[::-1] # put words like indoor/outdoor at first
|
||||
classname = ' '.join(names)
|
||||
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
classname=classname
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
51
datasets/ucf101.py
Normal file
51
datasets/ucf101.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
|
||||
from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
template = ['a photo of a person doing {}.']
|
||||
|
||||
|
||||
class UCF101(DatasetBase):
|
||||
|
||||
dataset_dir = 'ucf101'
|
||||
|
||||
def __init__(self, root, num_shots):
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes')
|
||||
self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json')
|
||||
|
||||
self.template = template
|
||||
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, text_file):
|
||||
text_file = os.path.join(self.dataset_dir, text_file)
|
||||
items = []
|
||||
|
||||
with open(text_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(' ')[0] # trainlist: filename, label
|
||||
action, filename = line.split('/')
|
||||
label = cname2lab[action]
|
||||
|
||||
elements = re.findall('[A-Z][^A-Z]*', action)
|
||||
renamed_action = '_'.join(elements)
|
||||
|
||||
filename = filename.replace('.avi', '.jpg')
|
||||
impath = os.path.join(self.image_dir, renamed_action, filename)
|
||||
|
||||
item = Datum(
|
||||
impath=impath,
|
||||
label=label,
|
||||
classname=renamed_action
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
377
datasets/utils.py
Normal file
377
datasets/utils.py
Normal file
@@ -0,0 +1,377 @@
|
||||
import os
|
||||
import random
|
||||
import os.path as osp
|
||||
import tarfile
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
import gdown
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_json(fpath):
|
||||
"""Read json file from a path."""
|
||||
with open(fpath, 'r') as f:
|
||||
obj = json.load(f)
|
||||
return obj
|
||||
|
||||
|
||||
def write_json(obj, fpath):
|
||||
"""Writes to a json file."""
|
||||
if not osp.exists(osp.dirname(fpath)):
|
||||
os.makedirs(osp.dirname(fpath))
|
||||
with open(fpath, 'w') as f:
|
||||
json.dump(obj, f, indent=4, separators=(',', ': '))
|
||||
|
||||
|
||||
def read_image(path):
|
||||
"""Read image from path using ``PIL.Image``.
|
||||
|
||||
Args:
|
||||
path (str): path to an image.
|
||||
|
||||
Returns:
|
||||
PIL image
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
raise IOError('No file exists at {}'.format(path))
|
||||
|
||||
while True:
|
||||
try:
|
||||
img = Image.open(path).convert('RGB')
|
||||
return img
|
||||
except IOError:
|
||||
print(
|
||||
'Cannot read image from {}, '
|
||||
'probably due to heavy IO. Will re-try'.format(path)
|
||||
)
|
||||
|
||||
|
||||
def listdir_nohidden(path, sort=False):
|
||||
"""List non-hidden items in a directory.
|
||||
|
||||
Args:
|
||||
path (str): directory path.
|
||||
sort (bool): sort the items.
|
||||
"""
|
||||
items = [f for f in os.listdir(path) if not f.startswith('.') and 'sh' not in f]
|
||||
if sort:
|
||||
items.sort()
|
||||
return items
|
||||
|
||||
|
||||
class Datum:
|
||||
"""Data instance which defines the basic attributes.
|
||||
|
||||
Args:
|
||||
impath (str): image path.
|
||||
label (int): class label.
|
||||
domain (int): domain label.
|
||||
classname (str): class name.
|
||||
"""
|
||||
|
||||
def __init__(self, impath='', label=0, domain=-1, classname=''):
|
||||
assert isinstance(impath, str)
|
||||
assert isinstance(label, int)
|
||||
assert isinstance(domain, int)
|
||||
assert isinstance(classname, str)
|
||||
|
||||
self._impath = impath
|
||||
self._label = label
|
||||
self._domain = domain
|
||||
self._classname = classname
|
||||
|
||||
@property
|
||||
def impath(self):
|
||||
return self._impath
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return self._label
|
||||
|
||||
@property
|
||||
def domain(self):
|
||||
return self._domain
|
||||
|
||||
@property
|
||||
def classname(self):
|
||||
return self._classname
|
||||
|
||||
|
||||
class DatasetBase:
|
||||
"""A unified dataset class for
|
||||
1) domain adaptation
|
||||
2) domain generalization
|
||||
3) semi-supervised learning
|
||||
"""
|
||||
dataset_dir = '' # the directory where the dataset is stored
|
||||
domains = [] # string names of all domains
|
||||
|
||||
def __init__(self, train_x=None, train_u=None, val=None, test=None,t_sne=None):
|
||||
self._train_x = train_x # labeled training data
|
||||
self._train_u = train_u # unlabeled training data (optional)
|
||||
self._val = val # validation data (optional)
|
||||
self._test = test # test data
|
||||
self._num_classes = self.get_num_classes(train_x)
|
||||
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
|
||||
|
||||
@property
|
||||
def train_x(self):
|
||||
return self._train_x
|
||||
|
||||
@property
|
||||
def train_u(self):
|
||||
return self._train_u
|
||||
|
||||
@property
|
||||
def val(self):
|
||||
return self._val
|
||||
|
||||
@property
|
||||
def test(self):
|
||||
return self._test
|
||||
|
||||
@property
|
||||
def lab2cname(self):
|
||||
return self._lab2cname
|
||||
|
||||
@property
|
||||
def classnames(self):
|
||||
return self._classnames
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._num_classes
|
||||
|
||||
def get_num_classes(self, data_source):
|
||||
"""Count number of classes.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
label_set = set()
|
||||
for item in data_source:
|
||||
label_set.add(item.label)
|
||||
return max(label_set) + 1
|
||||
|
||||
def get_lab2cname(self, data_source):
|
||||
"""Get a label-to-classname mapping (dict).
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
container = set()
|
||||
for item in data_source:
|
||||
container.add((item.label, item.classname))
|
||||
mapping = {label: classname for label, classname in container}
|
||||
labels = list(mapping.keys())
|
||||
labels.sort()
|
||||
classnames = [mapping[label] for label in labels]
|
||||
return mapping, classnames
|
||||
|
||||
def check_input_domains(self, source_domains, target_domains):
|
||||
self.is_input_domain_valid(source_domains)
|
||||
self.is_input_domain_valid(target_domains)
|
||||
|
||||
def is_input_domain_valid(self, input_domains):
|
||||
for domain in input_domains:
|
||||
if domain not in self.domains:
|
||||
raise ValueError(
|
||||
'Input domain must belong to {}, '
|
||||
'but got [{}]'.format(self.domains, domain)
|
||||
)
|
||||
|
||||
def download_data(self, url, dst, from_gdrive=True):
|
||||
if not osp.exists(osp.dirname(dst)):
|
||||
os.makedirs(osp.dirname(dst))
|
||||
|
||||
if from_gdrive:
|
||||
gdown.download(url, dst, quiet=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
print('Extracting file ...')
|
||||
|
||||
try:
|
||||
tar = tarfile.open(dst)
|
||||
tar.extractall(path=osp.dirname(dst))
|
||||
tar.close()
|
||||
except:
|
||||
zip_ref = zipfile.ZipFile(dst, 'r')
|
||||
zip_ref.extractall(osp.dirname(dst))
|
||||
zip_ref.close()
|
||||
|
||||
print('File extracted to {}'.format(osp.dirname(dst)))
|
||||
|
||||
def generate_fewshot_dataset(
|
||||
self, *data_sources, num_shots=-1, repeat=True
|
||||
):
|
||||
"""Generate a few-shot dataset (typically for the training set).
|
||||
|
||||
This function is useful when one wants to evaluate a model
|
||||
in a few-shot learning setting where each class only contains
|
||||
a few number of images.
|
||||
|
||||
Args:
|
||||
data_sources: each individual is a list containing Datum objects.
|
||||
num_shots (int): number of instances per class to sample.
|
||||
repeat (bool): repeat images if needed.
|
||||
"""
|
||||
if num_shots < 1:
|
||||
if len(data_sources) == 1:
|
||||
return data_sources[0]
|
||||
return data_sources
|
||||
|
||||
print(f'Creating a {num_shots}-shot dataset')
|
||||
|
||||
output = []
|
||||
|
||||
for data_source in data_sources:
|
||||
tracker = self.split_dataset_by_label(data_source)
|
||||
dataset = []
|
||||
|
||||
for label, items in tracker.items():
|
||||
if len(items) >= num_shots:
|
||||
sampled_items = random.sample(items, num_shots)
|
||||
else:
|
||||
if repeat:
|
||||
sampled_items = random.choices(items, k=num_shots)
|
||||
else:
|
||||
sampled_items = items
|
||||
dataset.extend(sampled_items)
|
||||
|
||||
output.append(dataset)
|
||||
|
||||
if len(output) == 1:
|
||||
return output[0]
|
||||
|
||||
return output
|
||||
|
||||
def split_dataset_by_label(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into class-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
output[item.label].append(item)
|
||||
|
||||
return output
|
||||
|
||||
def split_dataset_by_domain(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into domain-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
output[item.domain].append(item)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class DatasetWrapper(TorchDataset):
|
||||
def __init__(self, data_source, input_size, transform=None, is_train=False,
|
||||
return_img0=False, k_tfm=1):
|
||||
self.data_source = data_source
|
||||
self.transform = transform # accept list (tuple) as input
|
||||
self.is_train = is_train
|
||||
# Augmenting an image K>1 times is only allowed during training
|
||||
self.k_tfm = k_tfm if is_train else 1
|
||||
self.return_img0 = return_img0
|
||||
|
||||
if self.k_tfm > 1 and transform is None:
|
||||
raise ValueError(
|
||||
'Cannot augment the image {} times '
|
||||
'because transform is None'.format(self.k_tfm)
|
||||
)
|
||||
|
||||
# Build transform that doesn't apply any data augmentation
|
||||
interp_mode = T.InterpolationMode.BICUBIC
|
||||
to_tensor = []
|
||||
to_tensor += [T.Resize(input_size, interpolation=interp_mode)]
|
||||
to_tensor += [T.ToTensor()]
|
||||
normalize = T.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
|
||||
)
|
||||
to_tensor += [normalize]
|
||||
self.to_tensor = T.Compose(to_tensor)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_source[idx]
|
||||
|
||||
output = {
|
||||
'label': item.label,
|
||||
'domain': item.domain,
|
||||
'impath': item.impath
|
||||
}
|
||||
|
||||
img0 = read_image(item.impath)
|
||||
|
||||
if self.transform is not None:
|
||||
if isinstance(self.transform, (list, tuple)):
|
||||
for i, tfm in enumerate(self.transform):
|
||||
img = self._transform_image(tfm, img0)
|
||||
keyname = 'img'
|
||||
if (i + 1) > 1:
|
||||
keyname += str(i + 1)
|
||||
output[keyname] = img
|
||||
else:
|
||||
img = self._transform_image(self.transform, img0)
|
||||
output['img'] = img
|
||||
|
||||
if self.return_img0:
|
||||
output['img0'] = self.to_tensor(img0)
|
||||
|
||||
return output['img'], output['label'], output['impath']
|
||||
|
||||
def _transform_image(self, tfm, img0):
|
||||
img_list = []
|
||||
|
||||
for k in range(self.k_tfm):
|
||||
img_list.append(tfm(img0))
|
||||
|
||||
img = img_list
|
||||
if len(img) == 1:
|
||||
img = img[0]
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def build_data_loader(
|
||||
data_source=None,
|
||||
batch_size=64,
|
||||
input_size=224,
|
||||
tfm=None,
|
||||
is_train=True,
|
||||
shuffle=False,
|
||||
dataset_wrapper=None
|
||||
):
|
||||
|
||||
if dataset_wrapper is None:
|
||||
dataset_wrapper = DatasetWrapper
|
||||
|
||||
# Build data loader
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset_wrapper(data_source, input_size=input_size, transform=tfm, is_train=is_train),
|
||||
batch_size=batch_size,
|
||||
num_workers=8,
|
||||
shuffle=shuffle,
|
||||
drop_last=False,
|
||||
pin_memory=(torch.cuda.is_available())
|
||||
)
|
||||
assert len(data_loader) > 0
|
||||
|
||||
return data_loader
|
||||
Reference in New Issue
Block a user