100 lines
3.5 KiB
Python
100 lines
3.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
|
|
|
from clip import clip
|
|
from clip.model import convert_weights
|
|
|
|
from .coop import load_clip_to_cpu
|
|
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
|
|
|
|
CUSTOM_TEMPLATES = {
|
|
"OxfordPets": "a photo of a {}, a type of pet.",
|
|
"OxfordFlowers": "a photo of a {}, a type of flower.",
|
|
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
|
|
"DescribableTextures": "{} texture.",
|
|
"EuroSAT": "a centered satellite photo of {}.",
|
|
"StanfordCars": "a photo of a {}.",
|
|
"Food101": "a photo of {}, a type of food.",
|
|
"SUN397": "a photo of a {}.",
|
|
"Caltech101": "a photo of a {}.",
|
|
"UCF101": "a photo of a person doing {}.",
|
|
"ImageNet": "a photo of a {}.",
|
|
"ImageNetSketch": "a photo of a {}.",
|
|
"ImageNetV2": "a photo of a {}.",
|
|
"ImageNetA": "a photo of a {}.",
|
|
"ImageNetR": "a photo of a {}.",
|
|
}
|
|
|
|
|
|
@TRAINER_REGISTRY.register()
|
|
class ZeroshotCLIP(TrainerX):
|
|
def build_model(self):
|
|
cfg = self.cfg
|
|
classnames = self.dm.dataset.classnames
|
|
|
|
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
|
clip_model = load_clip_to_cpu(cfg)
|
|
clip_model.to(self.device)
|
|
|
|
temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
|
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
|
|
print(f"Prompts: {prompts}")
|
|
prompts = torch.cat([clip.tokenize(p) for p in prompts])
|
|
prompts = prompts.to(self.device)
|
|
|
|
with torch.no_grad():
|
|
text_features = clip_model.encode_text(prompts)
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
|
|
self.text_features = text_features
|
|
self.clip_model = clip_model
|
|
|
|
def model_inference(self, image):
|
|
image_features = self.clip_model.encode_image(image)
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
|
logit_scale = self.clip_model.logit_scale.exp()
|
|
logits = logit_scale * image_features @ self.text_features.t()
|
|
return logits
|
|
|
|
|
|
@TRAINER_REGISTRY.register()
|
|
class ZeroshotCLIP2(ZeroshotCLIP):
|
|
"""Prompt ensembling."""
|
|
|
|
# templates = IMAGENET_TEMPLATES
|
|
templates = IMAGENET_TEMPLATES_SELECT
|
|
|
|
def build_model(self):
|
|
cfg = self.cfg
|
|
classnames = self.dm.dataset.classnames
|
|
|
|
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
|
clip_model = load_clip_to_cpu(cfg)
|
|
clip_model.to(self.device)
|
|
|
|
for params in clip_model.parameters():
|
|
params.requires_grad_(False)
|
|
|
|
# add custom-made prompt
|
|
if cfg.DATASET.NAME != "ImageNet":
|
|
self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
|
|
|
|
num_temp = len(self.templates)
|
|
print(f"Prompt ensembling (n={num_temp})")
|
|
|
|
mean_text_features = 0
|
|
for i, temp in enumerate(self.templates):
|
|
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
|
|
prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
|
|
text_features = clip_model.encode_text(prompts)
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
mean_text_features = mean_text_features + text_features
|
|
mean_text_features = mean_text_features / num_temp
|
|
mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)
|
|
|
|
self.text_features = mean_text_features
|
|
self.clip_model = clip_model
|