43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
import sys
|
|
import os.path as osp
|
|
from torchvision.datasets import STL10
|
|
|
|
from dassl.utils import mkdir_if_missing
|
|
|
|
|
|
def extract_and_save_image(dataset, save_dir):
|
|
if osp.exists(save_dir):
|
|
print('Folder "{}" already exists'.format(save_dir))
|
|
return
|
|
|
|
print('Extracting images to "{}" ...'.format(save_dir))
|
|
mkdir_if_missing(save_dir)
|
|
|
|
for i in range(len(dataset)):
|
|
img, label = dataset[i]
|
|
if label == -1:
|
|
label_name = "none"
|
|
else:
|
|
label_name = str(label)
|
|
imname = str(i).zfill(6) + "_" + label_name + ".jpg"
|
|
impath = osp.join(save_dir, imname)
|
|
img.save(impath)
|
|
|
|
|
|
def download_and_prepare(root):
|
|
train = STL10(root, split="train", download=True)
|
|
test = STL10(root, split="test")
|
|
unlabeled = STL10(root, split="unlabeled")
|
|
|
|
train_dir = osp.join(root, "train")
|
|
test_dir = osp.join(root, "test")
|
|
unlabeled_dir = osp.join(root, "unlabeled")
|
|
|
|
extract_and_save_image(train, train_dir)
|
|
extract_and_save_image(test, test_dir)
|
|
extract_and_save_image(unlabeled, unlabeled_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
download_and_prepare(sys.argv[1])
|