Release of PromptSRC with pretrained models.
This commit is contained in:
99
trainers/zsclip.py
Normal file
99
trainers/zsclip.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
|
||||
from clip import clip
|
||||
from clip.model import convert_weights
|
||||
|
||||
from .coop import load_clip_to_cpu
|
||||
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
|
||||
|
||||
CUSTOM_TEMPLATES = {
|
||||
"OxfordPets": "a photo of a {}, a type of pet.",
|
||||
"OxfordFlowers": "a photo of a {}, a type of flower.",
|
||||
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
|
||||
"DescribableTextures": "{} texture.",
|
||||
"EuroSAT": "a centered satellite photo of {}.",
|
||||
"StanfordCars": "a photo of a {}.",
|
||||
"Food101": "a photo of {}, a type of food.",
|
||||
"SUN397": "a photo of a {}.",
|
||||
"Caltech101": "a photo of a {}.",
|
||||
"UCF101": "a photo of a person doing {}.",
|
||||
"ImageNet": "a photo of a {}.",
|
||||
"ImageNetSketch": "a photo of a {}.",
|
||||
"ImageNetV2": "a photo of a {}.",
|
||||
"ImageNetA": "a photo of a {}.",
|
||||
"ImageNetR": "a photo of a {}.",
|
||||
}
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class ZeroshotCLIP(TrainerX):
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
clip_model.to(self.device)
|
||||
|
||||
temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
||||
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
|
||||
print(f"Prompts: {prompts}")
|
||||
prompts = torch.cat([clip.tokenize(p) for p in prompts])
|
||||
prompts = prompts.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = clip_model.encode_text(prompts)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
self.text_features = text_features
|
||||
self.clip_model = clip_model
|
||||
|
||||
def model_inference(self, image):
|
||||
image_features = self.clip_model.encode_image(image)
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
logit_scale = self.clip_model.logit_scale.exp()
|
||||
logits = logit_scale * image_features @ self.text_features.t()
|
||||
return logits
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class ZeroshotCLIP2(ZeroshotCLIP):
|
||||
"""Prompt ensembling."""
|
||||
|
||||
# templates = IMAGENET_TEMPLATES
|
||||
templates = IMAGENET_TEMPLATES_SELECT
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
clip_model.to(self.device)
|
||||
|
||||
for params in clip_model.parameters():
|
||||
params.requires_grad_(False)
|
||||
|
||||
# add custom-made prompt
|
||||
if cfg.DATASET.NAME != "ImageNet":
|
||||
self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
|
||||
|
||||
num_temp = len(self.templates)
|
||||
print(f"Prompt ensembling (n={num_temp})")
|
||||
|
||||
mean_text_features = 0
|
||||
for i, temp in enumerate(self.templates):
|
||||
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
|
||||
prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
|
||||
text_features = clip_model.encode_text(prompts)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
mean_text_features = mean_text_features + text_features
|
||||
mean_text_features = mean_text_features / num_temp
|
||||
mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
self.text_features = mean_text_features
|
||||
self.clip_model = clip_model
|
||||
Reference in New Issue
Block a user