from torchvision import datasets, transforms from torch import tensor, long def ImageNet(data_path): channel = 3 im_size = (224, 224) num_classes = 1000 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean, std) dst_train = datasets.ImageNet(data_path, split="train", transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) dst_test = datasets.ImageNet(data_path, split="val", transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) class_names = dst_train.classes dst_train.targets = tensor(dst_train.targets, dtype=long) dst_test.targets = tensor(dst_test.targets, dtype=long) return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test