Files
DAPT/deepcore/datasets/cifar100.py
2025-10-07 22:42:55 +08:00

18 lines
767 B
Python

from torchvision import datasets, transforms
from torch import tensor, long
def CIFAR100(data_path):
channel = 3
im_size = (32, 32)
num_classes = 100
mean = [0.5071, 0.4865, 0.4409]
std = [0.2673, 0.2564, 0.2762]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform)
dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform)
class_names = dst_train.classes
dst_train.targets = tensor(dst_train.targets, dtype=long)
dst_test.targets = tensor(dst_test.targets, dtype=long)
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test