release code
This commit is contained in:
50
Dassl.ProGrad.pytorch/datasets/ssl/cifar10_cifar100_svhn.py
Normal file
50
Dassl.ProGrad.pytorch/datasets/ssl/cifar10_cifar100_svhn.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import sys
|
||||
import os.path as osp
|
||||
from torchvision.datasets import SVHN, CIFAR10, CIFAR100
|
||||
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
|
||||
def extract_and_save_image(dataset, save_dir):
|
||||
if osp.exists(save_dir):
|
||||
print('Folder "{}" already exists'.format(save_dir))
|
||||
return
|
||||
|
||||
print('Extracting images to "{}" ...'.format(save_dir))
|
||||
mkdir_if_missing(save_dir)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
img, label = dataset[i]
|
||||
class_dir = osp.join(save_dir, str(label).zfill(3))
|
||||
mkdir_if_missing(class_dir)
|
||||
impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg")
|
||||
img.save(impath)
|
||||
|
||||
|
||||
def download_and_prepare(name, root):
|
||||
print("Dataset: {}".format(name))
|
||||
print("Root: {}".format(root))
|
||||
|
||||
if name == "cifar10":
|
||||
train = CIFAR10(root, train=True, download=True)
|
||||
test = CIFAR10(root, train=False)
|
||||
elif name == "cifar100":
|
||||
train = CIFAR100(root, train=True, download=True)
|
||||
test = CIFAR100(root, train=False)
|
||||
elif name == "svhn":
|
||||
train = SVHN(root, split="train", download=True)
|
||||
test = SVHN(root, split="test", download=True)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
train_dir = osp.join(root, name, "train")
|
||||
test_dir = osp.join(root, name, "test")
|
||||
|
||||
extract_and_save_image(train, train_dir)
|
||||
extract_and_save_image(test, test_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_and_prepare("cifar10", sys.argv[1])
|
||||
download_and_prepare("cifar100", sys.argv[1])
|
||||
download_and_prepare("svhn", sys.argv[1])
|
||||
Reference in New Issue
Block a user