102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
import torch.nn as nn
|
|
from clip import clip
|
|
import torch
|
|
|
|
|
|
|
|
class Weight_Adapter(nn.Module):
|
|
def __init__(self, args, classnames,init_weights):
|
|
super().__init__()
|
|
self.classnames = classnames
|
|
if args.name in ["ViT-B/16", "ViT-B/32"]:
|
|
n_input = 512
|
|
elif args.name in ["RN50", "RN50x16"]:
|
|
n_input = 1024
|
|
n_output = 2 * len(classnames)
|
|
self.linear = nn.Linear(n_input, n_output, bias=False)
|
|
self.linear.weight.data = init_weights # Initialize linear layer weights
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
class Classifier(nn.Module):
|
|
def __init__(self, args, classnames,init_weights):
|
|
super().__init__()
|
|
self.classnames = classnames
|
|
if args.name in ["ViT-B/16", "ViT-B/32"]:
|
|
n_input = 512
|
|
elif args.name in ["RN50", "RN50x16"]:
|
|
n_input = 1024
|
|
n_output = len(classnames)
|
|
self.linear = nn.Linear(n_input, n_output, bias=False)
|
|
self.linear.weight.data = init_weights # Initialize linear layer weights
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
class Adapter(nn.Module):
|
|
def __init__(self, args, classnames,init_weights):
|
|
super().__init__()
|
|
self.classnames = classnames
|
|
if args.name in ["ViT-B/16", "ViT-B/32"]:
|
|
n_input = 512
|
|
elif args.name in ["RN50", "RN50x16"]:
|
|
n_input = 1024
|
|
n_output = len(classnames)
|
|
self.linear = nn.Linear(n_input, n_output, bias=False)
|
|
self.linear.weight.data = init_weights # Initialize linear layer weights
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
class Linear(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
if args.name in ["ViT-B/16", "ViT-B/32"]:
|
|
n_input = 512
|
|
elif args.name in ["RN50", "RN50x16"]:
|
|
n_input = 1024
|
|
self.linear = nn.Linear(n_input, n_input, bias=False)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
class Res_Adapter(nn.Module):
|
|
def __init__(self, n_input, ):
|
|
super().__init__()
|
|
self.residual_ratio = 0.5
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(n_input, n_input // 4, bias=False),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(n_input // 4, n_input, bias=False),
|
|
nn.ReLU(inplace=True)
|
|
|
|
)
|
|
|
|
def forward(self, x):
|
|
a = self.fc(x)
|
|
x = self.residual_ratio * a + (1 - self.residual_ratio) * x
|
|
|
|
return x
|
|
|
|
def all_classifier(classnames, templates, model):
|
|
with torch.no_grad():
|
|
zeroshot_weights = []
|
|
for classname in classnames:
|
|
classname = classname.replace('_', ' ')
|
|
texts = [template.format(classname) for template in templates] # format with class
|
|
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
|
|
class_embeddings = model.encode_text(texts) # embed with text encoder
|
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
|
class_embedding = class_embeddings.mean(dim=0)
|
|
class_embedding /= class_embedding.norm()
|
|
zeroshot_weights.append(class_embedding)
|
|
|
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
|
return zeroshot_weights
|
|
|
|
|