Upload to Main
This commit is contained in:
@@ -0,0 +1 @@
|
||||
# __init__.py
|
||||
Binary file not shown.
@@ -0,0 +1,8 @@
|
||||
from .cifar10 import *
|
||||
from .cifar100 import *
|
||||
from .fashionmnist import *
|
||||
from .imagenet import *
|
||||
from .mnist import *
|
||||
from .qmnist import *
|
||||
from .svhn import *
|
||||
from .tinyimagenet import *
|
||||
@@ -0,0 +1,19 @@
|
||||
from torchvision import datasets, transforms
|
||||
from torch import tensor, long
|
||||
|
||||
|
||||
def CIFAR10(data_path):
|
||||
channel = 3
|
||||
im_size = (32, 32)
|
||||
num_classes = 10
|
||||
mean = [0.4914, 0.4822, 0.4465]
|
||||
std = [0.2470, 0.2435, 0.2616]
|
||||
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
||||
dst_train = datasets.CIFAR10(data_path, train=True, download=False, transform=transform)
|
||||
dst_test = datasets.CIFAR10(data_path, train=False, download=False, transform=transform)
|
||||
class_names = dst_train.classes
|
||||
dst_train.targets = tensor(dst_train.targets, dtype=long)
|
||||
dst_test.targets = tensor(dst_test.targets, dtype=long)
|
||||
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
@@ -0,0 +1,17 @@
|
||||
from torchvision import datasets, transforms
|
||||
from torch import tensor, long
|
||||
|
||||
|
||||
def CIFAR100(data_path):
|
||||
channel = 3
|
||||
im_size = (32, 32)
|
||||
num_classes = 100
|
||||
mean = [0.5071, 0.4865, 0.4409]
|
||||
std = [0.2673, 0.2564, 0.2762]
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
||||
dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform)
|
||||
dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform)
|
||||
class_names = dst_train.classes
|
||||
dst_train.targets = tensor(dst_train.targets, dtype=long)
|
||||
dst_test.targets = tensor(dst_test.targets, dtype=long)
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
@@ -0,0 +1,14 @@
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
|
||||
def FashionMNIST(data_path):
|
||||
channel = 1
|
||||
im_size = (28, 28)
|
||||
num_classes = 10
|
||||
mean = [0.2861]
|
||||
std = [0.3530]
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
||||
dst_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)
|
||||
dst_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)
|
||||
class_names = dst_train.classes
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
@@ -0,0 +1,27 @@
|
||||
from torchvision import datasets, transforms
|
||||
from torch import tensor, long
|
||||
|
||||
|
||||
def ImageNet(data_path):
|
||||
channel = 3
|
||||
im_size = (224, 224)
|
||||
num_classes = 1000
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
normalize = transforms.Normalize(mean, std)
|
||||
dst_train = datasets.ImageNet(data_path, split="train", transform=transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
dst_test = datasets.ImageNet(data_path, split="val", transform=transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
class_names = dst_train.classes
|
||||
dst_train.targets = tensor(dst_train.targets, dtype=long)
|
||||
dst_test.targets = tensor(dst_test.targets, dtype=long)
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
@@ -0,0 +1,25 @@
|
||||
from torchvision import datasets, transforms
|
||||
import numpy as np
|
||||
|
||||
|
||||
def MNIST(data_path, permuted=False, permutation_seed=None):
|
||||
channel = 1
|
||||
im_size = (28, 28)
|
||||
num_classes = 10
|
||||
mean = [0.1307]
|
||||
std = [0.3081]
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
||||
if permuted:
|
||||
np.random.seed(permutation_seed)
|
||||
pixel_permutation = np.random.permutation(28 * 28)
|
||||
transform = transforms.Compose(
|
||||
[transform, transforms.Lambda(lambda x: x.view(-1, 1)[pixel_permutation].view(1, 28, 28))])
|
||||
|
||||
dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
|
||||
dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
|
||||
class_names = [str(c) for c in range(num_classes)]
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
|
||||
|
||||
def permutedMNIST(data_path, permutation_seed=None):
|
||||
return MNIST(data_path, True, permutation_seed)
|
||||
@@ -0,0 +1,18 @@
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
|
||||
def QMNIST(data_path):
|
||||
channel = 1
|
||||
im_size = (28, 28)
|
||||
num_classes = 10
|
||||
mean = [0.1308]
|
||||
std = [0.3088]
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
||||
dst_train = datasets.QMNIST(data_path, train=True, download=True, transform=transform)
|
||||
dst_test = datasets.QMNIST(data_path, train=False, download=True, transform=transform)
|
||||
class_names = [str(c) for c in range(num_classes)]
|
||||
dst_train.targets = dst_train.targets[:, 0]
|
||||
dst_test.targets = dst_test.targets[:, 0]
|
||||
dst_train.compat = False
|
||||
dst_test.compat = False
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
@@ -0,0 +1,19 @@
|
||||
from torchvision import datasets, transforms
|
||||
from torch import tensor, long
|
||||
|
||||
|
||||
def SVHN(data_path):
|
||||
channel = 3
|
||||
im_size = (32, 32)
|
||||
num_classes = 10
|
||||
mean = [0.4377, 0.4438, 0.4728]
|
||||
std = [0.1980, 0.2010, 0.1970]
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
||||
dst_train = datasets.SVHN(data_path, split='train', download=True, transform=transform)
|
||||
dst_test = datasets.SVHN(data_path, split='test', download=True, transform=transform)
|
||||
class_names = [str(c) for c in range(num_classes)]
|
||||
dst_train.classes = list(class_names)
|
||||
dst_test.classes = list(class_names)
|
||||
dst_train.targets = tensor(dst_train.labels, dtype=long)
|
||||
dst_test.targets = tensor(dst_test.labels, dtype=long)
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
@@ -0,0 +1,35 @@
|
||||
from torchvision import datasets, transforms
|
||||
import os
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
|
||||
def TinyImageNet(data_path, downsize=True):
|
||||
if not os.path.exists(os.path.join(data_path, "tiny-imagenet-200")):
|
||||
url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" # 248MB
|
||||
print("Downloading Tiny-ImageNet")
|
||||
r = requests.get(url, stream=True)
|
||||
with open(os.path.join(data_path, "tiny-imagenet-200.zip"), "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
print("Unziping Tiny-ImageNet")
|
||||
with zipfile.ZipFile(os.path.join(data_path, "tiny-imagenet-200.zip")) as zf:
|
||||
zf.extractall(path=data_path)
|
||||
|
||||
channel = 3
|
||||
im_size = (32, 32) if downsize else (64, 64)
|
||||
num_classes = 200
|
||||
mean = (0.4802, 0.4481, 0.3975)
|
||||
std = (0.2770, 0.2691, 0.2821)
|
||||
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
||||
if downsize:
|
||||
transform = transforms.Compose([transforms.Resize(32), transform])
|
||||
|
||||
dst_train = datasets.ImageFolder(root=os.path.join(data_path, 'tiny-imagenet-200/train'), transform=transform)
|
||||
dst_test = datasets.ImageFolder(root=os.path.join(data_path, 'tiny-imagenet-200/test'), transform=transform)
|
||||
|
||||
class_names = dst_train.classes
|
||||
return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test
|
||||
@@ -0,0 +1,17 @@
|
||||
from .cal import *
|
||||
from .contextualdiversity import *
|
||||
from .coresetmethod import *
|
||||
from .craig import *
|
||||
from .deepfool import *
|
||||
from .earlytrain import *
|
||||
from .forgetting import *
|
||||
from .full import *
|
||||
from .glister import *
|
||||
from .grand import *
|
||||
from .gradmatch import *
|
||||
from .herding import *
|
||||
from .kcentergreedy import *
|
||||
from .submodular import *
|
||||
from .uncertainty import *
|
||||
from .uniform import *
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,146 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
from .methods_utils.euclidean import euclidean_dist_pair_np
|
||||
from .methods_utils.cossim import cossim_pair_np
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from .. import nets
|
||||
from copy import deepcopy
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class Cal(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None,
|
||||
balance=False, metric="euclidean", neighbors: int = 10, pretrain_model: str = "ResNet18", **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
||||
|
||||
self.balance = balance
|
||||
|
||||
assert neighbors > 0 and neighbors < 100
|
||||
self.neighbors = neighbors
|
||||
|
||||
if metric == "euclidean":
|
||||
self.metric = euclidean_dist_pair_np
|
||||
elif metric == "cossim":
|
||||
self.metric = lambda a, b: -1. * cossim_pair_np(a, b)
|
||||
elif callable(metric):
|
||||
self.metric = metric
|
||||
else:
|
||||
self.metric = euclidean_dist_pair_np
|
||||
|
||||
self.pretrain_model = pretrain_model
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
#Initial achievement, may not optimal
|
||||
def mixing_feature(self,img_fea,text_fea,lam=0.5):
|
||||
# return img_fea
|
||||
return lam*img_fea + (1-lam)*text_fea
|
||||
|
||||
def find_knn(self):
|
||||
"""
|
||||
Find k-nearest-neighbor data points with the pretrained embedding model
|
||||
:return: knn matrix
|
||||
"""
|
||||
|
||||
# Initialize pretrained model
|
||||
# model = nets.__dict__[self.pretrain_model](channel=self.args.channel, num_classes=self.args.num_classes,
|
||||
# im_size=(224, 224), record_embedding=True, no_grad=True,
|
||||
# pretrained=True).to(self.args.device)
|
||||
self.model.eval()
|
||||
probs = []
|
||||
# # Resize dst_train to 224*224
|
||||
# if self.args.im_size[0] != 224 or self.args.im_size[1] != 224:
|
||||
# dst_train = deepcopy(self.dst_train)
|
||||
# dst_train.transform = transforms.Compose([dst_train.transform, transforms.Resize(224)])
|
||||
# else:
|
||||
# dst_train = self.dst_train
|
||||
|
||||
# Calculate the distance matrix and return knn results
|
||||
if self.balance:
|
||||
knn = []
|
||||
for c in tqdm(range(self.num_classes)):
|
||||
print(f'Start processing class {c}/{self.num_classes}')
|
||||
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
||||
|
||||
# Start recording embedding vectors
|
||||
# batch_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(dst_train, class_index),
|
||||
# batch_size=self.args.selection_batch,
|
||||
# num_workers=self.args.workers)
|
||||
embdeddings = []
|
||||
c_probs = np.zeros([len(class_index), self.num_classes])
|
||||
data_loader = self.select_dm(self.dst_train, class_index, is_train=False)
|
||||
for i, batch in enumerate(data_loader):
|
||||
image, label = batch['img'].cuda(), batch['label'].cuda()
|
||||
img_f, text_f,logit = self.model(image, label, record=True)
|
||||
final_feature = self.mixing_feature(img_f,text_f)
|
||||
embdeddings.append(final_feature.cpu().numpy())
|
||||
c_probs[i * self.args.DATASET.SELECTION_BATCH_SIZE:(i + 1) * self.args.DATASET.SELECTION_BATCH_SIZE] = \
|
||||
torch.softmax(logit, dim=1).detach().cpu()
|
||||
|
||||
embdeddings = np.concatenate(embdeddings, axis=0)
|
||||
probs.append(c_probs)
|
||||
knn.append(np.argsort(self.metric(embdeddings), axis=1)[:, 1:(self.neighbors + 1)])
|
||||
self.probs = np.concatenate(probs,axis=0)
|
||||
return knn
|
||||
else:
|
||||
# Start recording embedding vectors
|
||||
embdeddings = []
|
||||
batch_loader = self.select_dm(self.dst_train, None, is_train=False)
|
||||
print(f'Start processing all class')
|
||||
for i, batch in enumerate(tqdm(batch_loader)):
|
||||
image, label = batch['img'].cuda(), batch['label'].cuda()
|
||||
img_f, text_f,logit = self.model(image, label, record=True)
|
||||
final_feature = self.mixing_feature(img_f, text_f)
|
||||
embdeddings.append(final_feature.cpu().numpy())
|
||||
probs[i * self.args.DATASET.SELECTION_BATCH_SIZE:(i + 1) * self.args.DATASET.SELECTION_BATCH_SIZE] = \
|
||||
torch.softmax(logit, dim=1).detach().cpu()
|
||||
embdeddings = np.concatenate(embdeddings, axis=0)
|
||||
self.probs = np.concatenate(probs, axis=0)
|
||||
return np.argsort(self.metric(embdeddings), axis=1)[:, 1:(self.neighbors + 1)]
|
||||
|
||||
def calc_kl(self, knn, index=None):
|
||||
self.model.eval()
|
||||
self.model.no_grad = True
|
||||
sample_num = self.n_train if index is None else len(index)
|
||||
# probs = np.zeros([sample_num, self.num_classes])
|
||||
#
|
||||
# batch_loader = torch.utils.data.DataLoader(
|
||||
# self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
# batch_size=self.args.selection_batch, num_workers=self.args.workers)
|
||||
# batch_num = len(batch_loader)
|
||||
#
|
||||
# for i, (inputs, _) in enumerate(batch_loader):
|
||||
# probs[i * self.args.selection_batch:(i + 1) * self.args.selection_batch] = torch.nn.functional.softmax(
|
||||
# self.model(inputs.to(self.args.device)), dim=1).detach().cpu()
|
||||
probs = self.probs[index]
|
||||
s = np.zeros(sample_num)
|
||||
for i in range(0, sample_num, self.args.DATASET.SELECTION_BATCH_SIZE):
|
||||
|
||||
print("| Caculating KL-divergence for batch [%3d/%3d] with batchsize [%3d]" % (i, sample_num, self.args.DATASET.SELECTION_BATCH_SIZE))
|
||||
aa = np.expand_dims(probs[i:(i + self.args.DATASET.SELECTION_BATCH_SIZE)], 1).repeat(self.neighbors, 1)
|
||||
bb = probs[knn[i:(i + self.args.DATASET.SELECTION_BATCH_SIZE)], :]
|
||||
s[i:(i + self.args.DATASET.SELECTION_BATCH_SIZE)] = np.mean(
|
||||
np.sum(0.5 * aa * np.log(aa / bb) + 0.5 * bb * np.log(bb / aa), axis=2), axis=1)
|
||||
self.model.no_grad = False
|
||||
return s
|
||||
|
||||
def finish_run(self):
|
||||
scores=[]
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int32)
|
||||
for c, knn in zip(range(self.num_classes), self.knn):
|
||||
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
||||
scores.append(self.calc_kl(knn, class_index))
|
||||
selection_result = np.append(selection_result, class_index[np.argsort(
|
||||
#self.calc_kl(knn, class_index))[::1][:round(self.fraction * len(class_index))]])
|
||||
scores[-1])[::1][:round(self.fraction * len(class_index))]])
|
||||
else:
|
||||
selection_result = np.argsort(self.calc_kl(self.knn))[::1][:self.coreset_size]
|
||||
return {"indices": selection_result, "scores":scores}
|
||||
|
||||
def select(self, **kwargs):
|
||||
self.knn = self.find_knn()
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
@@ -0,0 +1,33 @@
|
||||
from .kcentergreedy import kCenterGreedy
|
||||
import torch
|
||||
|
||||
|
||||
# Acknowlegement to:
|
||||
# https://github.com/sharat29ag/CDAL
|
||||
|
||||
|
||||
class ContextualDiversity(kCenterGreedy):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200,
|
||||
specific_model=None, balance=True, already_selected=[], torchvision_pretrain: bool = False, **kwargs):
|
||||
super(ContextualDiversity, self).__init__(dst_train, args, fraction, random_seed, epochs=epochs, specific_model=specific_model, balance=balance, already_selected=already_selected, torchvision_pretrain=torchvision_pretrain, **kwargs)
|
||||
self.metric = self._metric
|
||||
|
||||
def _metric(self, a_output, b_output):
|
||||
with torch.no_grad():
|
||||
# Overload self.metric function for kCenterGreedy Algorithm
|
||||
aa = a_output.view(a_output.shape[0], 1, a_output.shape[1]).repeat(1, b_output.shape[0], 1)
|
||||
bb = b_output.view(1, b_output.shape[0], b_output.shape[1]).repeat(a_output.shape[0], 1, 1)
|
||||
return torch.sum(0.5 * aa * torch.log(aa / bb) + 0.5 * bb * torch.log(bb / aa), dim=2)
|
||||
|
||||
def construct_matrix(self, index=None):
|
||||
self.model.eval()
|
||||
self.model.no_grad = True
|
||||
sample_num = self.n_train if index is None else len(index)
|
||||
matrix = torch.zeros([sample_num, self.args.num_classes], requires_grad=False).to(self.args.device)
|
||||
batch_loader = torch.utils.data.DataLoader(self.dst_train if index is None else
|
||||
torch.utils.data.Subset(self.dst_train, index), batch_size=self.args.selection_batch
|
||||
,num_workers=self.args.workers)
|
||||
for i, (inputs, _) in enumerate(batch_loader):
|
||||
matrix[i * self.args.selection_batch:min((i + 1) * self.args.selection_batch, sample_num)] = torch.nn.functional.softmax(self.model(inputs.to(self.args.device)), dim=1)
|
||||
self.model.no_grad = False
|
||||
return matrix
|
||||
@@ -0,0 +1,49 @@
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class CoresetMethod(object):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None,**kwargs):
|
||||
if fraction <= 0.0 or fraction > 1.0:
|
||||
raise ValueError("Illegal Coreset Size.")
|
||||
|
||||
self.dm = dst_train
|
||||
self.dst_train = dst_train.dataset.train_x
|
||||
self.num_classes = dst_train.dataset.num_classes
|
||||
self.fraction = fraction
|
||||
self.random_seed = random_seed
|
||||
self.index = []
|
||||
self.args = args
|
||||
self.dst_train_label = self.get_train_label(self.dst_train)
|
||||
self.n_train = len(self.dst_train)
|
||||
self.coreset_size = round(self.n_train * fraction)
|
||||
self.max_epoch = self.args.OPTIM_SELECTION.MAX_EPOCH
|
||||
|
||||
def select(self, **kwargs):
|
||||
return
|
||||
|
||||
def get_train_label(self,dst_train):
|
||||
####Readable
|
||||
ind = []
|
||||
for i,item in enumerate(dst_train):
|
||||
ind.append(item.label)
|
||||
return np.asarray(ind)
|
||||
def pre_run(self):
|
||||
self.train_indx = np.arange(self.n_train)
|
||||
print(f'Start pre-funing CLIP with all datasets by {self.max_epoch} epoch')
|
||||
file_save_name = self.args.DATASET.NAME + '_' + str(self.args.SEED) + '.pth'
|
||||
output_checkpoint_dir = os.path.join('checkpoints', file_save_name)
|
||||
if self.max_epoch > 0:
|
||||
|
||||
if os.path.exists(output_checkpoint_dir):
|
||||
print(f'The checkpiont exists! Load that shit')
|
||||
ckpt = torch.load(output_checkpoint_dir)
|
||||
self.model.load_state_dict(ckpt)
|
||||
else:
|
||||
for epoch in range(self.epoch, self.max_epoch):
|
||||
# list_of_train_idx = np.random.choice(np.arange(self.n_pretrain if self.if_dst_pretrain else self.n_train),
|
||||
# self.n_pretrain_size, replace=False)
|
||||
self.before_epoch() # PASS
|
||||
self.train(epoch)
|
||||
self.test(epoch)
|
||||
self.after_epoch()
|
||||
torch.save(self.model.state_dict(), output_checkpoint_dir)
|
||||
@@ -0,0 +1,126 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch
|
||||
from .methods_utils import FacilityLocation, submodular_optimizer
|
||||
import numpy as np
|
||||
from .methods_utils.euclidean import euclidean_dist_pair_np
|
||||
from ..nets.nets_utils import MyDataParallel
|
||||
from tqdm import tqdm
|
||||
|
||||
class Craig(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None,
|
||||
balance=True, greedy="LazyGreedy", **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
||||
|
||||
if greedy not in submodular_optimizer.optimizer_choices:
|
||||
raise ModuleNotFoundError("Greedy optimizer not found.")
|
||||
self._greedy = greedy
|
||||
self.balance = balance
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def before_run(self):
|
||||
pass
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
# def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
# if batch_idx % self.args.print_freq == 0:
|
||||
# print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
|
||||
# epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item()))
|
||||
|
||||
# def calc_gradient(self, index=None):
|
||||
# self.model.eval()
|
||||
#
|
||||
# batch_loader = torch.utils.data.DataLoader(
|
||||
# self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
# batch_size=self.args.selection_batch, num_workers=self.args.workers)
|
||||
# sample_num = len(self.dst_val.targets) if index is None else len(index)
|
||||
# self.embedding_dim = self.model.get_last_layer().in_features
|
||||
#
|
||||
# gradients = []
|
||||
#
|
||||
# for i, (input, targets) in enumerate(batch_loader):
|
||||
# self.model_optimizer.zero_grad()
|
||||
# outputs = self.model(input.to(self.args.device))
|
||||
# loss = self.criterion(outputs.requires_grad_(True),
|
||||
# targets.to(self.args.device)).sum()
|
||||
# batch_num = targets.shape[0]
|
||||
# with torch.no_grad():
|
||||
# bias_parameters_grads = torch.autograd.grad(loss, outputs)[0]
|
||||
# weight_parameters_grads = self.model.embedding_recorder.embedding.view(batch_num, 1,
|
||||
# self.embedding_dim).repeat(1,
|
||||
# self.args.num_classes,
|
||||
# 1) * bias_parameters_grads.view(
|
||||
# batch_num, self.args.num_classes, 1).repeat(1, 1, self.embedding_dim)
|
||||
# gradients.append(
|
||||
# torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1).cpu().numpy())
|
||||
#
|
||||
# gradients = np.concatenate(gradients, axis=0)
|
||||
#
|
||||
# self.model.train()
|
||||
# return euclidean_dist_pair_np(gradients)
|
||||
|
||||
def calc_weights(self, matrix, result):
|
||||
min_sample = np.argmax(matrix[result], axis=0)
|
||||
weights = np.ones(np.sum(result) if result.dtype == bool else len(result))
|
||||
for i in min_sample:
|
||||
weights[i] = weights[i] + 1
|
||||
return weights
|
||||
|
||||
def finish_run(self):
|
||||
if isinstance(self.model, MyDataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
self.model.no_grad = True
|
||||
grad = self.calc_gradient()
|
||||
grad_matrix = euclidean_dist_pair_np(grad)
|
||||
# with self.model.embedding_recorder:
|
||||
if self.balance:
|
||||
|
||||
# Do selection by class
|
||||
selection_result = np.array([], dtype=np.int32)
|
||||
weights = np.array([])
|
||||
for c in tqdm(range(self.num_classes)):
|
||||
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
||||
matrix = -1. * grad_matrix[class_index[:,None],class_index] # Change to column index
|
||||
# matrix = -1. * self.calc_gradient(class_index)
|
||||
matrix -= np.min(matrix) - 1e-3 #The least is zero
|
||||
submod_function = FacilityLocation(index=class_index, similarity_matrix=matrix)
|
||||
submod_optimizer = submodular_optimizer.__dict__[self._greedy](args=self.args, index=class_index,
|
||||
budget=round(self.fraction * len(
|
||||
class_index)))
|
||||
class_result = submod_optimizer.select(gain_function=submod_function.calc_gain,
|
||||
update_state=submod_function.update_state)
|
||||
selection_result = np.append(selection_result, class_result)
|
||||
weights = np.append(weights, self.calc_weights(matrix, np.isin(class_index, class_result)))
|
||||
else:
|
||||
matrix = np.zeros([self.n_train, self.n_train])
|
||||
all_index = np.arange(self.n_train)
|
||||
for c in range(self.num_classes): # Sparse Matrix
|
||||
class_index = np.arange(self.n_train)[self.dst_train_label== c]
|
||||
matrix[np.ix_(class_index, class_index)] = -1. * self.calc_gradient(class_index)
|
||||
matrix[np.ix_(class_index, class_index)] -= np.min(matrix[np.ix_(class_index, class_index)]) - 1e-3
|
||||
submod_function = FacilityLocation(index=all_index, similarity_matrix=matrix)
|
||||
submod_optimizer = submodular_optimizer.__dict__[self._greedy](args=self.args, index=all_index,
|
||||
budget=self.coreset_size)
|
||||
selection_result = submod_optimizer.select(gain_function=submod_function.calc_gain_batch,
|
||||
update_state=submod_function.update_state,
|
||||
batch=self.args.selection_batch)
|
||||
weights = self.calc_weights(matrix, selection_result)
|
||||
self.model.no_grad = False
|
||||
return {"indices": selection_result, "weights": weights}
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
@@ -0,0 +1,120 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DeepFool(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200,
|
||||
specific_model=None, balance: bool = False, max_iter: int = 50, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
||||
|
||||
self.balance = balance
|
||||
self.max_iter = max_iter
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
if batch_idx % self.args.print_freq == 0:
|
||||
print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
|
||||
epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item()))
|
||||
|
||||
def finish_run(self):
|
||||
self.model.no_grad = False
|
||||
|
||||
# Create a data loader for self.dst_train with batch size self.args.selection_batch
|
||||
batch_loader = torch.utils.data.DataLoader(self.dst_train, batch_size=self.args.selection_batch
|
||||
, num_workers=self.args.workers)
|
||||
|
||||
r = np.zeros(self.n_train, dtype=np.float32)
|
||||
batch_num = len(batch_loader)
|
||||
for i, (inputs, targets) in enumerate(batch_loader):
|
||||
if i % self.args.print_freq == 0:
|
||||
print('| Selecting Batch [%3d/%3d]' % (i + 1, batch_num))
|
||||
r[(i * self.args.selection_batch):(i * self.args.selection_batch + targets.shape[0])] = self.deep_fool(
|
||||
inputs)
|
||||
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int64)
|
||||
for c in range(self.args.num_classes):
|
||||
class_index = np.arange(self.n_train)[self.dst_train.targets == c]
|
||||
selection_result = np.append(selection_result, class_index[
|
||||
r[class_index].argsort()[:round(len(class_index) * self.fraction)]])
|
||||
else:
|
||||
selection_result = r.argsort()[:self.coreset_size]
|
||||
return {"indices": selection_result, "scores": r}
|
||||
|
||||
def deep_fool(self, inputs):
|
||||
# Here, start running DeepFool algorithm.
|
||||
self.model.eval()
|
||||
|
||||
# Initialize a boolean mask indicating if selection has been stopped at corresponding positions.
|
||||
sample_size = inputs.shape[0]
|
||||
boolean_mask = np.ones(sample_size, dtype=bool)
|
||||
all_idx = np.arange(sample_size)
|
||||
|
||||
# A matrix to store total pertubations.
|
||||
r_tot = np.zeros([sample_size, inputs.shape[1] * inputs.shape[2] * inputs.shape[3]])
|
||||
|
||||
# Set requires_grad for inputs.
|
||||
cur_inputs = inputs.requires_grad_(True).to(self.args.device)
|
||||
|
||||
original_shape = inputs.shape[1:]
|
||||
|
||||
# set requires_grad for all parametres in network as False to accelerate autograd
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.model.no_grad = True
|
||||
first_preds = self.model(cur_inputs).argmax(dim=1)
|
||||
self.model.no_grad = False
|
||||
|
||||
for i in range(self.max_iter):
|
||||
f_all = self.model(cur_inputs)
|
||||
|
||||
w_k = []
|
||||
for c in range(self.args.num_classes):
|
||||
w_k.append(torch.autograd.grad(f_all[:, c].sum(), cur_inputs,
|
||||
retain_graph=False if c + 1 == self.args.num_classes else True)[
|
||||
0].flatten(1))
|
||||
w_k = torch.stack(w_k, dim=0)
|
||||
w_k = w_k - w_k[first_preds, boolean_mask[boolean_mask]].unsqueeze(0)
|
||||
w_k_norm = w_k.norm(dim=2)
|
||||
|
||||
w_k_norm[first_preds, boolean_mask[
|
||||
boolean_mask]] = 1. # Set w_k_norm for preds positions to 1. to avoid division by zero.
|
||||
|
||||
l_all = (f_all - f_all[boolean_mask[boolean_mask], first_preds].unsqueeze(1)).detach().abs() / w_k_norm.T
|
||||
l_all[boolean_mask[
|
||||
boolean_mask], first_preds] = np.inf # Set l_k for preds positions to inf, as the argmin for each
|
||||
# row will be calculated soon.
|
||||
|
||||
l_hat = l_all.argmin(dim=1)
|
||||
r_i = l_all[boolean_mask[boolean_mask], l_hat].unsqueeze(1) / w_k_norm[
|
||||
l_hat, boolean_mask[boolean_mask]].T.unsqueeze(1) * w_k[l_hat, boolean_mask[boolean_mask]]
|
||||
|
||||
# Update r_tot values.
|
||||
r_tot[boolean_mask] += r_i.cpu().numpy()
|
||||
|
||||
cur_inputs += r_i.reshape([r_i.shape[0]] + list(original_shape))
|
||||
|
||||
# Re-input the updated sample into the network and get new predictions.
|
||||
self.model.no_grad = True
|
||||
preds = self.model(cur_inputs).argmax(dim=1)
|
||||
self.model.no_grad = False
|
||||
|
||||
# In DeepFool algorithm, the iteration stops when the updated sample produces a different prediction
|
||||
# in the model.
|
||||
index_unfinished = (preds == first_preds)
|
||||
if torch.all(~index_unfinished):
|
||||
break
|
||||
|
||||
cur_inputs = cur_inputs[index_unfinished]
|
||||
first_preds = first_preds[index_unfinished]
|
||||
boolean_mask[all_idx[boolean_mask][~index_unfinished.cpu().numpy()]] = False
|
||||
|
||||
return (r_tot * r_tot).sum(axis=1)
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
@@ -0,0 +1,322 @@
|
||||
from .coresetmethod import CoresetMethod
|
||||
import torch, time
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from .. import nets
|
||||
from torchvision import transforms
|
||||
from datasets.data_manager import select_dm_loader
|
||||
from dassl.utils import MetricMeter, AverageMeter
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
import datetime
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
class EarlyTrain(CoresetMethod):
|
||||
'''
|
||||
Core code for training related to coreset selection methods when pre-training is required.
|
||||
'''
|
||||
|
||||
def __init__(self, dst_train, args,fraction=0.5, random_seed=None, epochs=200, specific_model=None,
|
||||
torchvision_pretrain: bool = False, dst_pretrain_dict: dict = {}, fraction_pretrain=1., dst_test=None,
|
||||
**kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed)
|
||||
self.epochs = epochs
|
||||
self.n_train = len(self.dst_train)
|
||||
self.coreset_size = round(self.n_train * fraction)
|
||||
self.model = specific_model
|
||||
self.train_loader = self.dm.train_loader_x
|
||||
self.test_loader = self.dm.test_loader
|
||||
|
||||
|
||||
if kwargs:
|
||||
# self.text_feature = kwargs['text_feature']
|
||||
self.optim = kwargs['optim']
|
||||
self.sche = kwargs['schedule']
|
||||
self.scar = kwargs['scar']
|
||||
|
||||
|
||||
|
||||
self.start_epoch = self.epoch = 0
|
||||
self.max_epoch = self.args.OPTIM_SELECTION.MAX_EPOCH
|
||||
|
||||
if fraction_pretrain <= 0. or fraction_pretrain > 1.:
|
||||
raise ValueError("Illegal pretrain fraction value.")
|
||||
self.fraction_pretrain = fraction_pretrain
|
||||
|
||||
if dst_pretrain_dict.__len__() != 0:
|
||||
dict_keys = dst_pretrain_dict.keys()
|
||||
if 'im_size' not in dict_keys or 'channel' not in dict_keys or 'dst_train' not in dict_keys or \
|
||||
'num_classes' not in dict_keys:
|
||||
raise AttributeError(
|
||||
'Argument dst_pretrain_dict must contain imszie, channel, dst_train and num_classes.')
|
||||
if dst_pretrain_dict['im_size'][0] != args.im_size[0] or dst_pretrain_dict['im_size'][0] != args.im_size[0]:
|
||||
raise ValueError("im_size of pretrain dataset does not match that of the training dataset.")
|
||||
if dst_pretrain_dict['channel'] != args.channel:
|
||||
raise ValueError("channel of pretrain dataset does not match that of the training dataset.")
|
||||
if dst_pretrain_dict['num_classes'] != args.num_classes:
|
||||
self.num_classes_mismatch()
|
||||
|
||||
self.dst_pretrain_dict = dst_pretrain_dict
|
||||
self.torchvision_pretrain = torchvision_pretrain
|
||||
self.if_dst_pretrain = (len(self.dst_pretrain_dict) != 0)
|
||||
|
||||
if torchvision_pretrain:
|
||||
# Pretrained models in torchvision only accept 224*224 inputs, therefore we resize current
|
||||
# datasets to 224*224.
|
||||
if args.im_size[0] != 224 or args.im_size[1] != 224:
|
||||
self.dst_train = deepcopy(dst_train)
|
||||
self.dst_train.transform = transforms.Compose([self.dst_train.transform, transforms.Resize(224)])
|
||||
if self.if_dst_pretrain:
|
||||
self.dst_pretrain_dict['dst_train'] = deepcopy(dst_pretrain_dict['dst_train'])
|
||||
self.dst_pretrain_dict['dst_train'].transform = transforms.Compose(
|
||||
[self.dst_pretrain_dict['dst_train'].transform, transforms.Resize(224)])
|
||||
if self.if_dst_pretrain:
|
||||
self.n_pretrain = len(self.dst_pretrain_dict['dst_train'])
|
||||
self.n_pretrain_size = round(
|
||||
self.fraction_pretrain * (self.n_pretrain if self.if_dst_pretrain else self.n_train))
|
||||
self.dst_test = dst_test
|
||||
|
||||
|
||||
def train(self, epoch, list_of_train_idx=None, **kwargs):
|
||||
""" Train model for one epoch """
|
||||
|
||||
self.before_train()
|
||||
self.model.train()
|
||||
|
||||
losses = MetricMeter()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
|
||||
|
||||
end = time.time()
|
||||
|
||||
print('\n=> Training Pre-tuning Epoch #%d' % epoch)
|
||||
train_loader = select_dm_loader(self.args,self.dst_train,is_train=True)
|
||||
self.num_batches = len(train_loader)
|
||||
|
||||
# trainset_permutation_inds = np.random.permutation(list_of_train_idx)
|
||||
# batch_sampler = torch.utils.data.BatchSampler(trainset_permutation_inds, batch_size=self.args.selection_batch,
|
||||
# drop_last=False)
|
||||
# trainset_permutation_inds = list(batch_sampler)
|
||||
#
|
||||
# train_loader = torch.utils.data.DataLoader(self.dst_pretrain_dict['dst_train'] if self.if_dst_pretrain
|
||||
# else self.dst_train, shuffle=False, batch_sampler=batch_sampler,
|
||||
#
|
||||
#
|
||||
# num_workers=self.args.workers, pin_memory=True)
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
data_time.update(time.time() - end)
|
||||
image, label,real_ind = batch['img'].cuda(),batch['label'].cuda(),batch['index'].cuda()
|
||||
|
||||
model = self.model
|
||||
optim = self.optim
|
||||
scaler = self.scar
|
||||
|
||||
prec = self.args.TRAINER.MAPLE.PREC
|
||||
if prec == "amp":
|
||||
with autocast():
|
||||
loss,outputs = model(image, label)
|
||||
optim.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
loss,outputs = model(image, label)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
self.after_loss(outputs, loss, label, real_ind, epoch)
|
||||
self.while_update(outputs, loss, label, epoch, i, self.args.DATALOADER.TRAIN_X.BATCH_SIZE)
|
||||
|
||||
loss_summary = {"loss": loss.item()}
|
||||
|
||||
if (i + 1) == self.num_batches:
|
||||
self.sche.step()
|
||||
batch_time.update(time.time() - end)
|
||||
losses.update(loss_summary)
|
||||
|
||||
meet_freq = (i + 1) % self.args.TRAIN.PRINT_FREQ == 0
|
||||
only_few_batches = self.num_batches < self.args.TRAIN.PRINT_FREQ
|
||||
|
||||
if meet_freq or only_few_batches:
|
||||
nb_remain = 0
|
||||
nb_remain += self.num_batches - i - 1
|
||||
nb_remain += (self.max_epoch - self.epoch - 1) * self.num_batches
|
||||
eta_seconds = batch_time.avg * nb_remain
|
||||
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
|
||||
info = []
|
||||
info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"]
|
||||
info += [f"batch [{i + 1}/{self.num_batches}]"]
|
||||
info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
|
||||
info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
|
||||
info += [f"{losses}"]
|
||||
info += [f"lr {optim.param_groups[0]['lr']:.4e}"]
|
||||
info += [f"eta {eta}"]
|
||||
print(" ".join(info))
|
||||
|
||||
# n_iter = self.epoch * self.num_batches + i
|
||||
# for name, meter in losses.meters.items():
|
||||
# self.write_scalar("train/" + name, meter.avg, n_iter)
|
||||
# self.write_scalar("train/lr", self.get_current_lr(), n_iter)
|
||||
|
||||
end = time.time()
|
||||
|
||||
return self.finish_train()
|
||||
|
||||
def run(self):
|
||||
self.train_indx = np.arange(self.n_train)
|
||||
self.before_run()
|
||||
print(f'Start pre-funing CLIP with all datasets by {self.max_epoch} epoch')
|
||||
file_save_name = self.args.DATASET.NAME + '_' + str(self.args.SEED) + '.pth'
|
||||
output_checkpoint_dir = os.path.join('checkpoints', file_save_name)
|
||||
if self.max_epoch > 0:
|
||||
|
||||
if os.path.exists(output_checkpoint_dir):
|
||||
print(f'The checkpiont exists! Load that shit')
|
||||
ckpt = torch.load(output_checkpoint_dir)
|
||||
self.model.load_state_dict(ckpt)
|
||||
else:
|
||||
for epoch in range(self.epoch,self.max_epoch):
|
||||
# list_of_train_idx = np.random.choice(np.arange(self.n_pretrain if self.if_dst_pretrain else self.n_train),
|
||||
# self.n_pretrain_size, replace=False)
|
||||
self.before_epoch() #PASS
|
||||
self.train(epoch)
|
||||
self.test(epoch)
|
||||
self.after_epoch()
|
||||
torch.save(self.model.state_dict(),output_checkpoint_dir)
|
||||
|
||||
return self.finish_run()
|
||||
|
||||
def test(self, epoch):
|
||||
self.model.no_grad = True
|
||||
self.model.eval()
|
||||
|
||||
|
||||
correct = 0.
|
||||
total = 0.
|
||||
|
||||
print('\n=> Testing Tuning Epoch #%d' % epoch)
|
||||
|
||||
for batch_idx, batch in enumerate(self.test_loader):
|
||||
image, target = batch['img'].cuda(), batch['label']
|
||||
output = self.model(image, target.cuda())
|
||||
|
||||
|
||||
predicted = torch.max(output.data, 1).indices.cpu()
|
||||
correct += predicted.eq(target).sum().item()
|
||||
total += target.size(0)
|
||||
|
||||
# if batch_idx % self.args.print_freq == 0:
|
||||
# print('| Test Epoch [%3d/%3d] Iter[%3d/%3d]\t\t Test Acc: %.3f%%' % (
|
||||
# epoch, self.epochs, batch_idx + 1, (round(len(self.dst_test) * self.args.selection_test_fraction) //
|
||||
# self.args.selection_batch) + 1, loss.item(),
|
||||
# 100. * correct / total))
|
||||
print(f'| Test Epoch {epoch} Test Acc: {100. * correct / total:.3f}%')
|
||||
self.model.no_grad = False
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
pass
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
||||
pass
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
pass
|
||||
|
||||
def finish_train(self):
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def before_run(self):
|
||||
pass
|
||||
|
||||
def finish_run(self):
|
||||
pass
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
|
||||
def select_without_train(self, **kwargs):
|
||||
return self.finish_run()
|
||||
|
||||
@torch.no_grad()
|
||||
def calcluate_clip_probability(self,batch):
|
||||
input = batch["img"].cuda()
|
||||
|
||||
self.specific_model = self.specific_model.cuda()
|
||||
image_features = self.specific_model.encode_image(input)
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
logit_scale = self.specific_model.logit_scale.exp()
|
||||
return logit_scale * image_features @ self.text_feature.t()
|
||||
|
||||
# using the defined select_dm
|
||||
def select_dm(self,data,ind=None,is_train=None):
|
||||
return select_dm_loader(self.args,data,ind,is_train)
|
||||
|
||||
|
||||
def parse_batch_test(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
|
||||
input = input.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
return input, label
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"].cuda()
|
||||
label = batch["label"].cuda()
|
||||
domain = batch["index"].cuda()
|
||||
|
||||
return input, label, domain
|
||||
|
||||
|
||||
|
||||
def calc_gradient(self, index=None):
|
||||
'''
|
||||
Calculate gradients matrix on current network for specified training dataset.
|
||||
'''
|
||||
self.model.eval()
|
||||
data_loader = self.select_dm(self.dst_train, index, is_train=False)
|
||||
# Initialize a matrix to save gradients.
|
||||
# (on cpu)
|
||||
gradients = []
|
||||
lam = 0.5
|
||||
for i, batch in enumerate(tqdm(data_loader)):
|
||||
self.optim.zero_grad()
|
||||
image, label = batch['img'].cuda(), batch['label'].cuda()
|
||||
bs_size = image.shape[0]
|
||||
loss, visual_embedding, logit= self.model(image, label, cal_gradient=True)
|
||||
embed_dim = visual_embedding.shape[-1]
|
||||
with torch.no_grad():
|
||||
bias_parameters_grads = torch.autograd.grad(loss, logit)[0]
|
||||
weight_parameters_grads = visual_embedding.view(bs_size, 1,
|
||||
-1).repeat(1, self.num_classes, 1) * \
|
||||
bias_parameters_grads.view(bs_size, self.num_classes,
|
||||
1).repeat(1, 1, embed_dim)
|
||||
# weight_parameters_grads_t = text_embedding.view(bs_size, 1,
|
||||
# -1).repeat(1, self.num_classes, 1) * \
|
||||
# bias_parameters_grads.view(bs_size, self.num_classes,
|
||||
# 1).repeat(1, 1, embed_dim)
|
||||
# final_weight = torch.abs(weight_parameters_grads-weight_parameters_grads_t)
|
||||
gradients.append(torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)],
|
||||
dim=1).cpu().numpy())
|
||||
|
||||
gradients = np.concatenate(gradients, axis=0, dtype=np.float32)
|
||||
print('Finish Gradient Calculation')
|
||||
self.model.train()
|
||||
return gradients
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch, time
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from datasets.data_manager import select_dm_loader
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/mtoneva/example_forgetting
|
||||
|
||||
class Forgetting(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None, balance=True, #default True
|
||||
dst_test=None, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model=specific_model,
|
||||
dst_test=dst_test,**kwargs)
|
||||
|
||||
self.balance = balance
|
||||
|
||||
def get_hms(self, seconds):
|
||||
# Format time for printing purposes
|
||||
|
||||
m, s = divmod(seconds, 60)
|
||||
h, m = divmod(m, 60)
|
||||
|
||||
return h, m, s
|
||||
|
||||
def before_train(self):
|
||||
self.train_loss = 0.
|
||||
self.correct = 0.
|
||||
self.total = 0.
|
||||
|
||||
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
||||
with torch.no_grad():
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
|
||||
cur_acc = (predicted == targets).clone().detach().requires_grad_(False).type(torch.float32)
|
||||
self.forgetting_events[batch_inds.clone().detach()[(self.last_acc[batch_inds]-cur_acc)>0.01]]+=1.
|
||||
self.last_acc[batch_inds] = cur_acc
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
pass
|
||||
# self.train_loss += loss.item()
|
||||
# self.total += targets.size(0)
|
||||
# _, predicted = torch.max(outputs.data, 1)
|
||||
# self.correct += predicted.eq(targets.data).cpu().sum()
|
||||
#
|
||||
# if batch_idx % self.args.print_freq == 0:
|
||||
# print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%' % (
|
||||
# epoch, self.epochs, batch_idx + 1, (self.n_train // batch_size) + 1, loss.item(),
|
||||
# 100. * self.correct.item() / self.total))
|
||||
|
||||
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
# epoch_time = time.time() - self.start_time
|
||||
# self.elapsed_time += epoch_time
|
||||
# print('| Elapsed time : %d:%02d:%02d' % (self.get_hms(self.elapsed_time)))
|
||||
|
||||
def before_run(self):
|
||||
self.elapsed_time = 0
|
||||
self.forgetting_events = torch.zeros(self.n_train, requires_grad=False).cuda()
|
||||
self.test_initial_acc()
|
||||
# self.last_acc = torch.zeros(self.n_train, requires_grad=False).cuda()
|
||||
|
||||
def test_initial_acc(self):
|
||||
self.model.no_grad = True
|
||||
self.model.eval()
|
||||
self.last_acc = torch.zeros(self.n_train, requires_grad=False).cuda()
|
||||
|
||||
print('\n=> Testing Initial acc for Forgetting')
|
||||
train_loader = select_dm_loader(self.args, self.dst_train)
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
image, target,batch_inds = batch['img'].cuda(), batch['label'].cuda(), batch['index'].cuda()
|
||||
output = self.model(image, target)
|
||||
predicted = torch.max(output.data, 1).indices
|
||||
|
||||
cur_acc = (predicted == target).clone().detach().requires_grad_(False).type(torch.float32)
|
||||
self.last_acc[batch_inds] = cur_acc
|
||||
|
||||
|
||||
self.model.no_grad = False
|
||||
|
||||
def finish_run(self):
|
||||
pass
|
||||
|
||||
def select(self, **kwargs):
|
||||
self.run()
|
||||
|
||||
if not self.balance:
|
||||
top_examples = self.train_indx[np.argsort(self.forgetting_events.cpu().numpy())][::-1][:self.coreset_size]
|
||||
else:
|
||||
top_examples = np.array([], dtype=np.int64)
|
||||
for c in range(self.num_classes):
|
||||
c_indx = self.train_indx[self.dst_train_label == c]
|
||||
budget = round(self.fraction * len(c_indx))
|
||||
top_examples = np.append(top_examples,
|
||||
c_indx[np.argsort(self.forgetting_events[c_indx].cpu().numpy())[::-1][:budget]])
|
||||
|
||||
return {"indices": top_examples, "scores": self.forgetting_events}
|
||||
@@ -0,0 +1,10 @@
|
||||
import numpy as np
|
||||
from .coresetmethod import CoresetMethod
|
||||
|
||||
|
||||
class Full(CoresetMethod):
|
||||
def __init__(self, dst_train, args, fraction, random_seed, **kwargs):
|
||||
self.n_train = len(dst_train)
|
||||
|
||||
def select(self, **kwargs):
|
||||
return {"indices": np.arange(self.n_train)}
|
||||
@@ -0,0 +1,210 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
from .methods_utils import submodular_optimizer
|
||||
import torch
|
||||
import numpy as np
|
||||
from ..nets.nets_utils import MyDataParallel
|
||||
from tqdm import tqdm
|
||||
|
||||
class Glister(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None,
|
||||
balance: bool = True, greedy="StochasticGreedy", eta=None, dst_val=None, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
||||
|
||||
self.balance = balance
|
||||
self.eta = args.OPTIM_SELECTION.LR if eta is None else eta
|
||||
self.dst_val = dst_train.dataset.val
|
||||
self.dst_val_label = self.get_train_label(self.dst_val)
|
||||
self.n_val = len(self.dst_val)
|
||||
|
||||
if greedy not in submodular_optimizer.optimizer_choices:
|
||||
raise ModuleNotFoundError("Greedy optimizer not found.")
|
||||
self._greedy = greedy
|
||||
|
||||
def calc_gradient(self, index=None,val=False):
|
||||
'''
|
||||
Calculate gradients matrix on current network for specified training dataset.
|
||||
'''
|
||||
self.model.eval()
|
||||
if val:
|
||||
val_str = 'Val'
|
||||
data_loader = self.select_dm(self.dst_val, index, is_train=False)
|
||||
# self.init_out = []
|
||||
# self.init_emb = []
|
||||
# self.init_y = []
|
||||
else:
|
||||
val_str = 'Train'
|
||||
data_loader = self.select_dm(self.dst_train, index, is_train=False)
|
||||
# Initialize a matrix to save gradients.
|
||||
# (on cpu)
|
||||
gradients = []
|
||||
|
||||
for i, batch in enumerate(tqdm(data_loader)):
|
||||
|
||||
self.optim.zero_grad()
|
||||
image, label = batch['img'].cuda(), batch['label'].cuda()
|
||||
bs_size = image.shape[0]
|
||||
loss,visual_embedding,logit = self.model(image,label,cal_gradient=True)
|
||||
embed_dim = visual_embedding.shape[-1]
|
||||
with torch.no_grad():
|
||||
bias_parameters_grads = torch.autograd.grad(loss, logit)[0]
|
||||
weight_parameters_grads = visual_embedding.view(bs_size, 1,
|
||||
-1).repeat(1, self.num_classes, 1) *\
|
||||
bias_parameters_grads.view(bs_size, self.num_classes,
|
||||
1).repeat(1, 1, embed_dim)
|
||||
gradients.append(torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)],
|
||||
dim=1).cpu().numpy())
|
||||
|
||||
# if val:
|
||||
# self.init_out.append(logit.cpu())
|
||||
# self.init_emb.append(visual_embedding.cpu())
|
||||
# self.init_y.append(label.cpu())
|
||||
|
||||
|
||||
# if val:
|
||||
# with torch.no_grad():
|
||||
# self.init_out = torch.cat(self.init_out,dim=0).numpy().astype(dtype=np.float32)
|
||||
# self.init_emb = torch.cat(self.init_emb,dim=0).numpy().astype(dtype=np.float32)
|
||||
# self.init_y = torch.cat(self.init_y,dim=0).numpy().astype(dtype=np.float32)
|
||||
|
||||
gradients = np.concatenate(gradients, axis=0,dtype=np.float32)
|
||||
print(f'Finish Gradient Calculation on {val_str} dataset')
|
||||
return gradients
|
||||
|
||||
# def calc_gradient(self, index=None, val=False, record_val_detail=False):
|
||||
# '''
|
||||
# Calculate gradients matrix on current network for training or validation dataset.
|
||||
# '''
|
||||
#
|
||||
# self.model.eval()
|
||||
#
|
||||
# if val:
|
||||
# batch_loader = torch.utils.data.DataLoader(
|
||||
# self.dst_val if index is None else torch.utils.data.Subset(self.dst_val, index),
|
||||
# batch_size=self.args.selection_batch, num_workers=self.args.workers)
|
||||
# else:
|
||||
# batch_loader = torch.utils.data.DataLoader(
|
||||
# self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
# batch_size=self.args.selection_batch, num_workers=self.args.workers)
|
||||
#
|
||||
# self.embedding_dim = self.model.get_last_layer().in_features
|
||||
# gradients = []
|
||||
# if val and record_val_detail:
|
||||
# self.init_out = []
|
||||
# self.init_emb = []
|
||||
# self.init_y = []
|
||||
#
|
||||
# for i, (input, targets) in enumerate(batch_loader):
|
||||
# self.model_optimizer.zero_grad()
|
||||
# outputs = self.model(input.to(self.args.device))
|
||||
# loss = self.criterion(outputs.requires_grad_(True), targets.to(self.args.device)).sum()
|
||||
# batch_num = targets.shape[0]
|
||||
# with torch.no_grad():
|
||||
# bias_parameters_grads = torch.autograd.grad(loss, outputs)[0]
|
||||
# weight_parameters_grads = self.model.embedding_recorder.embedding.view(batch_num, 1,
|
||||
# self.embedding_dim).repeat(1, self.args.num_classes, 1) *\
|
||||
# bias_parameters_grads.view(
|
||||
# batch_num, self.args.num_classes, 1).repeat(1, 1, self.embedding_dim)
|
||||
# gradients.append(torch.cat(
|
||||
# [bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1).cpu())
|
||||
#
|
||||
# if val and record_val_detail:
|
||||
# self.init_out.append(outputs.cpu())
|
||||
# self.init_emb.append(self.model.embedding_recorder.embedding.cpu())
|
||||
# self.init_y.append(targets)
|
||||
#
|
||||
# gradients = torch.cat(gradients, dim=0)
|
||||
# if val:
|
||||
# self.val_grads = torch.mean(gradients, dim=0)
|
||||
# if self.dst_val == self.dst_train:
|
||||
# # No validation set was provided while instantiating Glister, so self.dst_val == self.dst_train
|
||||
# self.train_grads = gradients
|
||||
# else:
|
||||
# self.train_grads = gradients
|
||||
# if val and record_val_detail:
|
||||
# with torch.no_grad():
|
||||
# self.init_out = torch.cat(self.init_out, dim=0)
|
||||
# self.init_emb = torch.cat(self.init_emb, dim=0)
|
||||
# self.init_y = torch.cat(self.init_y)
|
||||
#
|
||||
# self.model.train()
|
||||
|
||||
#PASS, worth disussion
|
||||
def update_val_gradients(self, new_selection, selected_for_train):
|
||||
|
||||
sum_selected_train_gradients = np.mean(self.train_gradients[selected_for_train], axis=0)
|
||||
|
||||
new_outputs = self.init_out - self.eta * sum_selected_train_gradients[:self.num_classes].reshape(1,
|
||||
-1).repeat(self.init_out.shape[0], 1) - self.eta * torch.matmul(self.init_emb,
|
||||
sum_selected_train_gradients[self.num_classes:].view(self.num_classes, -1).T)
|
||||
|
||||
sample_num = new_outputs.shape[0]
|
||||
gradients = torch.zeros([sample_num, self.args.num_classes * (self.embedding_dim + 1)], requires_grad=False)
|
||||
i = 0
|
||||
while i * self.args.selection_batch < sample_num:
|
||||
batch_indx = np.arange(sample_num)[i * self.args.selection_batch:min((i + 1) * self.args.selection_batch,
|
||||
sample_num)]
|
||||
new_out_puts_batch = new_outputs[batch_indx].clone().detach().requires_grad_(True)
|
||||
loss = self.criterion(new_out_puts_batch, self.init_y[batch_indx])
|
||||
batch_num = len(batch_indx)
|
||||
bias_parameters_grads = torch.autograd.grad(loss.sum(), new_out_puts_batch, retain_graph=True)[0]
|
||||
|
||||
weight_parameters_grads = self.init_emb[batch_indx].view(batch_num, 1, self.embedding_dim).repeat(1,
|
||||
self.args.num_classes, 1) * bias_parameters_grads.view(batch_num,
|
||||
self.args.num_classes, 1).repeat(1, 1, self.embedding_dim)
|
||||
gradients[batch_indx] = torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1).cpu()
|
||||
i += 1
|
||||
|
||||
self.val_grads = torch.mean(gradients, dim=0)
|
||||
|
||||
def finish_run(self):
|
||||
if isinstance(self.model, MyDataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
self.model.no_grad = True
|
||||
|
||||
self.train_indx = np.arange(self.n_train)
|
||||
self.val_indx = np.arange(self.n_val)
|
||||
|
||||
train_gradients = self.calc_gradient(index=None)
|
||||
val_gradients = self.calc_gradient(index=None,val=True)
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int64)
|
||||
#weights = np.array([], dtype=np.float32)
|
||||
for c in range(self.num_classes):
|
||||
c_indx = self.train_indx[self.dst_train_label == c]
|
||||
c_val_inx = self.val_indx[self.dst_val_label == c]
|
||||
self.train_gradients = train_gradients[c_indx]
|
||||
self.val_gradients = val_gradients[c_val_inx].mean(axis=0)
|
||||
|
||||
# self.init_out = self.init_out[c_val_inx]
|
||||
# self.init_emb = self.init_emb[c_val_inx]
|
||||
# self.init_y = self.init_y[c_val_inx]
|
||||
|
||||
submod_optimizer = submodular_optimizer.__dict__[self._greedy](args=self.args, index=c_indx,
|
||||
budget=round(self.fraction * len(c_indx)))
|
||||
#conditioal gain uses taylor series approximation
|
||||
c_selection_result = submod_optimizer.select(gain_function=lambda idx_gain, selected,
|
||||
**kwargs: np.dot(self.train_gradients[idx_gain],
|
||||
self.val_gradients.reshape(-1, 1)).
|
||||
flatten(), update_state=None) #self.update val
|
||||
selection_result = np.append(selection_result, c_selection_result)
|
||||
|
||||
else:
|
||||
self.train_gradients = train_gradients
|
||||
self.val_gradients = val_gradients.mean(axis=0)
|
||||
submod_optimizer = submodular_optimizer.__dict__[self._greedy](args=self.args,
|
||||
index=np.arange(self.n_train), budget=self.coreset_size)
|
||||
selection_result = submod_optimizer.select(gain_function=lambda idx_gain, selected,
|
||||
**kwargs: torch.matmul(self.train_gradients[idx_gain],
|
||||
self.val_gradients.view(-1, 1)).detach().cpu().numpy().flatten(),
|
||||
upadate_state=self.update_val_gradients)
|
||||
|
||||
|
||||
self.model.no_grad = False
|
||||
return {"indices": selection_result}
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.linalg import lstsq
|
||||
from scipy.optimize import nnls
|
||||
from .earlytrain import EarlyTrain
|
||||
from ..nets.nets_utils import MyDataParallel
|
||||
|
||||
|
||||
# https://github.com/krishnatejakk/GradMatch
|
||||
|
||||
class GradMatch(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None,
|
||||
balance=True, dst_val=None, lam: float = 1., **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
||||
self.balance = balance
|
||||
self.dst_val = dst_val
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
if batch_idx % self.args.print_freq == 0:
|
||||
print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
|
||||
epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item()))
|
||||
|
||||
def orthogonal_matching_pursuit(self, A, b, budget: int, lam: float = 1.):
|
||||
'''approximately solves min_x |x|_0 s.t. Ax=b using Orthogonal Matching Pursuit
|
||||
Acknowlegement to:
|
||||
https://github.com/krishnatejakk/GradMatch/blob/main/GradMatch/selectionstrategies/helpers/omp_solvers.py
|
||||
Args:
|
||||
A: design matrix of size (d, n)
|
||||
b: measurement vector of length d
|
||||
budget: selection budget
|
||||
lam: regularization coef. for the final output vector
|
||||
Returns:
|
||||
vector of length n
|
||||
'''
|
||||
with torch.no_grad():
|
||||
d, n = A.shape
|
||||
if budget <= 0:
|
||||
budget = 0
|
||||
elif budget > n:
|
||||
budget = n
|
||||
|
||||
x = np.zeros(n, dtype=np.float32)
|
||||
resid = b.clone()
|
||||
indices = []
|
||||
boolean_mask = torch.ones(n, dtype=bool, device="cuda")
|
||||
all_idx = torch.arange(n, device='cuda')
|
||||
|
||||
for i in range(budget):
|
||||
if i % self.args.print_freq == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, budget))
|
||||
projections = torch.matmul(A.T, resid)
|
||||
index = torch.argmax(projections[boolean_mask])
|
||||
index = all_idx[boolean_mask][index]
|
||||
|
||||
indices.append(index.item())
|
||||
boolean_mask[index] = False
|
||||
|
||||
if indices.__len__() == 1:
|
||||
A_i = A[:, index]
|
||||
x_i = projections[index] / torch.dot(A_i, A_i).view(-1)
|
||||
A_i = A[:, index].view(1, -1)
|
||||
else:
|
||||
A_i = torch.cat((A_i, A[:, index].view(1, -1)), dim=0)
|
||||
temp = torch.matmul(A_i, torch.transpose(A_i, 0, 1)) + lam * torch.eye(A_i.shape[0], device="cuda")
|
||||
x_i, _ = torch.lstsq(torch.matmul(A_i, b).view(-1, 1), temp)
|
||||
resid = b - torch.matmul(torch.transpose(A_i, 0, 1), x_i).view(-1)
|
||||
if budget > 1:
|
||||
x_i = nnls(temp.cpu().numpy(), torch.matmul(A_i, b).view(-1).cpu().numpy())[0]
|
||||
x[indices] = x_i
|
||||
elif budget == 1:
|
||||
x[indices[0]] = 1.
|
||||
return x
|
||||
|
||||
def orthogonal_matching_pursuit_np(self, A, b, budget: int, lam: float = 1.):
|
||||
'''approximately solves min_x |x|_0 s.t. Ax=b using Orthogonal Matching Pursuit
|
||||
Acknowlegement to:
|
||||
https://github.com/krishnatejakk/GradMatch/blob/main/GradMatch/selectionstrategies/helpers/omp_solvers.py
|
||||
Args:
|
||||
A: design matrix of size (d, n)
|
||||
b: measurement vector of length d
|
||||
budget: selection budget
|
||||
lam: regularization coef. for the final output vector
|
||||
Returns:
|
||||
vector of length n
|
||||
'''
|
||||
d, n = A.shape
|
||||
if budget <= 0:
|
||||
budget = 0
|
||||
elif budget > n:
|
||||
budget = n
|
||||
|
||||
x = np.zeros(n, dtype=np.float32)
|
||||
resid = np.copy(b)
|
||||
indices = []
|
||||
boolean_mask = np.ones(n, dtype=bool)
|
||||
all_idx = np.arange(n)
|
||||
|
||||
for i in range(budget):
|
||||
if i % self.args.print_freq == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, budget))
|
||||
projections = A.T.dot(resid)
|
||||
index = np.argmax(projections[boolean_mask])
|
||||
index = all_idx[boolean_mask][index]
|
||||
|
||||
indices.append(index.item())
|
||||
boolean_mask[index] = False
|
||||
|
||||
if indices.__len__() == 1:
|
||||
A_i = A[:, index]
|
||||
x_i = projections[index] / A_i.T.dot(A_i)
|
||||
else:
|
||||
A_i = np.vstack([A_i, A[:, index]])
|
||||
x_i = lstsq(A_i.dot(A_i.T) + lam * np.identity(A_i.shape[0]), A_i.dot(b))[0]
|
||||
resid = b - A_i.T.dot(x_i)
|
||||
if budget > 1:
|
||||
x_i = nnls(A_i.dot(A_i.T) + lam * np.identity(A_i.shape[0]), A_i.dot(b))[0]
|
||||
x[indices] = x_i
|
||||
elif budget == 1:
|
||||
x[indices[0]] = 1.
|
||||
return x
|
||||
|
||||
def calc_gradient(self, index=None, val=False):
|
||||
self.model.eval()
|
||||
if val:
|
||||
batch_loader = torch.utils.data.DataLoader(
|
||||
self.dst_val if index is None else torch.utils.data.Subset(self.dst_val, index),
|
||||
batch_size=self.args.selection_batch, num_workers=self.args.workers)
|
||||
sample_num = len(self.dst_val.targets) if index is None else len(index)
|
||||
else:
|
||||
batch_loader = torch.utils.data.DataLoader(
|
||||
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
batch_size=self.args.selection_batch, num_workers=self.args.workers)
|
||||
sample_num = self.n_train if index is None else len(index)
|
||||
|
||||
self.embedding_dim = self.model.get_last_layer().in_features
|
||||
gradients = torch.zeros([sample_num, self.args.num_classes * (self.embedding_dim + 1)],
|
||||
requires_grad=False, device=self.args.device)
|
||||
|
||||
for i, (input, targets) in enumerate(batch_loader):
|
||||
self.model_optimizer.zero_grad()
|
||||
outputs = self.model(input.to(self.args.device)).requires_grad_(True)
|
||||
loss = self.criterion(outputs, targets.to(self.args.device)).sum()
|
||||
batch_num = targets.shape[0]
|
||||
with torch.no_grad():
|
||||
bias_parameters_grads = torch.autograd.grad(loss, outputs, retain_graph=True)[0].cpu()
|
||||
weight_parameters_grads = self.model.embedding_recorder.embedding.cpu().view(batch_num, 1,
|
||||
self.embedding_dim).repeat(1,self.args.num_classes,1) *\
|
||||
bias_parameters_grads.view(batch_num, self.args.num_classes,
|
||||
1).repeat(1, 1, self.embedding_dim)
|
||||
gradients[i * self.args.selection_batch:min((i + 1) * self.args.selection_batch, sample_num)] =\
|
||||
torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)], dim=1)
|
||||
|
||||
return gradients
|
||||
|
||||
def finish_run(self):
|
||||
if isinstance(self.model, MyDataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
self.model.no_grad = True
|
||||
with self.model.embedding_recorder:
|
||||
if self.dst_val is not None:
|
||||
val_num = len(self.dst_val.targets)
|
||||
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int64)
|
||||
weights = np.array([], dtype=np.float32)
|
||||
for c in range(self.args.num_classes):
|
||||
class_index = np.arange(self.n_train)[self.dst_train.targets == c]
|
||||
cur_gradients = self.calc_gradient(class_index)
|
||||
if self.dst_val is not None:
|
||||
# Also calculate gradients of the validation set.
|
||||
val_class_index = np.arange(val_num)[self.dst_val.targets == c]
|
||||
cur_val_gradients = torch.mean(self.calc_gradient(val_class_index, val=True), dim=0)
|
||||
else:
|
||||
cur_val_gradients = torch.mean(cur_gradients, dim=0)
|
||||
if self.args.device == "cpu":
|
||||
# Compute OMP on numpy
|
||||
cur_weights = self.orthogonal_matching_pursuit_np(cur_gradients.numpy().T,
|
||||
cur_val_gradients.numpy(),
|
||||
budget=round(len(class_index) * self.fraction))
|
||||
else:
|
||||
cur_weights = self.orthogonal_matching_pursuit(cur_gradients.to(self.args.device).T,
|
||||
cur_val_gradients.to(self.args.device),
|
||||
budget=round(len(class_index) * self.fraction))
|
||||
selection_result = np.append(selection_result, class_index[np.nonzero(cur_weights)[0]])
|
||||
weights = np.append(weights, cur_weights[np.nonzero(cur_weights)[0]])
|
||||
else:
|
||||
cur_gradients = self.calc_gradient()
|
||||
if self.dst_val is not None:
|
||||
# Also calculate gradients of the validation set.
|
||||
cur_val_gradients = torch.mean(self.calc_gradient(val=True), dim=0)
|
||||
else:
|
||||
cur_val_gradients = torch.mean(cur_gradients, dim=0)
|
||||
if self.args.device == "cpu":
|
||||
# Compute OMP on numpy
|
||||
cur_weights = self.orthogonal_matching_pursuit_np(cur_gradients.numpy().T,
|
||||
cur_val_gradients.numpy(),
|
||||
budget=self.coreset_size)
|
||||
else:
|
||||
cur_weights = self.orthogonal_matching_pursuit(cur_gradients.T, cur_val_gradients,
|
||||
budget=self.coreset_size)
|
||||
selection_result = np.nonzero(cur_weights)[0]
|
||||
weights = cur_weights[selection_result]
|
||||
self.model.no_grad = False
|
||||
return {"indices": selection_result, "weights": weights}
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch, time
|
||||
import numpy as np
|
||||
from ..nets.nets_utils import MyDataParallel
|
||||
from tqdm import tqdm
|
||||
|
||||
class GraNd(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, repeat=1,
|
||||
specific_model=None, balance=False, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model,**kwargs)
|
||||
self.epochs = epochs
|
||||
self.n_train = len(self.dst_train)
|
||||
self.coreset_size = round(self.n_train * fraction)
|
||||
self.specific_model = specific_model
|
||||
self.repeat = repeat
|
||||
|
||||
self.balance = balance
|
||||
|
||||
# def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
# if batch_idx % self.args.print_freq == 0:
|
||||
# print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
|
||||
# epoch, self.epochs, batch_idx + 1, (self.n_train // batch_size) + 1, loss.item()))
|
||||
|
||||
def before_run(self):
|
||||
if isinstance(self.model, MyDataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
def calc_gradient(self, index=None):
|
||||
'''
|
||||
Calculate gradients matrix on current network for specified training dataset.
|
||||
'''
|
||||
self.model.eval()
|
||||
data_loader = self.select_dm(self.dst_train, index, is_train=False)
|
||||
# Initialize a matrix to save gradients.
|
||||
# (on cpu)
|
||||
gradients = []
|
||||
|
||||
for i, batch in enumerate(tqdm(data_loader)):
|
||||
self.optim.zero_grad()
|
||||
image, label = batch['img'].cuda(), batch['label'].cuda()
|
||||
bs_size = image.shape[0]
|
||||
loss, visual_embedding, logit = self.model(image, label, cal_gradient=True)
|
||||
embed_dim = visual_embedding.shape[-1]
|
||||
with torch.no_grad():
|
||||
bias_parameters_grads = torch.autograd.grad(loss, logit)[0]
|
||||
weight_parameters_grads = visual_embedding.view(bs_size, 1,
|
||||
-1).repeat(1, self.num_classes, 1) * \
|
||||
bias_parameters_grads.view(bs_size, self.num_classes,
|
||||
1).repeat(1, 1, embed_dim)
|
||||
gradients.append(torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)],
|
||||
dim=1).cpu().numpy())
|
||||
|
||||
gradients = np.concatenate(gradients, axis=0, dtype=np.float32)
|
||||
print('Finish Gradient Calculation')
|
||||
self.model.train()
|
||||
return gradients
|
||||
|
||||
def finish_run(self):
|
||||
# self.model.embedding_recorder.record_embedding = True # recording embedding vector
|
||||
|
||||
gradients = self.calc_gradient()
|
||||
self.norm_matrix[:,0] = np.linalg.norm(gradients,axis=1)
|
||||
|
||||
|
||||
|
||||
# embedding_dim = self.model.get_last_layer().in_features
|
||||
# data_loader = self.select_dm(self.dst_train, None, is_train=False)
|
||||
# sample_num = self.n_train
|
||||
#
|
||||
# for i, batch in enumerate(data_loader):
|
||||
# self.optim.zero_grad()
|
||||
# image, target,batch_inds = batch['img'].cuda(), batch['label'].cuda(), batch['index'].cuda()
|
||||
#
|
||||
# outputs = self.model(image)
|
||||
# loss = self.criterion(outputs.requires_grad_(True),
|
||||
# targets.to(self.args.device)).sum()
|
||||
# batch_num = targets.shape[0]
|
||||
# with torch.no_grad():
|
||||
# bias_parameters_grads = torch.autograd.grad(loss, outputs)[0]
|
||||
# self.norm_matrix[i * self.args.selection_batch:min((i + 1) * self.args.selection_batch, sample_num),
|
||||
# self.cur_repeat] = torch.norm(torch.cat([bias_parameters_grads, (
|
||||
# self.model.embedding_recorder.embedding.view(batch_num, 1, embedding_dim).repeat(1,
|
||||
# self.args.num_classes, 1) * bias_parameters_grads.view(
|
||||
# batch_num, self.args.num_classes, 1).repeat(1, 1, embedding_dim)).
|
||||
# view(batch_num, -1)], dim=1), dim=1, p=2)
|
||||
#
|
||||
# self.model.train()
|
||||
|
||||
|
||||
def select(self, **kwargs):
|
||||
# Initialize a matrix to save norms of each sample on idependent runs
|
||||
self.norm_matrix = np.zeros([self.n_train, self.repeat])
|
||||
|
||||
# for self.cur_repeat in range(self.repeat):
|
||||
self.run()
|
||||
# self.random_seed = self.random_seed + 5
|
||||
|
||||
self.norm_mean = np.mean(self.norm_matrix, axis=1)
|
||||
if not self.balance:
|
||||
top_examples = self.train_indx[np.argsort(self.norm_mean)][::-1][:self.coreset_size]
|
||||
else:
|
||||
top_examples = np.array([], dtype=np.int64)
|
||||
for c in tqdm(range(self.num_classes)):
|
||||
c_indx = self.train_indx[self.dst_train_label == c]
|
||||
budget = round(self.fraction * len(c_indx))
|
||||
top_examples = np.append(top_examples, c_indx[np.argsort(self.norm_mean[c_indx])[::-1][:budget]])
|
||||
|
||||
return {"indices": top_examples, "scores": self.norm_mean}
|
||||
@@ -0,0 +1,109 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch
|
||||
import numpy as np
|
||||
from .methods_utils import euclidean_dist
|
||||
from ..nets.nets_utils import MyDataParallel
|
||||
|
||||
|
||||
class Herding(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200,
|
||||
specific_model="ResNet18", balance: bool = False, metric="euclidean", **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs=epochs, specific_model=specific_model, **kwargs)
|
||||
|
||||
if metric == "euclidean":
|
||||
self.metric = euclidean_dist
|
||||
elif callable(metric):
|
||||
self.metric = metric
|
||||
else:
|
||||
self.metric = euclidean_dist
|
||||
self.run = lambda: self.finish_run()
|
||||
|
||||
def _construct_matrix(index=None):
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
batch_size=self.n_train if index is None else len(index), num_workers=self.args.workers)
|
||||
inputs, _ = next(iter(data_loader))
|
||||
return inputs.flatten(1).requires_grad_(False).to(self.args.device)
|
||||
|
||||
self.construct_matrix = _construct_matrix
|
||||
|
||||
self.balance = balance
|
||||
self.select_bs = self.args.DATASET.SELECTION_BATCH_SIZE
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
pass
|
||||
|
||||
#Initial achievement, may not optimal
|
||||
def mixing_feature(self,img_fea,text_fea,lam=0.5):
|
||||
# return img_fea
|
||||
return lam*img_fea + (1-lam)*text_fea
|
||||
|
||||
def construct_matrix(self, index=None):
|
||||
self.model.eval()
|
||||
self.model.no_grad = True
|
||||
with torch.no_grad():
|
||||
# with self.model.embedding_recorder:
|
||||
sample_num = self.n_train if index is None else len(index)
|
||||
matrix = torch.zeros([sample_num, self.emb_dim], requires_grad=False).cuda()
|
||||
data_loader = self.select_dm(self.dst_train,index,is_train=False)
|
||||
for i, batch in enumerate(data_loader):
|
||||
image,label = batch['img'].cuda(),batch['label'].cuda()
|
||||
img_f,text_f,_ = self.model(image, label, record=True)
|
||||
final_embed = self.mixing_feature(img_f,text_f) #Using the mixed image_feature and text_feature
|
||||
matrix[i * self.select_bs:min((i + 1) * self.select_bs, sample_num)] = final_embed
|
||||
|
||||
self.model.no_grad = False
|
||||
self.model.train()
|
||||
return matrix
|
||||
|
||||
def before_run(self):
|
||||
self.emb_dim = self.model.image_encoder.output_dim
|
||||
|
||||
def herding(self, matrix, budget: int, index=None):
|
||||
|
||||
sample_num = matrix.shape[0]
|
||||
|
||||
if budget < 0:
|
||||
raise ValueError("Illegal budget size.")
|
||||
elif budget > sample_num:
|
||||
budget = sample_num
|
||||
|
||||
indices = np.arange(sample_num)
|
||||
with torch.no_grad():
|
||||
mu = torch.mean(matrix, dim=0)
|
||||
select_result = np.zeros(sample_num, dtype=bool)
|
||||
|
||||
for i in range(budget):
|
||||
if i % self.args.TRAIN.PRINT_FREQ == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, budget))
|
||||
dist = self.metric(((i + 1) * mu - torch.sum(matrix[select_result], dim=0)).view(1, -1),
|
||||
matrix[~select_result])
|
||||
p = torch.argmax(dist).item()
|
||||
p = indices[~select_result][p]
|
||||
select_result[p] = True
|
||||
if index is None:
|
||||
index = indices
|
||||
return index[select_result]
|
||||
|
||||
def finish_run(self):
|
||||
if isinstance(self.model, MyDataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int32)
|
||||
for c in range(self.num_classes):
|
||||
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
||||
selection_result = np.append(selection_result, self.herding(self.construct_matrix(class_index),
|
||||
budget=round(self.fraction * len(class_index)), index=class_index))
|
||||
else:
|
||||
selection_result = self.herding(self.construct_matrix(), budget=self.coreset_size)
|
||||
return {"indices": selection_result}
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch
|
||||
import numpy as np
|
||||
from .methods_utils import euclidean_dist
|
||||
from ..nets.nets_utils import MyDataParallel
|
||||
|
||||
|
||||
def k_center_greedy(matrix, budget: int, metric, device, random_seed=None, index=None, already_selected=None,
|
||||
print_freq: int = 20):
|
||||
if type(matrix) == torch.Tensor:
|
||||
assert matrix.dim() == 2
|
||||
elif type(matrix) == np.ndarray:
|
||||
assert matrix.ndim == 2
|
||||
matrix = torch.from_numpy(matrix).requires_grad_(False).to(device)
|
||||
|
||||
sample_num = matrix.shape[0]
|
||||
assert sample_num >= 1
|
||||
|
||||
if budget < 0:
|
||||
raise ValueError("Illegal budget size.")
|
||||
elif budget > sample_num:
|
||||
budget = sample_num
|
||||
|
||||
if index is not None:
|
||||
assert matrix.shape[0] == len(index)
|
||||
else:
|
||||
index = np.arange(sample_num)
|
||||
|
||||
assert callable(metric)
|
||||
|
||||
already_selected = np.array(already_selected)
|
||||
|
||||
with torch.no_grad():
|
||||
np.random.seed(random_seed)
|
||||
if already_selected.__len__() == 0:
|
||||
select_result = np.zeros(sample_num, dtype=bool)
|
||||
# Randomly select one initial point.
|
||||
already_selected = [np.random.randint(0, sample_num)]
|
||||
budget -= 1
|
||||
select_result[already_selected] = True
|
||||
else:
|
||||
select_result = np.in1d(index, already_selected)
|
||||
|
||||
num_of_already_selected = np.sum(select_result)
|
||||
|
||||
# Initialize a (num_of_already_selected+budget-1)*sample_num matrix storing distances of pool points from
|
||||
# each clustering center.
|
||||
dis_matrix = -1 * torch.ones([num_of_already_selected + budget - 1, sample_num], requires_grad=False).to(device)
|
||||
|
||||
dis_matrix[:num_of_already_selected, ~select_result] = metric(matrix[select_result], matrix[~select_result])
|
||||
|
||||
mins = torch.min(dis_matrix[:num_of_already_selected, :], dim=0).values
|
||||
|
||||
for i in range(budget):
|
||||
if i % print_freq == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, budget))
|
||||
p = torch.argmax(mins).item()
|
||||
select_result[p] = True
|
||||
|
||||
if i == budget - 1:
|
||||
break
|
||||
mins[p] = -1
|
||||
dis_matrix[num_of_already_selected + i, ~select_result] = metric(matrix[[p]], matrix[~select_result])
|
||||
mins = torch.min(mins, dis_matrix[num_of_already_selected + i])
|
||||
return index[select_result]
|
||||
|
||||
|
||||
class kCenterGreedy(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=0,
|
||||
specific_model="ResNet18", balance: bool = False, already_selected=[], metric="euclidean",
|
||||
torchvision_pretrain: bool = True, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs=epochs, specific_model=specific_model,
|
||||
torchvision_pretrain=torchvision_pretrain, **kwargs)
|
||||
|
||||
if already_selected.__len__() != 0:
|
||||
if min(already_selected) < 0 or max(already_selected) >= self.n_train:
|
||||
raise ValueError("List of already selected points out of the boundary.")
|
||||
self.already_selected = np.array(already_selected)
|
||||
|
||||
self.min_distances = None
|
||||
|
||||
if metric == "euclidean":
|
||||
self.metric = euclidean_dist
|
||||
elif callable(metric):
|
||||
self.metric = metric
|
||||
else:
|
||||
self.metric = euclidean_dist
|
||||
self.run = lambda : self.finish_run()
|
||||
def _construct_matrix(index=None):
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
batch_size=self.n_train if index is None else len(index),
|
||||
num_workers=self.args.workers)
|
||||
inputs, _ = next(iter(data_loader))
|
||||
return inputs.flatten(1).requires_grad_(False).to(self.args.device)
|
||||
self.construct_matrix = _construct_matrix
|
||||
|
||||
self.balance = balance
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
if batch_idx % self.args.print_freq == 0:
|
||||
print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f' % (
|
||||
epoch, self.epochs, batch_idx + 1, (self.n_pretrain_size // batch_size) + 1, loss.item()))
|
||||
|
||||
def old_construct_matrix(self, index=None):
|
||||
self.model.eval()
|
||||
self.model.no_grad = True
|
||||
with torch.no_grad():
|
||||
with self.model.embedding_recorder:
|
||||
sample_num = self.n_train if index is None else len(index)
|
||||
matrix = torch.zeros([sample_num, self.emb_dim], requires_grad=False).to(self.args.device)
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(self.dst_train if index is None else
|
||||
torch.utils.data.Subset(self.dst_train, index),
|
||||
batch_size=self.args.selection_batch,
|
||||
num_workers=self.args.workers)
|
||||
|
||||
for i, (inputs, _) in enumerate(data_loader):
|
||||
self.model(inputs.to(self.args.device))
|
||||
matrix[i * self.args.selection_batch:min((i + 1) * self.args.selection_batch,
|
||||
sample_num)] = self.model.embedding_recorder.embedding
|
||||
|
||||
self.model.no_grad = False
|
||||
return matrix
|
||||
|
||||
def construct_matrix(self, index=None):
|
||||
self.model.eval()
|
||||
self.model.no_grad = True
|
||||
with torch.no_grad():
|
||||
with self.model.embedding_recorder:
|
||||
sample_num = self.n_train if index is None else len(index)
|
||||
matrix = []
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(self.dst_train if index is None else
|
||||
torch.utils.data.Subset(self.dst_train, index),
|
||||
batch_size=self.args.selection_batch,
|
||||
num_workers=self.args.workers)
|
||||
|
||||
for i, (inputs, _) in enumerate(data_loader):
|
||||
self.model(inputs.to(self.args.device))
|
||||
matrix.append(self.model.embedding_recorder.embedding)
|
||||
|
||||
self.model.no_grad = False
|
||||
return torch.cat(matrix, dim=0)
|
||||
|
||||
def before_run(self):
|
||||
self.emb_dim = self.model.get_last_layer().in_features
|
||||
|
||||
def finish_run(self):
|
||||
if isinstance(self.model, MyDataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
def select(self, **kwargs):
|
||||
self.run()
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int32)
|
||||
for c in range(self.args.num_classes):
|
||||
class_index = np.arange(self.n_train)[self.dst_train.targets == c]
|
||||
|
||||
selection_result = np.append(selection_result, k_center_greedy(self.construct_matrix(class_index),
|
||||
budget=round(
|
||||
self.fraction * len(class_index)),
|
||||
metric=self.metric,
|
||||
device=self.args.device,
|
||||
random_seed=self.random_seed,
|
||||
index=class_index,
|
||||
already_selected=self.already_selected[
|
||||
np.in1d(self.already_selected,
|
||||
class_index)],
|
||||
print_freq=self.args.print_freq))
|
||||
else:
|
||||
matrix = self.construct_matrix()
|
||||
del self.model_optimizer
|
||||
del self.model
|
||||
selection_result = k_center_greedy(matrix, budget=self.coreset_size,
|
||||
metric=self.metric, device=self.args.device,
|
||||
random_seed=self.random_seed,
|
||||
already_selected=self.already_selected, print_freq=self.args.print_freq)
|
||||
return {"indices": selection_result}
|
||||
@@ -0,0 +1,4 @@
|
||||
from .euclidean import *
|
||||
from .cossim import *
|
||||
from .submodular_function import *
|
||||
from .submodular_optimizer import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,35 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def cossim_np(v1, v2):
|
||||
# return cossim(torch.tensor(v1),torch.tensor(v2)).cpu().numpy()
|
||||
num = np.dot(v1, v2.T)
|
||||
denom = np.linalg.norm(v1, axis=1).reshape(-1, 1) * np.linalg.norm(v2, axis=1)
|
||||
res = num / (denom + 1e-6)
|
||||
res[np.isneginf(res)] = 0.
|
||||
return 0.5 + 0.5 * res
|
||||
|
||||
def cossim_pair_np(v1):
|
||||
num = np.dot(v1, v1.T)
|
||||
norm = np.linalg.norm(v1, axis=1)
|
||||
denom = norm.reshape(-1, 1) * norm
|
||||
res = num / (denom + 1e-6)
|
||||
res[np.isneginf(res)] = 0.
|
||||
return 0.5 + 0.5 * res
|
||||
|
||||
def cossim(v1, v2):
|
||||
num = torch.matmul(v1, v2.T)
|
||||
denom = torch.norm(v1, dim=1).view(-1, 1) * torch.norm(v2, dim=1)
|
||||
res = num / (denom + 1e-6)
|
||||
res[torch.isneginf(res)] = 0.
|
||||
return 0.5 + 0.5 * res
|
||||
|
||||
def cossim_pair(v1):
|
||||
num = torch.matmul(v1, v1.T)
|
||||
norm = torch.norm(v1, dim=1)
|
||||
denom = norm.view(-1, 1) * norm
|
||||
res = num / (denom + 1e-6)
|
||||
res[torch.isneginf(res)] = 0.
|
||||
return 0.5 + 0.5 * res
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def euclidean_dist(x, y):
|
||||
m, n = x.size(0), y.size(0)
|
||||
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
||||
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
||||
dist = xx + yy
|
||||
dist.addmm_(1, -2, x, y.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt()
|
||||
return dist
|
||||
|
||||
|
||||
def euclidean_dist_pair(x):
|
||||
m = x.size(0)
|
||||
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, m)
|
||||
dist = xx + xx.t()
|
||||
dist.addmm_(1, -2, x, x.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt()
|
||||
return dist
|
||||
|
||||
def euclidean_dist_np(x, y):
|
||||
(rowx, colx) = x.shape
|
||||
(rowy, coly) = y.shape
|
||||
xy = np.dot(x, y.T)
|
||||
x2 = np.repeat(np.reshape(np.sum(np.multiply(x, x), axis=1), (rowx, 1)), repeats=rowy, axis=1)
|
||||
y2 = np.repeat(np.reshape(np.sum(np.multiply(y, y), axis=1), (rowy, 1)), repeats=rowx, axis=1).T
|
||||
return np.sqrt(np.clip(x2 + y2 - 2. * xy, 1e-12, None))
|
||||
|
||||
#calculate the euclidean distance of each sample in x, return a N*N matrix, whose diag is zero
|
||||
def euclidean_dist_pair_np(x):
|
||||
(rowx, colx) = x.shape
|
||||
xy = np.dot(x, x.T)
|
||||
x2 = np.repeat(np.reshape(np.sum(np.multiply(x, x), axis=1), (rowx, 1)), repeats=rowx, axis=1)
|
||||
return np.sqrt(np.clip(x2 + x2.T - 2. * xy, 1e-12, None))
|
||||
@@ -0,0 +1,144 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SubmodularFunction(object):
|
||||
def __init__(self, index, similarity_kernel=None, similarity_matrix=None, already_selected=[]):
|
||||
self.index = index
|
||||
self.n = len(index)
|
||||
|
||||
self.already_selected = already_selected
|
||||
|
||||
assert similarity_kernel is not None or similarity_matrix is not None
|
||||
|
||||
# For the sample similarity matrix, the method supports two input modes, one is to input a pairwise similarity
|
||||
# matrix for the whole sample, and the other case allows the input of a similarity kernel to be used to
|
||||
# calculate similarities incrementally at a later time if required.
|
||||
if similarity_kernel is not None:
|
||||
assert callable(similarity_kernel)
|
||||
self.similarity_kernel = self._similarity_kernel(similarity_kernel)
|
||||
else:
|
||||
assert similarity_matrix.shape[0] == self.n and similarity_matrix.shape[1] == self.n
|
||||
self.similarity_matrix = similarity_matrix
|
||||
self.similarity_kernel = lambda a, b: self.similarity_matrix[np.ix_(a, b)]
|
||||
|
||||
def _similarity_kernel(self, similarity_kernel):
|
||||
return similarity_kernel
|
||||
|
||||
|
||||
class FacilityLocation(SubmodularFunction):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.already_selected.__len__()==0:
|
||||
self.cur_max = np.zeros(self.n, dtype=np.float32)
|
||||
else:
|
||||
self.cur_max = np.max(self.similarity_kernel(np.arange(self.n), self.already_selected), axis=1)
|
||||
|
||||
self.all_idx = np.ones(self.n, dtype=bool)
|
||||
|
||||
def _similarity_kernel(self, similarity_kernel):
|
||||
# Initialize a matrix to store similarity values of sample points.
|
||||
self.sim_matrix = np.zeros([self.n, self.n], dtype=np.float32)
|
||||
self.if_columns_calculated = np.zeros(self.n, dtype=bool)
|
||||
|
||||
def _func(a, b):
|
||||
if not np.all(self.if_columns_calculated[b]):
|
||||
if b.dtype != bool:
|
||||
temp = ~self.all_idx
|
||||
temp[b] = True
|
||||
b = temp
|
||||
not_calculated = b & ~self.if_columns_calculated
|
||||
self.sim_matrix[:, not_calculated] = similarity_kernel(self.all_idx, not_calculated)
|
||||
self.if_columns_calculated[not_calculated] = True
|
||||
return self.sim_matrix[np.ix_(a, b)]
|
||||
return _func
|
||||
|
||||
def calc_gain(self, idx_gain, selected, **kwargs):
|
||||
gains = np.maximum(0., self.similarity_kernel(self.all_idx, idx_gain) - self.cur_max.reshape(-1, 1)).sum(axis=0)
|
||||
return gains
|
||||
|
||||
def calc_gain_batch(self, idx_gain, selected, **kwargs):
|
||||
batch_idx = ~self.all_idx
|
||||
batch_idx[0:kwargs["batch"]] = True
|
||||
gains = np.maximum(0., self.similarity_kernel(batch_idx, idx_gain) - self.cur_max[batch_idx].reshape(-1, 1)).sum(axis=0)
|
||||
for i in range(kwargs["batch"], self.n, kwargs["batch"]):
|
||||
batch_idx = ~self.all_idx
|
||||
batch_idx[i * kwargs["batch"]:(i + 1) * kwargs["batch"]] = True
|
||||
gains += np.maximum(0., self.similarity_kernel(batch_idx, idx_gain) - self.cur_max[batch_idx].reshape(-1,1)).sum(axis=0)
|
||||
return gains
|
||||
|
||||
def update_state(self, new_selection, total_selected, **kwargs):
|
||||
self.cur_max = np.maximum(self.cur_max, np.max(self.similarity_kernel(self.all_idx, new_selection), axis=1))
|
||||
#self.cur_max = np.max(np.append(self.cur_max.reshape(-1, 1), self.similarity_kernel(self.all_idx, new_selection), axis=1), axis=1)
|
||||
|
||||
|
||||
class GraphCut(SubmodularFunction):
|
||||
def __init__(self, lam: float = 1., **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.lam = lam
|
||||
|
||||
if 'similarity_matrix' in kwargs:
|
||||
self.sim_matrix_cols_sum = np.sum(self.similarity_matrix, axis=0)
|
||||
self.all_idx = np.ones(self.n, dtype=bool)
|
||||
|
||||
def _similarity_kernel(self, similarity_kernel):
|
||||
# Initialize a matrix to store similarity values of sample points.
|
||||
self.sim_matrix = np.zeros([self.n, self.n], dtype=np.float32)
|
||||
self.sim_matrix_cols_sum = np.zeros(self.n, dtype=np.float32)
|
||||
self.if_columns_calculated = np.zeros(self.n, dtype=bool)
|
||||
|
||||
def _func(a, b):
|
||||
if not np.all(self.if_columns_calculated[b]):
|
||||
if b.dtype != bool:
|
||||
temp = ~self.all_idx
|
||||
temp[b] = True
|
||||
b = temp
|
||||
not_calculated = b & ~self.if_columns_calculated
|
||||
self.sim_matrix[:, not_calculated] = similarity_kernel(self.all_idx, not_calculated)
|
||||
self.sim_matrix_cols_sum[not_calculated] = np.sum(self.sim_matrix[:, not_calculated], axis=0)
|
||||
self.if_columns_calculated[not_calculated] = True
|
||||
return self.sim_matrix[np.ix_(a, b)]
|
||||
return _func
|
||||
|
||||
def calc_gain(self, idx_gain, selected, **kwargs):
|
||||
# Conditional gain
|
||||
# return the sum distance of each unselected sample to the any other one (selected, idx_gain) is for fun. _func()
|
||||
gain = -2. * np.sum(self.similarity_kernel(selected, idx_gain), axis=0) + self.lam * self.sim_matrix_cols_sum[idx_gain]
|
||||
|
||||
return gain
|
||||
|
||||
def update_state(self, new_selection, total_selected, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class LogDeterminant(SubmodularFunction):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.all_idx = np.ones(self.n, dtype=bool)
|
||||
|
||||
def _similarity_kernel(self, similarity_kernel):
|
||||
# Initialize a matrix to store similarity values of sample points.
|
||||
self.sim_matrix = np.zeros([self.n, self.n], dtype=np.float32)
|
||||
self.if_columns_calculated = np.zeros(self.n, dtype=bool)
|
||||
|
||||
def _func(a, b):
|
||||
if not np.all(self.if_columns_calculated[b]):
|
||||
if b.dtype != bool:
|
||||
temp = ~self.all_idx
|
||||
temp[b] = True
|
||||
b = temp
|
||||
not_calculated = b & ~self.if_columns_calculated
|
||||
self.sim_matrix[:, not_calculated] = similarity_kernel(self.all_idx, not_calculated)
|
||||
self.if_columns_calculated[not_calculated] = True
|
||||
return self.sim_matrix[np.ix_(a, b)]
|
||||
return _func
|
||||
|
||||
def calc_gain(self, idx_gain, selected, **kwargs):
|
||||
# Gain for LogDeterminant can be written as $f(x | A ) = \log\det(S_{a} - S_{a,A}S_{A}^{-1}S_{x,A}^T)$.
|
||||
sim_idx_gain = self.similarity_kernel(selected, idx_gain).T
|
||||
sim_selected = self.similarity_kernel(selected, selected)
|
||||
return (np.dot(sim_idx_gain, np.linalg.pinv(sim_selected)) * sim_idx_gain).sum(-1)
|
||||
|
||||
def update_state(self, new_selection, total_selected, **kwargs):
|
||||
pass
|
||||
@@ -0,0 +1,155 @@
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
optimizer_choices = ["NaiveGreedy", "LazyGreedy", "StochasticGreedy", "ApproximateLazyGreedy"]
|
||||
|
||||
class optimizer(object):
|
||||
def __init__(self, args, index, budget:int, already_selected=[]):
|
||||
self.args = args
|
||||
self.index = index
|
||||
|
||||
if budget <= 0 or budget > index.__len__():
|
||||
raise ValueError("Illegal budget for optimizer.")
|
||||
|
||||
self.n = len(index)
|
||||
self.budget = budget
|
||||
self.already_selected = already_selected
|
||||
|
||||
|
||||
class NaiveGreedy(optimizer):
|
||||
def __init__(self, args, index, budget:int, already_selected=[]):
|
||||
super(NaiveGreedy, self).__init__(args, index, budget, already_selected)
|
||||
|
||||
def select(self, gain_function, update_state=None, **kwargs):
|
||||
assert callable(gain_function)
|
||||
if update_state is not None:
|
||||
assert callable(update_state)
|
||||
selected = np.zeros(self.n, dtype=bool)
|
||||
selected[self.already_selected] = True
|
||||
|
||||
greedy_gain = np.zeros(len(self.index))
|
||||
for i in range(sum(selected), self.budget):
|
||||
if i % self.args.TRAIN.PRINT_FREQ == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, self.budget))
|
||||
greedy_gain[~selected] = gain_function(~selected, selected, **kwargs)
|
||||
current_selection = greedy_gain.argmax()
|
||||
selected[current_selection] = True
|
||||
greedy_gain[current_selection] = -np.inf
|
||||
if update_state is not None:
|
||||
update_state(np.array([current_selection]), selected, **kwargs)
|
||||
return self.index[selected]
|
||||
|
||||
|
||||
class LazyGreedy(optimizer):
|
||||
def __init__(self, args, index, budget:int, already_selected=[]):
|
||||
super(LazyGreedy, self).__init__(args, index, budget, already_selected)
|
||||
|
||||
def select(self, gain_function, update_state=None, **kwargs):
|
||||
assert callable(gain_function)
|
||||
if update_state is not None:
|
||||
assert callable(update_state)
|
||||
selected = np.zeros(self.n, dtype=bool)
|
||||
selected[self.already_selected] = True
|
||||
|
||||
greedy_gain = np.zeros(len(self.index))
|
||||
greedy_gain[~selected] = gain_function(~selected, selected, **kwargs)
|
||||
greedy_gain[selected] = -np.inf
|
||||
|
||||
for i in tqdm(range(sum(selected), self.budget)):
|
||||
if i % self.args.TRAIN.PRINT_FREQ == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, self.budget))
|
||||
best_gain = -np.inf
|
||||
last_max_element = -1
|
||||
while True:
|
||||
cur_max_element = greedy_gain.argmax()
|
||||
if last_max_element == cur_max_element:
|
||||
# Select cur_max_element into the current subset
|
||||
selected[cur_max_element] = True
|
||||
greedy_gain[cur_max_element] = -np.inf
|
||||
|
||||
if update_state is not None:
|
||||
update_state(np.array([cur_max_element]), selected, **kwargs)
|
||||
break
|
||||
new_gain = gain_function(np.array([cur_max_element]), selected, **kwargs)[0]
|
||||
greedy_gain[cur_max_element] = new_gain
|
||||
if new_gain >= best_gain:
|
||||
best_gain = new_gain
|
||||
last_max_element = cur_max_element
|
||||
return self.index[selected]
|
||||
|
||||
|
||||
class StochasticGreedy(optimizer):
|
||||
def __init__(self, args, index, budget:int, already_selected=[], epsilon: float=0.9):
|
||||
super(StochasticGreedy, self).__init__(args, index, budget, already_selected)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def select(self, gain_function, update_state=None, **kwargs):
|
||||
assert callable(gain_function)
|
||||
if update_state is not None:
|
||||
assert callable(update_state)
|
||||
selected = np.zeros(self.n, dtype=bool)
|
||||
selected[self.already_selected] = True
|
||||
|
||||
sample_size = max(round(-np.log(self.epsilon) * self.n / self.budget), 1)
|
||||
|
||||
greedy_gain = np.zeros(len(self.index))
|
||||
all_idx = np.arange(self.n)
|
||||
for i in range(sum(selected), self.budget):
|
||||
if i % self.args.TRAIN.PRINT_FREQ == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, self.budget))
|
||||
|
||||
# Uniformly select a subset from unselected samples with size sample_size
|
||||
subset = np.random.choice(all_idx[~selected], replace=False, size=min(sample_size, self.n - i))
|
||||
|
||||
if subset.__len__() == 0:
|
||||
break
|
||||
|
||||
greedy_gain[subset] = gain_function(subset, selected, **kwargs)
|
||||
current_selection = greedy_gain[subset].argmax()
|
||||
selected[subset[current_selection]] = True
|
||||
greedy_gain[subset[current_selection]] = -np.inf
|
||||
if update_state is not None:
|
||||
update_state(np.array([subset[current_selection]]), selected, **kwargs)
|
||||
return self.index[selected]
|
||||
|
||||
|
||||
class ApproximateLazyGreedy(optimizer):
|
||||
def __init__(self, args, index, budget:int, already_selected=[], beta: float=0.9):
|
||||
super(ApproximateLazyGreedy, self).__init__(args, index, budget, already_selected)
|
||||
self.beta = beta
|
||||
|
||||
def select(self, gain_function, update_state=None, **kwargs):
|
||||
assert callable(gain_function)
|
||||
if update_state is not None:
|
||||
assert callable(update_state)
|
||||
selected = np.zeros(self.n, dtype=bool)
|
||||
selected[self.already_selected] = True
|
||||
|
||||
greedy_gain = np.zeros(len(self.index))
|
||||
greedy_gain[~selected] = gain_function(~selected, selected, **kwargs)
|
||||
greedy_gain[selected] = -np.inf
|
||||
|
||||
for i in range(sum(selected), self.budget):
|
||||
if i % self.args.TRAIN.PRINT_FREQ == 0:
|
||||
print("| Selecting [%3d/%3d]" % (i + 1, self.budget))
|
||||
while True:
|
||||
cur_max_element = greedy_gain.argmax()
|
||||
max_gain = greedy_gain[cur_max_element]
|
||||
|
||||
new_gain = gain_function(np.array([cur_max_element]), selected, **kwargs)[0]
|
||||
|
||||
if new_gain >= self.beta * max_gain:
|
||||
# Select cur_max_element into the current subset
|
||||
selected[cur_max_element] = True
|
||||
greedy_gain[cur_max_element] = -np.inf
|
||||
|
||||
if update_state is not None:
|
||||
update_state(np.array([cur_max_element]), selected, **kwargs)
|
||||
break
|
||||
else:
|
||||
greedy_gain[cur_max_element] = new_gain
|
||||
return self.index[selected]
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import numpy as np
|
||||
import torch
|
||||
from .methods_utils import cossim_np, submodular_function, submodular_optimizer
|
||||
from ..nets.nets_utils import MyDataParallel
|
||||
|
||||
|
||||
class Submodular(EarlyTrain):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, epochs=200, specific_model=None, balance=True,
|
||||
function="GraphCut", greedy="LazyGreedy", metric="cossim", **kwargs):
|
||||
super(Submodular, self).__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
||||
|
||||
if greedy not in submodular_optimizer.optimizer_choices:
|
||||
raise ModuleNotFoundError("Greedy optimizer not found.")
|
||||
print(f"The Submodular Method is {function}")
|
||||
self._greedy = greedy
|
||||
self._metric = metric
|
||||
self._function = function
|
||||
|
||||
self.balance = balance
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def before_run(self):
|
||||
pass
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
|
||||
def calc_gradient(self, index=None):
|
||||
'''
|
||||
Calculate gradients matrix on current network for specified training dataset.
|
||||
'''
|
||||
self.model.eval()
|
||||
data_loader = self.select_dm(self.dst_train, index, is_train=False)
|
||||
# Initialize a matrix to save gradients.
|
||||
# (on cpu)
|
||||
gradients = []
|
||||
|
||||
for i, batch in enumerate(data_loader):
|
||||
|
||||
self.optim.zero_grad()
|
||||
image, label = batch['img'].cuda(), batch['label'].cuda()
|
||||
bs_size = image.shape[0]
|
||||
loss,visual_embedding,logit = self.model(image,label,cal_gradient=True)
|
||||
embed_dim = visual_embedding.shape[-1]
|
||||
with torch.no_grad():
|
||||
bias_parameters_grads = torch.autograd.grad(loss, logit)[0]
|
||||
weight_parameters_grads = visual_embedding.view(bs_size, 1,
|
||||
-1).repeat(1, self.num_classes, 1) *\
|
||||
bias_parameters_grads.view(bs_size, self.num_classes,
|
||||
1).repeat(1, 1, embed_dim)
|
||||
gradients.append(torch.cat([bias_parameters_grads, weight_parameters_grads.flatten(1)],
|
||||
dim=1).cpu().numpy())
|
||||
|
||||
gradients = np.concatenate(gradients, axis=0,dtype=np.float32)
|
||||
print('Finish Gradient Calculation')
|
||||
return gradients
|
||||
|
||||
def finish_run(self):
|
||||
if isinstance(self.model, MyDataParallel):
|
||||
self.model = self.model.module
|
||||
|
||||
# Turn on the embedding recorder and the no_grad flag
|
||||
|
||||
self.model.no_grad = True
|
||||
self.train_indx = np.arange(self.n_train)
|
||||
|
||||
gradients = self.calc_gradient(index=None)
|
||||
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int64)
|
||||
for c in range(self.num_classes):
|
||||
print(f'class {c}')
|
||||
c_indx = self.train_indx[self.dst_train_label == c]
|
||||
# Calculate gradients into a matrix
|
||||
c_gradients = gradients[c_indx]
|
||||
# Instantiate a submodular function
|
||||
submod_function = submodular_function.__dict__[self._function](index=c_indx,
|
||||
similarity_kernel=lambda a, b:cossim_np(c_gradients[a], c_gradients[b]))
|
||||
submod_optimizer = submodular_optimizer.__dict__[self._greedy](args=self.args,
|
||||
index=c_indx, budget=round(self.fraction * len(c_indx)), already_selected=[])
|
||||
|
||||
c_selection_result = submod_optimizer.select(gain_function=submod_function.calc_gain,
|
||||
update_state=submod_function.update_state)
|
||||
selection_result = np.append(selection_result, c_selection_result)
|
||||
else:
|
||||
# Calculate gradients into a matrix
|
||||
gradients = self.calc_gradient()
|
||||
# Instantiate a submodular function
|
||||
submod_function = submodular_function.__dict__[self._function](index=self.train_indx,
|
||||
similarity_kernel=lambda a, b: cossim_np(gradients[a], gradients[b]))
|
||||
submod_optimizer = submodular_optimizer.__dict__[self._greedy](args=self.args, index=self.train_indx,
|
||||
budget=self.coreset_size)
|
||||
selection_result = submod_optimizer.select(gain_function=submod_function.calc_gain,
|
||||
update_state=submod_function.update_state)
|
||||
|
||||
self.model.no_grad = False
|
||||
return {"indices": selection_result}
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
from .earlytrain import EarlyTrain
|
||||
import torch
|
||||
import numpy as np
|
||||
from datasets.data_manager import select_dm_loader
|
||||
import time
|
||||
|
||||
class Uncertainty(EarlyTrain):
|
||||
def __init__(self, dst_train, args,fraction=0.5, random_seed=None, epochs=200, selection_method="Margin",
|
||||
specific_model=None, balance=False, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed, epochs, specific_model, **kwargs)
|
||||
|
||||
selection_choices = ["LeastConfidence",
|
||||
"Entropy",
|
||||
"Margin"]
|
||||
if selection_method not in selection_choices:
|
||||
raise NotImplementedError("Selection algorithm unavailable.")
|
||||
self.selection_method = selection_method
|
||||
|
||||
self.epochs = epochs
|
||||
self.balance = balance
|
||||
|
||||
def before_train(self):
|
||||
pass
|
||||
|
||||
def after_loss(self, outputs, loss, targets, batch_inds, epoch):
|
||||
pass
|
||||
|
||||
|
||||
def after_epoch(self):
|
||||
pass
|
||||
|
||||
def before_run(self):
|
||||
pass
|
||||
|
||||
def num_classes_mismatch(self):
|
||||
raise ValueError("num_classes of pretrain dataset does not match that of the training dataset.")
|
||||
|
||||
def while_update(self, outputs, loss, targets, epoch, batch_idx, batch_size):
|
||||
pass
|
||||
|
||||
def finish_run(self):
|
||||
if self.balance:
|
||||
selection_result = np.array([], dtype=np.int64)
|
||||
scores = []
|
||||
for c in range(self.num_classes):
|
||||
print(f"Balance Processing on the train set class {c}")
|
||||
class_index = np.arange(self.n_train)[self.dst_train_label == c]
|
||||
scores.append(self.rank_uncertainty_clip(class_index))
|
||||
selection_result = np.append(selection_result, class_index[np.argsort(scores[-1])[
|
||||
:round(len(class_index) * self.fraction)]])
|
||||
else:
|
||||
print(f"Imbalance Processing on the train set class")
|
||||
scores = self.rank_uncertainty_clip()
|
||||
selection_result = np.argsort(scores)[::-1][:self.coreset_size]
|
||||
return {"indices": selection_result, "scores": scores}
|
||||
|
||||
def rank_uncertainty(self,index=None):
|
||||
self.specific_model.eval()
|
||||
with torch.no_grad():
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
self.dst_train if index is None else torch.utils.data.Subset(self.dst_train, index),
|
||||
batch_size=self.args.selection_batch,
|
||||
num_workers=self.args.workers)
|
||||
|
||||
scores = np.array([])
|
||||
batch_num = len(train_loader)
|
||||
|
||||
for i, (input, _) in enumerate(train_loader):
|
||||
if i % self.args.print_freq == 0:
|
||||
print("| Selecting for batch [%3d/%3d]" % (i + 1, batch_num))
|
||||
if self.selection_method == "LeastConfidence":
|
||||
scores = np.append(scores, self.model(input.to(self.args.device)).max(axis=1).values.cpu().numpy())
|
||||
elif self.selection_method == "Entropy":
|
||||
preds = torch.nn.functional.softmax(self.model(input.to(self.args.device)), dim=1).cpu().numpy()
|
||||
scores = np.append(scores, (np.log(preds + 1e-6) * preds).sum(axis=1))
|
||||
elif self.selection_method == 'Margin':
|
||||
preds = torch.nn.functional.softmax(self.model(input.to(self.args.device)), dim=1)
|
||||
preds_argmax = torch.argmax(preds, dim=1)
|
||||
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone()
|
||||
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0
|
||||
preds_sub_argmax = torch.argmax(preds, dim=1)
|
||||
scores = np.append(scores, (max_preds - preds[
|
||||
torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]).cpu().numpy())
|
||||
return scores
|
||||
|
||||
|
||||
def rank_uncertainty_clip(self,index=None):
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
train_loader = select_dm_loader(self.args,self.dst_train,index)
|
||||
scores = np.array([])
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
# if i % self.args.print_freq == 0:
|
||||
# print("| Selecting for batch [%3d/%3d]" % (i + 1, batch_num))
|
||||
image, label = batch['img'].cuda(), batch['label'].cuda()
|
||||
logits = self.model(image,label) ##Eval mode
|
||||
if self.selection_method == "LeastConfidence":
|
||||
scores = np.append(scores, logits.max(axis=1).values.cpu().numpy())
|
||||
elif self.selection_method == "Entropy":
|
||||
preds = torch.softmax(logits, dim=1).cpu().numpy()
|
||||
scores = np.append(scores, (np.log(preds + 1e-6) * preds).sum(axis=1))
|
||||
elif self.selection_method == 'Margin':
|
||||
preds = torch.softmax(logits, dim=1)
|
||||
preds_argmax = torch.argmax(preds, dim=1)
|
||||
max_preds = preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax].clone()
|
||||
preds[torch.ones(preds.shape[0], dtype=bool), preds_argmax] = -1.0
|
||||
preds_sub_argmax = torch.argmax(preds, dim=1)
|
||||
scores = np.append(scores, (max_preds - preds[torch.ones(preds.shape[0], dtype=bool), preds_sub_argmax]).cpu().numpy())
|
||||
self.model.train()
|
||||
return scores
|
||||
|
||||
|
||||
def select(self, **kwargs):
|
||||
selection_result = self.run()
|
||||
return selection_result
|
||||
|
||||
def select_without_train(self):
|
||||
selection_result = self.finish_run()
|
||||
return selection_result
|
||||
@@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
from .coresetmethod import CoresetMethod
|
||||
|
||||
|
||||
class Uniform(CoresetMethod):
|
||||
def __init__(self, dst_train, args, fraction=0.5, random_seed=None, balance=True, replace=False, **kwargs):
|
||||
super().__init__(dst_train, args, fraction, random_seed)
|
||||
self.balance = balance
|
||||
self.replace = replace
|
||||
self.n_train = len(self.dst_train)
|
||||
|
||||
def select_balance(self):
|
||||
"""The same sampling proportions were used in each class separately."""
|
||||
np.random.seed(self.random_seed)
|
||||
self.index = np.array([], dtype=np.int64)
|
||||
all_index = np.arange(self.n_train)
|
||||
for c in range(self.num_classes):
|
||||
|
||||
c_index = (self.dst_train_label == c)
|
||||
self.index = np.append(self.index,
|
||||
np.random.choice(all_index[c_index], round(self.fraction * c_index.sum().item()),
|
||||
replace=self.replace))
|
||||
return self.index
|
||||
|
||||
def select_no_balance(self):
|
||||
np.random.seed(self.random_seed)
|
||||
self.index = np.random.choice(np.arange(self.n_train), round(self.n_train * self.fraction),
|
||||
replace=self.replace)
|
||||
|
||||
return self.index
|
||||
|
||||
def select(self, **kwargs):
|
||||
|
||||
return {"indices": self.select_balance() if self.balance else self.select_no_balance()}
|
||||
@@ -0,0 +1,8 @@
|
||||
from .alexnet import *
|
||||
from .inceptionv3 import *
|
||||
from .lenet import *
|
||||
from .mlp import *
|
||||
from .mobilenetv3 import *
|
||||
from .resnet import *
|
||||
from .vgg import *
|
||||
from .wideresnet import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,100 @@
|
||||
import torch.nn as nn
|
||||
from torch import set_grad_enabled
|
||||
from torchvision import models
|
||||
import torch
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
class AlexNet_32x32(nn.Module):
|
||||
def __init__(self, channel, num_classes, record_embedding=False, no_grad=False):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel == 1 else 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(128, 192, kernel_size=5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(192, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 192, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(192, 192, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.fc = nn.Linear(192 * 4 * 4, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class AlexNet_224x224(models.AlexNet):
|
||||
def __init__(self, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super().__init__(num_classes, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.features[0] = nn.Conv2d(channel, 64, kernel_size=11, stride=4, padding=2)
|
||||
self.fc = self.classifier[-1]
|
||||
self.classifier[-1] = self.embedding_recorder
|
||||
self.classifier.add_module("fc", self.fc)
|
||||
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def AlexNet(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
net = AlexNet_224x224(channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-7be5be79.pth'
|
||||
, progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.features[0] = nn.Conv2d(channel, 64, kernel_size=11, stride=4, padding=2)
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(4096, num_classes)
|
||||
net.classifier[-1] = net.fc
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
net = AlexNet_224x224(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad)
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
net = AlexNet_32x32(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad)
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
@@ -0,0 +1,426 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models import inception
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, input_channels, output_channels, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs)
|
||||
self.bn = nn.BatchNorm2d(output_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# same naive inception module
|
||||
class InceptionA(nn.Module):
|
||||
|
||||
def __init__(self, input_channels, pool_features):
|
||||
super().__init__()
|
||||
self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1)
|
||||
|
||||
self.branch5x5 = nn.Sequential(
|
||||
BasicConv2d(input_channels, 48, kernel_size=1),
|
||||
BasicConv2d(48, 64, kernel_size=5, padding=2)
|
||||
)
|
||||
|
||||
self.branch3x3 = nn.Sequential(
|
||||
BasicConv2d(input_channels, 64, kernel_size=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
self.branchpool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x -> 1x1(same)
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
# x -> 1x1 -> 5x5(same)
|
||||
branch5x5 = self.branch5x5(x)
|
||||
# branch5x5 = self.branch5x5_2(branch5x5)
|
||||
|
||||
# x -> 1x1 -> 3x3 -> 3x3(same)
|
||||
branch3x3 = self.branch3x3(x)
|
||||
|
||||
# x -> pool -> 1x1(same)
|
||||
branchpool = self.branchpool(x)
|
||||
|
||||
outputs = [branch1x1, branch5x5, branch3x3, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
# downsample
|
||||
# Factorization into smaller convolutions
|
||||
class InceptionB(nn.Module):
|
||||
|
||||
def __init__(self, input_channels):
|
||||
super().__init__()
|
||||
|
||||
self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch3x3stack = nn.Sequential(
|
||||
BasicConv2d(input_channels, 64, kernel_size=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
# x - > 3x3(downsample)
|
||||
branch3x3 = self.branch3x3(x)
|
||||
|
||||
# x -> 3x3 -> 3x3(downsample)
|
||||
branch3x3stack = self.branch3x3stack(x)
|
||||
|
||||
# x -> avgpool(downsample)
|
||||
branchpool = self.branchpool(x)
|
||||
|
||||
# """We can use two parallel stride 2 blocks: P and C. P is a pooling
|
||||
# layer (either average or maximum pooling) the activation, both of
|
||||
# them are stride 2 the filter banks of which are concatenated as in
|
||||
# figure 10."""
|
||||
outputs = [branch3x3, branch3x3stack, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
# Factorizing Convolutions with Large Filter Size
|
||||
class InceptionC(nn.Module):
|
||||
def __init__(self, input_channels, channels_7x7):
|
||||
super().__init__()
|
||||
self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1)
|
||||
|
||||
c7 = channels_7x7
|
||||
|
||||
# In theory, we could go even further and argue that one can replace any n × n
|
||||
# convolution by a 1 × n convolution followed by a n × 1 convolution and the
|
||||
# computational cost saving increases dramatically as n grows (see figure 6).
|
||||
self.branch7x7 = nn.Sequential(
|
||||
BasicConv2d(input_channels, c7, kernel_size=1),
|
||||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||||
BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
||||
)
|
||||
|
||||
self.branch7x7stack = nn.Sequential(
|
||||
BasicConv2d(input_channels, c7, kernel_size=1),
|
||||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||||
BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
|
||||
BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
|
||||
BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
||||
)
|
||||
|
||||
self.branch_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x -> 1x1(same)
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
# x -> 1layer 1*7 and 7*1 (same)
|
||||
branch7x7 = self.branch7x7(x)
|
||||
|
||||
# x-> 2layer 1*7 and 7*1(same)
|
||||
branch7x7stack = self.branch7x7stack(x)
|
||||
|
||||
# x-> avgpool (same)
|
||||
branchpool = self.branch_pool(x)
|
||||
|
||||
outputs = [branch1x1, branch7x7, branch7x7stack, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionD(nn.Module):
|
||||
|
||||
def __init__(self, input_channels):
|
||||
super().__init__()
|
||||
|
||||
self.branch3x3 = nn.Sequential(
|
||||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||||
BasicConv2d(192, 320, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch7x7 = nn.Sequential(
|
||||
BasicConv2d(input_channels, 192, kernel_size=1),
|
||||
BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)),
|
||||
BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)),
|
||||
BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
# x -> 1x1 -> 3x3(downsample)
|
||||
branch3x3 = self.branch3x3(x)
|
||||
|
||||
# x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample)
|
||||
branch7x7 = self.branch7x7(x)
|
||||
|
||||
# x -> avgpool (downsample)
|
||||
branchpool = self.branchpool(x)
|
||||
|
||||
outputs = [branch3x3, branch7x7, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
# same
|
||||
class InceptionE(nn.Module):
|
||||
def __init__(self, input_channels):
|
||||
super().__init__()
|
||||
self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1)
|
||||
|
||||
self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1)
|
||||
self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||||
self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||||
|
||||
self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1)
|
||||
self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
|
||||
self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||||
self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||||
|
||||
self.branch_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(input_channels, 192, kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x -> 1x1 (same)
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
# x -> 1x1 -> 3x1
|
||||
# x -> 1x1 -> 1x3
|
||||
# concatenate(3x1, 1x3)
|
||||
# """7. Inception modules with expanded the filter bank outputs.
|
||||
# This architecture is used on the coarsest (8 × 8) grids to promote
|
||||
# high dimensional representations, as suggested by principle
|
||||
# 2 of Section 2."""
|
||||
branch3x3 = self.branch3x3_1(x)
|
||||
branch3x3 = [
|
||||
self.branch3x3_2a(branch3x3),
|
||||
self.branch3x3_2b(branch3x3)
|
||||
]
|
||||
branch3x3 = torch.cat(branch3x3, 1)
|
||||
|
||||
# x -> 1x1 -> 3x3 -> 1x3
|
||||
# x -> 1x1 -> 3x3 -> 3x1
|
||||
# concatenate(1x3, 3x1)
|
||||
branch3x3stack = self.branch3x3stack_1(x)
|
||||
branch3x3stack = self.branch3x3stack_2(branch3x3stack)
|
||||
branch3x3stack = [
|
||||
self.branch3x3stack_3a(branch3x3stack),
|
||||
self.branch3x3stack_3b(branch3x3stack)
|
||||
]
|
||||
branch3x3stack = torch.cat(branch3x3stack, 1)
|
||||
|
||||
branchpool = self.branch_pool(x)
|
||||
|
||||
outputs = [branch1x1, branch3x3, branch3x3stack, branchpool]
|
||||
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionV3_32x32(nn.Module):
|
||||
|
||||
def __init__(self, channel, num_classes, record_embedding=False, no_grad=False):
|
||||
super().__init__()
|
||||
self.Conv2d_1a_3x3 = BasicConv2d(channel, 32, kernel_size=3, padding=3 if channel == 1 else 1)
|
||||
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
|
||||
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
|
||||
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
|
||||
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
|
||||
|
||||
# naive inception module
|
||||
self.Mixed_5b = InceptionA(192, pool_features=32)
|
||||
self.Mixed_5c = InceptionA(256, pool_features=64)
|
||||
self.Mixed_5d = InceptionA(288, pool_features=64)
|
||||
|
||||
# downsample
|
||||
self.Mixed_6a = InceptionB(288)
|
||||
|
||||
self.Mixed_6b = InceptionC(768, channels_7x7=128)
|
||||
self.Mixed_6c = InceptionC(768, channels_7x7=160)
|
||||
self.Mixed_6d = InceptionC(768, channels_7x7=160)
|
||||
self.Mixed_6e = InceptionC(768, channels_7x7=192)
|
||||
|
||||
# downsample
|
||||
self.Mixed_7a = InceptionD(768)
|
||||
|
||||
self.Mixed_7b = InceptionE(1280)
|
||||
self.Mixed_7c = InceptionE(2048)
|
||||
|
||||
# 6*6 feature size
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.dropout = nn.Dropout2d()
|
||||
self.linear = nn.Linear(2048, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.linear
|
||||
|
||||
def forward(self, x):
|
||||
with torch.set_grad_enabled(not self.no_grad):
|
||||
# 32 -> 30
|
||||
x = self.Conv2d_1a_3x3(x)
|
||||
x = self.Conv2d_2a_3x3(x)
|
||||
x = self.Conv2d_2b_3x3(x)
|
||||
x = self.Conv2d_3b_1x1(x)
|
||||
x = self.Conv2d_4a_3x3(x)
|
||||
|
||||
# 30 -> 30
|
||||
x = self.Mixed_5b(x)
|
||||
x = self.Mixed_5c(x)
|
||||
x = self.Mixed_5d(x)
|
||||
|
||||
# 30 -> 14
|
||||
# Efficient Grid Size Reduction to avoid representation
|
||||
# bottleneck
|
||||
x = self.Mixed_6a(x)
|
||||
|
||||
# 14 -> 14
|
||||
# """In practice, we have found that employing this factorization does not
|
||||
# work well on early layers, but it gives very good results on medium
|
||||
# grid-sizes (On m × m feature maps, where m ranges between 12 and 20).
|
||||
# On that level, very good results can be achieved by using 1 × 7 convolutions
|
||||
# followed by 7 × 1 convolutions."""
|
||||
x = self.Mixed_6b(x)
|
||||
x = self.Mixed_6c(x)
|
||||
x = self.Mixed_6d(x)
|
||||
x = self.Mixed_6e(x)
|
||||
|
||||
# 14 -> 6
|
||||
# Efficient Grid Size Reduction
|
||||
x = self.Mixed_7a(x)
|
||||
|
||||
# 6 -> 6
|
||||
# We are using this solution only on the coarsest grid,
|
||||
# since that is the place where producing high dimensional
|
||||
# sparse representation is the most critical as the ratio of
|
||||
# local processing (by 1 × 1 convolutions) is increased compared
|
||||
# to the spatial aggregation."""
|
||||
x = self.Mixed_7b(x)
|
||||
x = self.Mixed_7c(x)
|
||||
|
||||
# 6 -> 1
|
||||
x = self.avgpool(x)
|
||||
x = self.dropout(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class InceptionV3_224x224(inception.Inception3):
|
||||
def __init__(self, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super().__init__(num_classes=num_classes, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def _forward(self, x):
|
||||
with torch.set_grad_enabled(not self.no_grad):
|
||||
# N x 3 x 299 x 299
|
||||
x = self.Conv2d_1a_3x3(x)
|
||||
# N x 32 x 149 x 149
|
||||
x = self.Conv2d_2a_3x3(x)
|
||||
# N x 32 x 147 x 147
|
||||
x = self.Conv2d_2b_3x3(x)
|
||||
# N x 64 x 147 x 147
|
||||
x = self.maxpool1(x)
|
||||
# N x 64 x 73 x 73
|
||||
x = self.Conv2d_3b_1x1(x)
|
||||
# N x 80 x 73 x 73
|
||||
x = self.Conv2d_4a_3x3(x)
|
||||
# N x 192 x 71 x 71
|
||||
x = self.maxpool2(x)
|
||||
# N x 192 x 35 x 35
|
||||
x = self.Mixed_5b(x)
|
||||
# N x 256 x 35 x 35
|
||||
x = self.Mixed_5c(x)
|
||||
# N x 288 x 35 x 35
|
||||
x = self.Mixed_5d(x)
|
||||
# N x 288 x 35 x 35
|
||||
x = self.Mixed_6a(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6b(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6c(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6d(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6e(x)
|
||||
# N x 768 x 17 x 17
|
||||
aux = None
|
||||
if self.AuxLogits is not None:
|
||||
if self.training:
|
||||
aux = self.AuxLogits(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_7a(x)
|
||||
# N x 1280 x 8 x 8
|
||||
x = self.Mixed_7b(x)
|
||||
# N x 2048 x 8 x 8
|
||||
x = self.Mixed_7c(x)
|
||||
# N x 2048 x 8 x 8
|
||||
# Adaptive average pooling
|
||||
x = self.avgpool(x)
|
||||
# N x 2048 x 1 x 1
|
||||
x = self.dropout(x)
|
||||
# N x 2048 x 1 x 1
|
||||
x = torch.flatten(x, 1)
|
||||
# N x 2048
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.fc(x)
|
||||
# N x 1000 (num_classes)
|
||||
return x, aux
|
||||
|
||||
|
||||
def InceptionV3(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
net = InceptionV3_224x224(channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(inception.model_urls["inception_v3_google"], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(net.fc.in_features, num_classes)
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
net = InceptionV3_224x224(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad)
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
net = InceptionV3_32x32(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad)
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
|
||||
return net
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import set_grad_enabled
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
class LeNet(nn.Module):
|
||||
def __init__(self, channel, num_classes, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
if pretrained:
|
||||
raise NotImplementedError("torchvison pretrained models not available.")
|
||||
super(LeNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel == 1 else 0),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(6, 16, kernel_size=5),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.fc_1 = nn.Linear(16 * 53 * 53 if im_size[0] == im_size[1] == 224 else 16 * 5 * 5, 120)
|
||||
self.fc_2 = nn.Linear(120, 84)
|
||||
self.fc_3 = nn.Linear(84, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc_3
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = F.relu(self.fc_1(x))
|
||||
x = F.relu(self.fc_2(x))
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.fc_3(x)
|
||||
return x
|
||||
@@ -0,0 +1,37 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import set_grad_enabled
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
|
||||
''' MLP '''
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, channel, num_classes, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
if pretrained:
|
||||
raise NotImplementedError("torchvison pretrained models not available.")
|
||||
super(MLP, self).__init__()
|
||||
self.fc_1 = nn.Linear(im_size[0] * im_size[1] * channel, 128)
|
||||
self.fc_2 = nn.Linear(128, 128)
|
||||
self.fc_3 = nn.Linear(128, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc_3
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
out = x.view(x.size(0), -1)
|
||||
out = F.relu(self.fc_1(out))
|
||||
out = F.relu(self.fc_2(out))
|
||||
out = self.embedding_recorder(out)
|
||||
out = self.fc_3(out)
|
||||
return out
|
||||
@@ -0,0 +1,304 @@
|
||||
import torch.nn as nn
|
||||
from torch import set_grad_enabled, flatten, Tensor
|
||||
from torchvision.models import mobilenetv3
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
import math
|
||||
|
||||
'''MobileNetV3 in PyTorch.
|
||||
Paper: "Inverted Residuals and Linear Bottlenecks:Mobile Networks for Classification, Detection and Segmentation"
|
||||
|
||||
Acknowlegement to:
|
||||
https://github.com/d-li14/mobilenetv3.pytorch/blob/master/mobilenetv3.py
|
||||
'''
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class h_sigmoid(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(h_sigmoid, self).__init__()
|
||||
self.relu = nn.ReLU6(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(x + 3) / 6
|
||||
|
||||
|
||||
class h_swish(nn.Module):
|
||||
def __init__(self, inplace=True):
|
||||
super(h_swish, self).__init__()
|
||||
self.sigmoid = h_sigmoid(inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.sigmoid(x)
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, _make_divisible(channel // reduction, 8)),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(_make_divisible(channel // reduction, 8), channel),
|
||||
h_sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
def conv_3x3_bn(inp, oup, stride, padding=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, padding, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
h_swish()
|
||||
)
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
h_swish()
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
|
||||
super(InvertedResidual, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
|
||||
self.identity = stride == 1 and inp == oup
|
||||
|
||||
if inp == hidden_dim:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
h_swish() if use_hs else nn.ReLU(inplace=True),
|
||||
# Squeeze-and-Excite
|
||||
SELayer(hidden_dim) if use_se else nn.Identity(),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
h_swish() if use_hs else nn.ReLU(inplace=True),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
# Squeeze-and-Excite
|
||||
SELayer(hidden_dim) if use_se else nn.Identity(),
|
||||
h_swish() if use_hs else nn.ReLU(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.identity:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV3_32x32(nn.Module):
|
||||
def __init__(self, cfgs, mode, channel=3, num_classes=1000, record_embedding=False,
|
||||
no_grad=False, width_mult=1.):
|
||||
super(MobileNetV3_32x32, self).__init__()
|
||||
# setting of inverted residual blocks
|
||||
self.cfgs = cfgs
|
||||
assert mode in ['mobilenet_v3_large', 'mobilenet_v3_small']
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(16 * width_mult, 8)
|
||||
layers = [conv_3x3_bn(channel, input_channel, 2, padding=3 if channel == 1 else 1)]
|
||||
# building inverted residual blocks
|
||||
block = InvertedResidual
|
||||
for k, t, c, use_se, use_hs, s in self.cfgs:
|
||||
output_channel = _make_divisible(c * width_mult, 8)
|
||||
exp_size = _make_divisible(input_channel * t, 8)
|
||||
layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
|
||||
input_channel = output_channel
|
||||
self.features = nn.Sequential(*layers)
|
||||
# building last several layers
|
||||
self.conv = conv_1x1_bn(input_channel, exp_size)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
output_channel = {'mobilenet_v3_large': 1280, 'mobilenet_v3_small': 1024}
|
||||
output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[
|
||||
mode]
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(exp_size, output_channel),
|
||||
h_swish(),
|
||||
nn.Dropout(0.2),
|
||||
self.embedding_recorder,
|
||||
nn.Linear(output_channel, num_classes),
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.conv(x)
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.classifier[-1]
|
||||
|
||||
|
||||
class MobileNetV3_224x224(mobilenetv3.MobileNetV3):
|
||||
def __init__(self, inverted_residual_setting, last_channel,
|
||||
channel=3, num_classes=1000, record_embedding=False, no_grad=False, **kwargs):
|
||||
super(MobileNetV3_224x224, self).__init__(inverted_residual_setting, last_channel,
|
||||
num_classes=num_classes, **kwargs)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
|
||||
self.fc = self.classifier[-1]
|
||||
self.classifier[-1] = self.embedding_recorder
|
||||
self.classifier.add_module("fc", self.fc)
|
||||
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def MobileNetV3(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False,
|
||||
no_grad: bool = False,
|
||||
pretrained: bool = False, **kwargs):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if channel != 3:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
|
||||
inverted_residual_setting, last_channel = mobilenetv3._mobilenet_v3_conf(arch)
|
||||
net = MobileNetV3_224x224(inverted_residual_setting=inverted_residual_setting, last_channel=last_channel,
|
||||
channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad,
|
||||
**kwargs)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(mobilenetv3.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(last_channel, num_classes)
|
||||
net.classifier[-1] = net.fc
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
if channel != 3:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
inverted_residual_setting, last_channel = mobilenetv3._mobilenet_v3_conf(arch)
|
||||
net = MobileNetV3_224x224(inverted_residual_setting=inverted_residual_setting, last_channel=last_channel,
|
||||
channel=channel, num_classes=num_classes, record_embedding=record_embedding,
|
||||
no_grad=no_grad, **kwargs)
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
if arch == "mobilenet_v3_large":
|
||||
cfgs = [
|
||||
# k, t, c, SE, HS, s
|
||||
[3, 1, 16, 0, 0, 1],
|
||||
[3, 4, 24, 0, 0, 2],
|
||||
[3, 3, 24, 0, 0, 1],
|
||||
[5, 3, 40, 1, 0, 2],
|
||||
[5, 3, 40, 1, 0, 1],
|
||||
[5, 3, 40, 1, 0, 1],
|
||||
[3, 6, 80, 0, 1, 2],
|
||||
[3, 2.5, 80, 0, 1, 1],
|
||||
[3, 2.3, 80, 0, 1, 1],
|
||||
[3, 2.3, 80, 0, 1, 1],
|
||||
[3, 6, 112, 1, 1, 1],
|
||||
[3, 6, 112, 1, 1, 1],
|
||||
[5, 6, 160, 1, 1, 2],
|
||||
[5, 6, 160, 1, 1, 1],
|
||||
[5, 6, 160, 1, 1, 1]
|
||||
]
|
||||
net = MobileNetV3_32x32(cfgs, arch, channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "mobilenet_v3_small":
|
||||
cfgs = [
|
||||
# k, t, c, SE, HS, s
|
||||
[3, 1, 16, 1, 0, 2],
|
||||
[3, 4.5, 24, 0, 0, 2],
|
||||
[3, 3.67, 24, 0, 0, 1],
|
||||
[5, 4, 40, 1, 1, 2],
|
||||
[5, 6, 40, 1, 1, 1],
|
||||
[5, 6, 40, 1, 1, 1],
|
||||
[5, 3, 48, 1, 1, 1],
|
||||
[5, 3, 48, 1, 1, 1],
|
||||
[5, 6, 96, 1, 1, 2],
|
||||
[5, 6, 96, 1, 1, 1],
|
||||
[5, 6, 96, 1, 1, 1],
|
||||
]
|
||||
net = MobileNetV3_32x32(cfgs, arch, channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def MobileNetV3Large(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False, **kwargs):
|
||||
return MobileNetV3("mobilenet_v3_large", channel, num_classes, im_size, record_embedding, no_grad,
|
||||
pretrained, **kwargs)
|
||||
|
||||
|
||||
def MobileNetV3Small(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False, **kwargs):
|
||||
return MobileNetV3("mobilenet_v3_small", channel, num_classes, im_size, record_embedding, no_grad,
|
||||
pretrained, **kwargs)
|
||||
@@ -0,0 +1,2 @@
|
||||
from .parallel import *
|
||||
from .recorder import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,16 @@
|
||||
from torch.nn import DataParallel
|
||||
|
||||
|
||||
class MyDataParallel(DataParallel):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.module, name)
|
||||
def __setattr__(self, name, value):
|
||||
try:
|
||||
if name == "no_grad":
|
||||
return setattr(self.module, name, value)
|
||||
return super().__setattr__(name, value)
|
||||
except AttributeError:
|
||||
return setattr(self.module, name, value)
|
||||
@@ -0,0 +1,18 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class EmbeddingRecorder(nn.Module):
|
||||
def __init__(self, record_embedding: bool = False):
|
||||
super().__init__()
|
||||
self.record_embedding = record_embedding
|
||||
|
||||
def forward(self, x):
|
||||
if self.record_embedding:
|
||||
self.embedding = x
|
||||
return x
|
||||
|
||||
def __enter__(self):
|
||||
self.record_embedding = True
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.record_embedding = False
|
||||
@@ -0,0 +1,241 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import set_grad_enabled, flatten, Tensor
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
from torchvision.models import resnet
|
||||
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(in_planes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_32x32(nn.Module):
|
||||
def __init__(self, block, num_blocks, channel=3, num_classes=10, record_embedding: bool = False,
|
||||
no_grad: bool = False):
|
||||
super().__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = conv3x3(channel, 64)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.linear
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.embedding_recorder(out)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_224x224(resnet.ResNet):
|
||||
def __init__(self, block, layers, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super().__init__(block, layers, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
if num_classes != 1000:
|
||||
self.fc = nn.Linear(self.fc.in_features, num_classes)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = flatten(x, 1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def ResNet(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if arch == "resnet18":
|
||||
net = ResNet_224x224(resnet.BasicBlock, [2, 2, 2, 2], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet34":
|
||||
net = ResNet_224x224(resnet.BasicBlock, [3, 4, 6, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet50":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 6, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet101":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 23, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet152":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 8, 36, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(resnet.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(net.fc.in_features, num_classes)
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
if arch == "resnet18":
|
||||
net = ResNet_224x224(resnet.BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet34":
|
||||
net = ResNet_224x224(resnet.BasicBlock, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet50":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet101":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet152":
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
if arch == "resnet18":
|
||||
net = ResNet_32x32(BasicBlock, [2, 2, 2, 2], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet34":
|
||||
net = ResNet_32x32(BasicBlock, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet50":
|
||||
net = ResNet_32x32(Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet101":
|
||||
net = ResNet_32x32(Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
elif arch == "resnet152":
|
||||
net = ResNet_32x32(Bottleneck, [3, 8, 36, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def ResNet18(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet18", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def ResNet34(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet34", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def ResNet50(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet50", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def ResNet101(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet101", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def ResNet152(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return ResNet("resnet152", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
@@ -0,0 +1,128 @@
|
||||
import torch.nn as nn
|
||||
from torch import set_grad_enabled, flatten, Tensor
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
from torchvision.models import vgg
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/kuangliu/pytorch-cifar,
|
||||
# https://github.com/BIGBALLON/CIFAR-ZOO,
|
||||
|
||||
cfg_vgg = {
|
||||
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
|
||||
class VGG_32x32(nn.Module):
|
||||
def __init__(self, vgg_name, channel, num_classes, record_embedding=False, no_grad=False):
|
||||
super(VGG_32x32, self).__init__()
|
||||
self.channel = channel
|
||||
self.features = self._make_layers(cfg_vgg[vgg_name])
|
||||
self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes)
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def forward(self, x):
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.embedding_recorder(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.classifier
|
||||
|
||||
def _make_layers(self, cfg):
|
||||
layers = []
|
||||
in_channels = self.channel
|
||||
for ic, x in enumerate(cfg):
|
||||
if x == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel == 1 and ic == 0 else 1),
|
||||
nn.BatchNorm2d(x),
|
||||
nn.ReLU(inplace=True)]
|
||||
in_channels = x
|
||||
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class VGG_224x224(vgg.VGG):
|
||||
def __init__(self, features: nn.Module, channel: int, num_classes: int, record_embedding: bool = False,
|
||||
no_grad: bool = False, **kwargs):
|
||||
super(VGG_224x224, self).__init__(features, num_classes, **kwargs)
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
if channel != 3:
|
||||
self.features[0] = nn.Conv2d(channel, 64, kernel_size=3, padding=1)
|
||||
self.fc = self.classifier[-1]
|
||||
self.classifier[-1] = self.embedding_recorder
|
||||
self.classifier.add_module("fc", self.fc)
|
||||
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with set_grad_enabled(not self.no_grad):
|
||||
x = self.features(x)
|
||||
x = self.avgpool(x)
|
||||
x = flatten(x, 1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def VGG(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
net = VGG_224x224(features=vgg.make_layers(cfg_vgg[arch], True), channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(vgg.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.features[0] = nn.Conv2d(channel, 64, kernel_size=3, padding=1)
|
||||
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(4096, num_classes)
|
||||
net.classifier[-1] = net.fc
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
net = VGG_224x224(features=vgg.make_layers(cfg_vgg[arch], True), channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad)
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
net = VGG_32x32(arch, channel, num_classes=num_classes, record_embedding=record_embedding, no_grad=no_grad)
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def VGG11(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG("vgg11", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG13(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg13', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG16(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg16', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def VGG19(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return VGG('vgg19', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .nets_utils import EmbeddingRecorder
|
||||
from torchvision.models import resnet
|
||||
from .resnet import ResNet_224x224
|
||||
|
||||
|
||||
# Acknowledgement to
|
||||
# https://github.com/xternalz/WideResNet-pytorch
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
self.equalInOut = (in_planes == out_planes)
|
||||
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||
padding=0, bias=False) or None
|
||||
|
||||
def forward(self, x):
|
||||
if not self.equalInOut:
|
||||
x = self.relu1(self.bn1(x))
|
||||
else:
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, training=self.training)
|
||||
out = self.conv2(out)
|
||||
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
||||
|
||||
|
||||
class NetworkBlock(nn.Module):
|
||||
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
|
||||
super(NetworkBlock, self).__init__()
|
||||
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
|
||||
|
||||
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
|
||||
layers = []
|
||||
for i in range(int(nb_layers)):
|
||||
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class WideResNet_32x32(nn.Module):
|
||||
def __init__(self, depth, num_classes, channel=3, widen_factor=1, drop_rate=0.0, record_embedding=False,
|
||||
no_grad=False):
|
||||
super(WideResNet_32x32, self).__init__()
|
||||
nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
|
||||
assert ((depth - 4) % 6 == 0)
|
||||
n = (depth - 4) / 6
|
||||
block = BasicBlock
|
||||
# 1st conv before any network block
|
||||
self.conv1 = nn.Conv2d(channel, nChannels[0], kernel_size=3, stride=1,
|
||||
padding=3 if channel == 1 else 1, bias=False)
|
||||
# 1st block
|
||||
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, drop_rate)
|
||||
# 2nd block
|
||||
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, drop_rate)
|
||||
# 3rd block
|
||||
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, drop_rate)
|
||||
# global average pooling and classifier
|
||||
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc = nn.Linear(nChannels[3], num_classes)
|
||||
self.nChannels = nChannels[3]
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
|
||||
self.embedding_recorder = EmbeddingRecorder(record_embedding)
|
||||
self.no_grad = no_grad
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.fc
|
||||
|
||||
def forward(self, x):
|
||||
with torch.set_grad_enabled(not self.no_grad):
|
||||
out = self.conv1(x)
|
||||
out = self.block1(out)
|
||||
out = self.block2(out)
|
||||
out = self.block3(out)
|
||||
out = self.relu(self.bn1(out))
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(-1, self.nChannels)
|
||||
out = self.embedding_recorder(out)
|
||||
return self.fc(out)
|
||||
|
||||
|
||||
def WideResNet(arch: str, channel: int, num_classes: int, im_size, record_embedding: bool = False,
|
||||
no_grad: bool = False, pretrained: bool = False):
|
||||
arch = arch.lower()
|
||||
if pretrained:
|
||||
if im_size[0] != 224 or im_size[1] != 224:
|
||||
raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
|
||||
if arch == "wrn502":
|
||||
arch = "wide_resnet50_2"
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 6, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad, width_per_group=64 * 2)
|
||||
elif arch == "wrn1012":
|
||||
arch = "wide_resnet101_2"
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 23, 3], channel=3, num_classes=1000,
|
||||
record_embedding=record_embedding, no_grad=no_grad, width_per_group=64 * 2)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(resnet.model_urls[arch], progress=True)
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
if channel != 3:
|
||||
net.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
if num_classes != 1000:
|
||||
net.fc = nn.Linear(net.fc.in_features, num_classes)
|
||||
|
||||
elif im_size[0] == 224 and im_size[1] == 224:
|
||||
# Use torchvision models without pretrained parameters
|
||||
if arch == "wrn502":
|
||||
arch = "wide_resnet50_2"
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 6, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad, width_per_group=64 * 2)
|
||||
elif arch == "wrn1012":
|
||||
arch = "wide_resnet101_2"
|
||||
net = ResNet_224x224(resnet.Bottleneck, [3, 4, 23, 3], channel=channel, num_classes=num_classes,
|
||||
record_embedding=record_embedding, no_grad=no_grad, width_per_group=64 * 2)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
|
||||
elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
|
||||
channel == 3 and im_size[0] == 32 and im_size[1] == 32):
|
||||
if arch == "wrn168":
|
||||
net = WideResNet_32x32(16, num_classes, channel, 8)
|
||||
elif arch == "wrn2810":
|
||||
net = WideResNet_32x32(28, num_classes, channel, 10)
|
||||
elif arch == "wrn282":
|
||||
net = WideResNet_32x32(28, num_classes, channel, 2)
|
||||
else:
|
||||
raise ValueError("Model architecture not found.")
|
||||
else:
|
||||
raise NotImplementedError("Network Architecture for current dataset has not been implemented.")
|
||||
return net
|
||||
|
||||
|
||||
def WRN168(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet("wrn168", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def WRN2810(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet("wrn2810", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def WRN282(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet('wrn282', channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def WRN502(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet("wrn502", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
|
||||
|
||||
def WRN1012(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
|
||||
pretrained: bool = False):
|
||||
return WideResNet("wrn1012", channel, num_classes, im_size, record_embedding, no_grad, pretrained)
|
||||
Reference in New Issue
Block a user