This commit is contained in:
2024-05-21 19:41:56 +08:00
commit ca67205608
217 changed files with 201004 additions and 0 deletions

101
Adapter.py Normal file
View File

@@ -0,0 +1,101 @@
import torch.nn as nn
from clip import clip
import torch
class Weight_Adapter(nn.Module):
def __init__(self, args, classnames,init_weights):
super().__init__()
self.classnames = classnames
if args.name in ["ViT-B/16", "ViT-B/32"]:
n_input = 512
elif args.name in ["RN50", "RN50x16"]:
n_input = 1024
n_output = 2 * len(classnames)
self.linear = nn.Linear(n_input, n_output, bias=False)
self.linear.weight.data = init_weights # Initialize linear layer weights
def forward(self, x):
x = self.linear(x)
return x
class Classifier(nn.Module):
def __init__(self, args, classnames,init_weights):
super().__init__()
self.classnames = classnames
if args.name in ["ViT-B/16", "ViT-B/32"]:
n_input = 512
elif args.name in ["RN50", "RN50x16"]:
n_input = 1024
n_output = len(classnames)
self.linear = nn.Linear(n_input, n_output, bias=False)
self.linear.weight.data = init_weights # Initialize linear layer weights
def forward(self, x):
x = self.linear(x)
return x
class Adapter(nn.Module):
def __init__(self, args, classnames,init_weights):
super().__init__()
self.classnames = classnames
if args.name in ["ViT-B/16", "ViT-B/32"]:
n_input = 512
elif args.name in ["RN50", "RN50x16"]:
n_input = 1024
n_output = len(classnames)
self.linear = nn.Linear(n_input, n_output, bias=False)
self.linear.weight.data = init_weights # Initialize linear layer weights
def forward(self, x):
x = self.linear(x)
return x
class Linear(nn.Module):
def __init__(self, args):
super().__init__()
if args.name in ["ViT-B/16", "ViT-B/32"]:
n_input = 512
elif args.name in ["RN50", "RN50x16"]:
n_input = 1024
self.linear = nn.Linear(n_input, n_input, bias=False)
def forward(self, x):
x = self.linear(x)
return x
class Res_Adapter(nn.Module):
def __init__(self, n_input, ):
super().__init__()
self.residual_ratio = 0.5
self.fc = nn.Sequential(
nn.Linear(n_input, n_input // 4, bias=False),
nn.ReLU(inplace=True),
nn.Linear(n_input // 4, n_input, bias=False),
nn.ReLU(inplace=True)
)
def forward(self, x):
a = self.fc(x)
x = self.residual_ratio * a + (1 - self.residual_ratio) * x
return x
def all_classifier(classnames, templates, model):
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
classname = classname.replace('_', ' ')
texts = [template.format(classname) for template in templates] # format with class
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
class_embeddings = model.encode_text(texts) # embed with text encoder
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
return zeroshot_weights