release code
This commit is contained in:
131
Dassl.ProGrad.pytorch/datasets/da/digit5.py
Normal file
131
Dassl.ProGrad.pytorch/datasets/da/digit5.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import argparse
|
||||
from PIL import Image
|
||||
from scipy.io import loadmat
|
||||
|
||||
|
||||
def mkdir_if_missing(directory):
|
||||
if not osp.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
|
||||
def extract_and_save(data, label, save_dir):
|
||||
for i, (x, y) in enumerate(zip(data, label)):
|
||||
if x.shape[2] == 1:
|
||||
x = np.repeat(x, 3, axis=2)
|
||||
if y == 10:
|
||||
y = 0
|
||||
x = Image.fromarray(x, mode="RGB")
|
||||
save_path = osp.join(
|
||||
save_dir,
|
||||
str(i + 1).zfill(6) + "_" + str(y) + ".jpg"
|
||||
)
|
||||
x.save(save_path)
|
||||
|
||||
|
||||
def load_mnist(data_dir, raw_data_dir):
|
||||
filepath = osp.join(raw_data_dir, "mnist_data.mat")
|
||||
data = loadmat(filepath)
|
||||
|
||||
train_data = np.reshape(data["train_32"], (55000, 32, 32, 1))
|
||||
test_data = np.reshape(data["test_32"], (10000, 32, 32, 1))
|
||||
|
||||
train_label = np.nonzero(data["label_train"])[1]
|
||||
test_label = np.nonzero(data["label_test"])[1]
|
||||
|
||||
return train_data, test_data, train_label, test_label
|
||||
|
||||
|
||||
def load_mnist_m(data_dir, raw_data_dir):
|
||||
filepath = osp.join(raw_data_dir, "mnistm_with_label.mat")
|
||||
data = loadmat(filepath)
|
||||
|
||||
train_data = data["train"]
|
||||
test_data = data["test"]
|
||||
|
||||
train_label = np.nonzero(data["label_train"])[1]
|
||||
test_label = np.nonzero(data["label_test"])[1]
|
||||
|
||||
return train_data, test_data, train_label, test_label
|
||||
|
||||
|
||||
def load_svhn(data_dir, raw_data_dir):
|
||||
train = loadmat(osp.join(raw_data_dir, "svhn_train_32x32.mat"))
|
||||
train_data = train["X"].transpose(3, 0, 1, 2)
|
||||
train_label = train["y"][:, 0]
|
||||
|
||||
test = loadmat(osp.join(raw_data_dir, "svhn_test_32x32.mat"))
|
||||
test_data = test["X"].transpose(3, 0, 1, 2)
|
||||
test_label = test["y"][:, 0]
|
||||
|
||||
return train_data, test_data, train_label, test_label
|
||||
|
||||
|
||||
def load_syn(data_dir, raw_data_dir):
|
||||
filepath = osp.join(raw_data_dir, "syn_number.mat")
|
||||
data = loadmat(filepath)
|
||||
|
||||
train_data = data["train_data"]
|
||||
test_data = data["test_data"]
|
||||
|
||||
train_label = data["train_label"][:, 0]
|
||||
test_label = data["test_label"][:, 0]
|
||||
|
||||
return train_data, test_data, train_label, test_label
|
||||
|
||||
|
||||
def load_usps(data_dir, raw_data_dir):
|
||||
filepath = osp.join(raw_data_dir, "usps_28x28.mat")
|
||||
data = loadmat(filepath)["dataset"]
|
||||
|
||||
train_data = data[0][0].transpose(0, 2, 3, 1)
|
||||
test_data = data[1][0].transpose(0, 2, 3, 1)
|
||||
|
||||
train_data *= 255
|
||||
test_data *= 255
|
||||
|
||||
train_data = train_data.astype(np.uint8)
|
||||
test_data = test_data.astype(np.uint8)
|
||||
|
||||
train_label = data[0][1][:, 0]
|
||||
test_label = data[1][1][:, 0]
|
||||
|
||||
return train_data, test_data, train_label, test_label
|
||||
|
||||
|
||||
def main(data_dir):
|
||||
data_dir = osp.abspath(osp.expanduser(data_dir))
|
||||
raw_data_dir = osp.join(data_dir, "Digit-Five")
|
||||
|
||||
if not osp.exists(data_dir):
|
||||
raise FileNotFoundError('"{}" does not exist'.format(data_dir))
|
||||
|
||||
datasets = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
||||
|
||||
for name in datasets:
|
||||
print("Creating {}".format(name))
|
||||
|
||||
output = eval("load_" + name)(data_dir, raw_data_dir)
|
||||
train_data, test_data, train_label, test_label = output
|
||||
|
||||
print("# train: {}".format(train_data.shape[0]))
|
||||
print("# test: {}".format(test_data.shape[0]))
|
||||
|
||||
train_dir = osp.join(data_dir, name, "train_images")
|
||||
mkdir_if_missing(train_dir)
|
||||
test_dir = osp.join(data_dir, name, "test_images")
|
||||
mkdir_if_missing(test_dir)
|
||||
|
||||
extract_and_save(train_data, train_label, train_dir)
|
||||
extract_and_save(test_data, test_label, test_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"data_dir", type=str, help="directory containing Digit-Five/"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.data_dir)
|
||||
Reference in New Issue
Block a user