import torch.nn as nn from clip import clip import torch class Weight_Adapter(nn.Module): def __init__(self, args, classnames,init_weights): super().__init__() self.classnames = classnames if args.name in ["ViT-B/16", "ViT-B/32"]: n_input = 512 elif args.name in ["RN50", "RN50x16"]: n_input = 1024 n_output = 2 * len(classnames) self.linear = nn.Linear(n_input, n_output, bias=False) self.linear.weight.data = init_weights # Initialize linear layer weights def forward(self, x): x = self.linear(x) return x class Classifier(nn.Module): def __init__(self, args, classnames,init_weights): super().__init__() self.classnames = classnames if args.name in ["ViT-B/16", "ViT-B/32"]: n_input = 512 elif args.name in ["RN50", "RN50x16"]: n_input = 1024 n_output = len(classnames) self.linear = nn.Linear(n_input, n_output, bias=False) self.linear.weight.data = init_weights # Initialize linear layer weights def forward(self, x): x = self.linear(x) return x class Adapter(nn.Module): def __init__(self, args, classnames,init_weights): super().__init__() self.classnames = classnames if args.name in ["ViT-B/16", "ViT-B/32"]: n_input = 512 elif args.name in ["RN50", "RN50x16"]: n_input = 1024 n_output = len(classnames) self.linear = nn.Linear(n_input, n_output, bias=False) self.linear.weight.data = init_weights # Initialize linear layer weights def forward(self, x): x = self.linear(x) return x class Linear(nn.Module): def __init__(self, args): super().__init__() if args.name in ["ViT-B/16", "ViT-B/32"]: n_input = 512 elif args.name in ["RN50", "RN50x16"]: n_input = 1024 self.linear = nn.Linear(n_input, n_input, bias=False) def forward(self, x): x = self.linear(x) return x class Res_Adapter(nn.Module): def __init__(self, n_input, ): super().__init__() self.residual_ratio = 0.5 self.fc = nn.Sequential( nn.Linear(n_input, n_input // 4, bias=False), nn.ReLU(inplace=True), nn.Linear(n_input // 4, n_input, bias=False), nn.ReLU(inplace=True) ) def forward(self, x): a = self.fc(x) x = self.residual_ratio * a + (1 - self.residual_ratio) * x return x def all_classifier(classnames, templates, model): with torch.no_grad(): zeroshot_weights = [] for classname in classnames: classname = classname.replace('_', ' ') texts = [template.format(classname) for template in templates] # format with class texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字 class_embeddings = model.encode_text(texts) # embed with text encoder class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() return zeroshot_weights