Files
DAPT/deepcore/datasets/tinyimagenet.py
2025-10-07 22:42:55 +08:00

36 lines
1.4 KiB
Python

from torchvision import datasets, transforms
import os
import requests
import zipfile
def TinyImageNet(data_path, downsize=True):
if not os.path.exists(os.path.join(data_path, "tiny-imagenet-200")):
url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" # 248MB
print("Downloading Tiny-ImageNet")
r = requests.get(url, stream=True)
with open(os.path.join(data_path, "tiny-imagenet-200.zip"), "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
print("Unziping Tiny-ImageNet")
with zipfile.ZipFile(os.path.join(data_path, "tiny-imagenet-200.zip")) as zf:
zf.extractall(path=data_path)
channel = 3
im_size = (32, 32) if downsize else (64, 64)
num_classes = 200
mean = (0.4802, 0.4481, 0.3975)
std = (0.2770, 0.2691, 0.2821)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
if downsize:
transform = transforms.Compose([transforms.Resize(32), transform])
dst_train = datasets.ImageFolder(root=os.path.join(data_path, 'tiny-imagenet-200/train'), transform=transform)
dst_test = datasets.ImageFolder(root=os.path.join(data_path, 'tiny-imagenet-200/test'), transform=transform)
class_names = dst_train.classes
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test