Upload to Main
This commit is contained in:
27
deepcore/datasets/imagenet.py
Normal file
27
deepcore/datasets/imagenet.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from torchvision import datasets, transforms
|
||||
from torch import tensor, long
|
||||
|
||||
|
||||
def ImageNet(data_path):
|
||||
channel = 3
|
||||
im_size = (224, 224)
|
||||
num_classes = 1000
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
normalize = transforms.Normalize(mean, std)
|
||||
dst_train = datasets.ImageNet(data_path, split="train", transform=transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
dst_test = datasets.ImageNet(data_path, split="val", transform=transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
class_names = dst_train.classes
|
||||
dst_train.targets = tensor(dst_train.targets, dtype=long)
|
||||
dst_test.targets = tensor(dst_test.targets, dtype=long)
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
Reference in New Issue
Block a user