template to llm desc

This commit is contained in:
2026-02-04 01:03:23 +08:00
parent bb95c77b63
commit f9beacf476
31 changed files with 39563 additions and 13 deletions

View File

@@ -5,16 +5,19 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
import json
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from .imagenet_templates import IMAGENET_TEMPLATES
_tokenizer = _Tokenizer()
DESC_LLM = "gpt-4.1"
DESC_TOPK = 4
def load_clip_to_cpu(cfg, zero_shot_model=False):
backbone_name = cfg.MODEL.BACKBONE.NAME
@@ -116,16 +119,23 @@ class VLPromptLearner(nn.Module):
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.ZS_image_encoder = clip_model_temp_image.visual
# Now pre-compute the frozen VL embeddings
# Now pre-compute the frozen VL embeddings from LLM descriptions
all_teacher_features = []
# Using multiple text templates to ensure textual diversity during training
for single_template in IMAGENET_TEMPLATES:
x = [single_template.replace("{}", name) for name in classnames]
x_tokenized = torch.cat([clip.tokenize(p) for p in x])
text_features = clip_model_temp.encode_text(x_tokenized.cuda())
all_teacher_features.append(text_features.unsqueeze(1))
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
with open(desc_file, "r") as f:
all_desc = json.load(f)
self.fixed_embeddings = torch.cat(all_teacher_features, dim=1).mean(dim=1)
for cls in classnames:
cls_descs = [f"a photo of {cls}, {desc}" for desc in all_desc[cls]]
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
with torch.no_grad():
cls_feature = clip_model_temp.encode_text(cls_token)
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
cls_feature = torch.mean(cls_feature, dim=0)
all_teacher_features.append(cls_feature)
self.fixed_embeddings = torch.stack(all_teacher_features)
print(f"Using LLM descriptions from: {desc_file}")
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names