release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
import sys
import pprint as pp
import os.path as osp
from torchvision.datasets import STL10, CIFAR10
from dassl.utils import mkdir_if_missing
cifar_label2name = {
0: "airplane",
1: "car", # the original name was 'automobile'
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog", # conflict class
7: "horse",
8: "ship",
9: "truck",
}
stl_label2name = {
0: "airplane",
1: "bird",
2: "car",
3: "cat",
4: "deer",
5: "dog",
6: "horse",
7: "monkey", # conflict class
8: "ship",
9: "truck",
}
new_name2label = {
"airplane": 0,
"bird": 1,
"car": 2,
"cat": 3,
"deer": 4,
"dog": 5,
"horse": 6,
"ship": 7,
"truck": 8,
}
def extract_and_save_image(dataset, save_dir, discard, label2name):
if osp.exists(save_dir):
print('Folder "{}" already exists'.format(save_dir))
return
print('Extracting images to "{}" ...'.format(save_dir))
mkdir_if_missing(save_dir)
for i in range(len(dataset)):
img, label = dataset[i]
if label == discard:
continue
class_name = label2name[label]
label_new = new_name2label[class_name]
class_dir = osp.join(
save_dir,
str(label_new).zfill(3) + "_" + class_name
)
mkdir_if_missing(class_dir)
impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg")
img.save(impath)
def download_and_prepare(name, root, discarded_label, label2name):
print("Dataset: {}".format(name))
print("Root: {}".format(root))
print("Old labels:")
pp.pprint(label2name)
print("Discarded label: {}".format(discarded_label))
print("New labels:")
pp.pprint(new_name2label)
if name == "cifar":
train = CIFAR10(root, train=True, download=True)
test = CIFAR10(root, train=False)
else:
train = STL10(root, split="train", download=True)
test = STL10(root, split="test")
train_dir = osp.join(root, name, "train")
test_dir = osp.join(root, name, "test")
extract_and_save_image(train, train_dir, discarded_label, label2name)
extract_and_save_image(test, test_dir, discarded_label, label2name)
if __name__ == "__main__":
download_and_prepare("cifar", sys.argv[1], 6, cifar_label2name)
download_and_prepare("stl", sys.argv[1], 7, stl_label2name)

View File

@@ -0,0 +1,131 @@
import os
import numpy as np
import os.path as osp
import argparse
from PIL import Image
from scipy.io import loadmat
def mkdir_if_missing(directory):
if not osp.exists(directory):
os.makedirs(directory)
def extract_and_save(data, label, save_dir):
for i, (x, y) in enumerate(zip(data, label)):
if x.shape[2] == 1:
x = np.repeat(x, 3, axis=2)
if y == 10:
y = 0
x = Image.fromarray(x, mode="RGB")
save_path = osp.join(
save_dir,
str(i + 1).zfill(6) + "_" + str(y) + ".jpg"
)
x.save(save_path)
def load_mnist(data_dir, raw_data_dir):
filepath = osp.join(raw_data_dir, "mnist_data.mat")
data = loadmat(filepath)
train_data = np.reshape(data["train_32"], (55000, 32, 32, 1))
test_data = np.reshape(data["test_32"], (10000, 32, 32, 1))
train_label = np.nonzero(data["label_train"])[1]
test_label = np.nonzero(data["label_test"])[1]
return train_data, test_data, train_label, test_label
def load_mnist_m(data_dir, raw_data_dir):
filepath = osp.join(raw_data_dir, "mnistm_with_label.mat")
data = loadmat(filepath)
train_data = data["train"]
test_data = data["test"]
train_label = np.nonzero(data["label_train"])[1]
test_label = np.nonzero(data["label_test"])[1]
return train_data, test_data, train_label, test_label
def load_svhn(data_dir, raw_data_dir):
train = loadmat(osp.join(raw_data_dir, "svhn_train_32x32.mat"))
train_data = train["X"].transpose(3, 0, 1, 2)
train_label = train["y"][:, 0]
test = loadmat(osp.join(raw_data_dir, "svhn_test_32x32.mat"))
test_data = test["X"].transpose(3, 0, 1, 2)
test_label = test["y"][:, 0]
return train_data, test_data, train_label, test_label
def load_syn(data_dir, raw_data_dir):
filepath = osp.join(raw_data_dir, "syn_number.mat")
data = loadmat(filepath)
train_data = data["train_data"]
test_data = data["test_data"]
train_label = data["train_label"][:, 0]
test_label = data["test_label"][:, 0]
return train_data, test_data, train_label, test_label
def load_usps(data_dir, raw_data_dir):
filepath = osp.join(raw_data_dir, "usps_28x28.mat")
data = loadmat(filepath)["dataset"]
train_data = data[0][0].transpose(0, 2, 3, 1)
test_data = data[1][0].transpose(0, 2, 3, 1)
train_data *= 255
test_data *= 255
train_data = train_data.astype(np.uint8)
test_data = test_data.astype(np.uint8)
train_label = data[0][1][:, 0]
test_label = data[1][1][:, 0]
return train_data, test_data, train_label, test_label
def main(data_dir):
data_dir = osp.abspath(osp.expanduser(data_dir))
raw_data_dir = osp.join(data_dir, "Digit-Five")
if not osp.exists(data_dir):
raise FileNotFoundError('"{}" does not exist'.format(data_dir))
datasets = ["mnist", "mnist_m", "svhn", "syn", "usps"]
for name in datasets:
print("Creating {}".format(name))
output = eval("load_" + name)(data_dir, raw_data_dir)
train_data, test_data, train_label, test_label = output
print("# train: {}".format(train_data.shape[0]))
print("# test: {}".format(test_data.shape[0]))
train_dir = osp.join(data_dir, name, "train_images")
mkdir_if_missing(train_dir)
test_dir = osp.join(data_dir, name, "test_images")
mkdir_if_missing(test_dir)
extract_and_save(train_data, train_label, train_dir)
extract_and_save(test_data, test_label, test_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"data_dir", type=str, help="directory containing Digit-Five/"
)
args = parser.parse_args()
main(args.data_dir)

View File

@@ -0,0 +1,24 @@
# ------------------------------------------------------------------------
# ROOT is the root directory where you put your domain datasets.
#
# Suppose you wanna put the dataset under $DATA, which stores all the
# domain datasets, run the following command in your terminal to
# download VisDa17:
#
# $ sh visda17.sh $DATA
#------------------------------------------------------------------------
ROOT=$1
mkdir $ROOT/visda17
cd $ROOT/visda17
wget http://csr.bu.edu/ftp/visda17/clf/train.tar
tar xvf train.tar
wget http://csr.bu.edu/ftp/visda17/clf/validation.tar
tar xvf validation.tar
wget http://csr.bu.edu/ftp/visda17/clf/test.tar
tar xvf test.tar
wget https://raw.githubusercontent.com/VisionLearningGroup/taskcv-2017-public/master/classification/data/image_list.txt -O test/image_list.txt