template to llm desc
This commit is contained in:
@@ -5,16 +5,19 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
import json
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.utils import load_pretrained_weights, load_checkpoint
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from clip import clip
|
||||
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
from .imagenet_templates import IMAGENET_TEMPLATES
|
||||
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
DESC_LLM = "gpt-4.1"
|
||||
DESC_TOPK = 4
|
||||
|
||||
|
||||
def load_clip_to_cpu(cfg, zero_shot_model=False):
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
@@ -116,16 +119,23 @@ class VLPromptLearner(nn.Module):
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
||||
self.ZS_image_encoder = clip_model_temp_image.visual
|
||||
# Now pre-compute the frozen VL embeddings
|
||||
# Now pre-compute the frozen VL embeddings from LLM descriptions
|
||||
all_teacher_features = []
|
||||
# Using multiple text templates to ensure textual diversity during training
|
||||
for single_template in IMAGENET_TEMPLATES:
|
||||
x = [single_template.replace("{}", name) for name in classnames]
|
||||
x_tokenized = torch.cat([clip.tokenize(p) for p in x])
|
||||
text_features = clip_model_temp.encode_text(x_tokenized.cuda())
|
||||
all_teacher_features.append(text_features.unsqueeze(1))
|
||||
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
|
||||
with open(desc_file, "r") as f:
|
||||
all_desc = json.load(f)
|
||||
|
||||
self.fixed_embeddings = torch.cat(all_teacher_features, dim=1).mean(dim=1)
|
||||
for cls in classnames:
|
||||
cls_descs = [f"a photo of {cls}, {desc}" for desc in all_desc[cls]]
|
||||
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
|
||||
with torch.no_grad():
|
||||
cls_feature = clip_model_temp.encode_text(cls_token)
|
||||
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
|
||||
cls_feature = torch.mean(cls_feature, dim=0)
|
||||
all_teacher_features.append(cls_feature)
|
||||
|
||||
self.fixed_embeddings = torch.stack(all_teacher_features)
|
||||
print(f"Using LLM descriptions from: {desc_file}")
|
||||
# These token vectors will be saved when in save_model(),
|
||||
# but they should be ignored in load_model() as we want to use
|
||||
# those computed using the current class names
|
||||
|
||||
Reference in New Issue
Block a user