scripts and template
This commit is contained in:
@@ -18,6 +18,24 @@ _tokenizer = _Tokenizer()
|
||||
DESC_LLM = "gpt-4.1"
|
||||
DESC_TOPK = 4
|
||||
|
||||
CUSTOM_TEMPLATES = {
|
||||
"OxfordPets": "a photo of a {}, a type of pet.",
|
||||
"OxfordFlowers": "a photo of a {}, a type of flower.",
|
||||
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
|
||||
"DescribableTextures": "a photo of a {}, a type of texture.",
|
||||
"EuroSAT": "a centered satellite photo of {}.",
|
||||
"StanfordCars": "a photo of a {}.",
|
||||
"Food101": "a photo of {}, a type of food.",
|
||||
"SUN397": "a photo of a {}.",
|
||||
"Caltech101": "a photo of a {}.",
|
||||
"UCF101": "a photo of a person doing {}.",
|
||||
"ImageNet": "a photo of a {}.",
|
||||
"ImageNetSketch": "a photo of a {}.",
|
||||
"ImageNetV2": "a photo of a {}.",
|
||||
"ImageNetA": "a photo of a {}.",
|
||||
"ImageNetR": "a photo of a {}.",
|
||||
}
|
||||
|
||||
|
||||
def load_clip_to_cpu(cfg, zero_shot_model=False):
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
@@ -125,8 +143,9 @@ class VLPromptLearner(nn.Module):
|
||||
with open(desc_file, "r") as f:
|
||||
all_desc = json.load(f)
|
||||
|
||||
template = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
||||
for cls in classnames:
|
||||
cls_descs = [f"a photo of {cls}, {desc}" for desc in all_desc[cls]]
|
||||
cls_descs = [template.format(cls)[:-1] + f", {desc}" for desc in all_desc[cls]]
|
||||
cls_token = torch.cat([clip.tokenize(cls_desc) for cls_desc in cls_descs]).cuda()
|
||||
with torch.no_grad():
|
||||
cls_feature = clip_model_temp.encode_text(cls_token)
|
||||
|
||||
Reference in New Issue
Block a user