Files
PromptSRC/trainers/zsclip.py
2023-07-13 23:43:31 +05:00

100 lines
3.5 KiB
Python

import torch
import torch.nn as nn
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.model import convert_weights
from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
CUSTOM_TEMPLATES = {
"OxfordPets": "a photo of a {}, a type of pet.",
"OxfordFlowers": "a photo of a {}, a type of flower.",
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
"DescribableTextures": "{} texture.",
"EuroSAT": "a centered satellite photo of {}.",
"StanfordCars": "a photo of a {}.",
"Food101": "a photo of {}, a type of food.",
"SUN397": "a photo of a {}.",
"Caltech101": "a photo of a {}.",
"UCF101": "a photo of a person doing {}.",
"ImageNet": "a photo of a {}.",
"ImageNetSketch": "a photo of a {}.",
"ImageNetV2": "a photo of a {}.",
"ImageNetA": "a photo of a {}.",
"ImageNetR": "a photo of a {}.",
}
@TRAINER_REGISTRY.register()
class ZeroshotCLIP(TrainerX):
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
clip_model.to(self.device)
temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
print(f"Prompts: {prompts}")
prompts = torch.cat([clip.tokenize(p) for p in prompts])
prompts = prompts.to(self.device)
with torch.no_grad():
text_features = clip_model.encode_text(prompts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
self.text_features = text_features
self.clip_model = clip_model
def model_inference(self, image):
image_features = self.clip_model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
logit_scale = self.clip_model.logit_scale.exp()
logits = logit_scale * image_features @ self.text_features.t()
return logits
@TRAINER_REGISTRY.register()
class ZeroshotCLIP2(ZeroshotCLIP):
"""Prompt ensembling."""
# templates = IMAGENET_TEMPLATES
templates = IMAGENET_TEMPLATES_SELECT
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
clip_model.to(self.device)
for params in clip_model.parameters():
params.requires_grad_(False)
# add custom-made prompt
if cfg.DATASET.NAME != "ImageNet":
self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
num_temp = len(self.templates)
print(f"Prompt ensembling (n={num_temp})")
mean_text_features = 0
for i, temp in enumerate(self.templates):
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
text_features = clip_model.encode_text(prompts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
mean_text_features = mean_text_features + text_features
mean_text_features = mean_text_features / num_temp
mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)
self.text_features = mean_text_features
self.clip_model = clip_model