scripts and template

This commit is contained in:
2026-02-04 10:24:11 +08:00
parent f9beacf476
commit ea5e9f17ba
3 changed files with 149 additions and 11 deletions

View File

@@ -18,6 +18,24 @@ _tokenizer = _Tokenizer()
DESC_LLM = "gpt-4.1"
DESC_TOPK = 4
CUSTOM_TEMPLATES = {
"OxfordPets": "a photo of a {}, a type of pet.",
"OxfordFlowers": "a photo of a {}, a type of flower.",
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
"DescribableTextures": "a photo of a {}, a type of texture.",
"EuroSAT": "a centered satellite photo of {}.",
"StanfordCars": "a photo of a {}.",
"Food101": "a photo of {}, a type of food.",
"SUN397": "a photo of a {}.",
"Caltech101": "a photo of a {}.",
"UCF101": "a photo of a person doing {}.",
"ImageNet": "a photo of a {}.",
"ImageNetSketch": "a photo of a {}.",
"ImageNetV2": "a photo of a {}.",
"ImageNetA": "a photo of a {}.",
"ImageNetR": "a photo of a {}.",
}
def load_clip_to_cpu(cfg, zero_shot_model=False):
backbone_name = cfg.MODEL.BACKBONE.NAME
@@ -125,8 +143,9 @@ class VLPromptLearner(nn.Module):
with open(desc_file, "r") as f:
all_desc = json.load(f)
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
for cls in classnames:
cls_descs = [f"a photo of {cls}, {desc}" for desc in all_desc[cls]]
cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]]
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
with torch.no_grad():
cls_feature = clip_model_temp.encode_text(cls_token)