Upload to Main
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,318 @@
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.utils import load_pretrained_weights, load_checkpoint
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
|
||||
from clip import clip
|
||||
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
|
||||
def load_clip_to_cpu(cfg):
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
url = clip._MODELS[backbone_name]
|
||||
model_path = clip._download(url)
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location="cpu").eval()
|
||||
state_dict = None
|
||||
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
design_details = {"trainer": 'CoCoOp',
|
||||
"vision_depth": 0,
|
||||
"language_depth": 0, "vision_ctx": 0,
|
||||
"language_ctx": 0}
|
||||
model = clip.build_model(state_dict or model.state_dict(), design_details)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self, clip_model):
|
||||
super().__init__()
|
||||
self.transformer = clip_model.transformer
|
||||
self.positional_embedding = clip_model.positional_embedding
|
||||
self.ln_final = clip_model.ln_final
|
||||
self.text_projection = clip_model.text_projection
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, prompts, tokenized_prompts):
|
||||
x = prompts + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PromptLearner(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
n_cls = len(classnames)
|
||||
n_ctx = cfg.TRAINER.COCOOP.N_CTX
|
||||
ctx_init = cfg.TRAINER.COCOOP.CTX_INIT
|
||||
dtype = clip_model.dtype
|
||||
ctx_dim = clip_model.ln_final.weight.shape[0]
|
||||
vis_dim = clip_model.visual.output_dim
|
||||
clip_imsize = clip_model.visual.input_resolution
|
||||
cfg_imsize = cfg.INPUT.SIZE[0]
|
||||
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
|
||||
|
||||
if ctx_init:
|
||||
# use given words to initialize context vectors
|
||||
ctx_init = ctx_init.replace("_", " ")
|
||||
n_ctx = len(ctx_init.split(" "))
|
||||
prompt = clip.tokenize(ctx_init)
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(prompt).type(dtype)
|
||||
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
|
||||
prompt_prefix = ctx_init
|
||||
else:
|
||||
# random initialization
|
||||
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
|
||||
nn.init.normal_(ctx_vectors, std=0.02)
|
||||
prompt_prefix = " ".join(["X"] * n_ctx)
|
||||
|
||||
print(f'Initial context: "{prompt_prefix}"')
|
||||
print(f"Number of context words (tokens): {n_ctx}")
|
||||
|
||||
self.ctx = nn.Parameter(ctx_vectors)
|
||||
|
||||
self.meta_net = nn.Sequential(OrderedDict([
|
||||
("linear1", nn.Linear(vis_dim, vis_dim // 16)),
|
||||
("relu", nn.ReLU(inplace=True)),
|
||||
("linear2", nn.Linear(vis_dim // 16, ctx_dim))
|
||||
]))
|
||||
|
||||
if cfg.TRAINER.COCOOP.PREC == "fp16":
|
||||
self.meta_net.half()
|
||||
|
||||
classnames = [name.replace("_", " ") for name in classnames]
|
||||
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
|
||||
prompts = [prompt_prefix + " " + name + "." for name in classnames]
|
||||
|
||||
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
||||
|
||||
# These token vectors will be saved when in save_model(),
|
||||
# but they should be ignored in load_model() as we want to use
|
||||
# those computed using the current class names
|
||||
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
|
||||
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS
|
||||
|
||||
self.n_cls = n_cls
|
||||
self.n_ctx = n_ctx
|
||||
self.tokenized_prompts = tokenized_prompts # torch.Tensor
|
||||
self.name_lens = name_lens
|
||||
|
||||
def construct_prompts(self, ctx, prefix, suffix, label=None):
|
||||
# dim0 is either batch_size (during training) or n_cls (during testing)
|
||||
# ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
|
||||
# prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
|
||||
# suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
|
||||
|
||||
if label is not None:
|
||||
prefix = prefix[label]
|
||||
suffix = suffix[label]
|
||||
|
||||
prompts = torch.cat(
|
||||
[
|
||||
prefix, # (dim0, 1, dim)
|
||||
ctx, # (dim0, n_ctx, dim)
|
||||
suffix, # (dim0, *, dim)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return prompts
|
||||
|
||||
def forward(self, im_features):
|
||||
prefix = self.token_prefix
|
||||
suffix = self.token_suffix
|
||||
ctx = self.ctx # (n_ctx, ctx_dim)
|
||||
bias = self.meta_net(im_features) # (batch, ctx_dim)
|
||||
bias = bias.unsqueeze(1) # (batch, 1, ctx_dim)
|
||||
ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim)
|
||||
ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim)
|
||||
|
||||
# Use instance-conditioned context tokens for all classes
|
||||
prompts = []
|
||||
for ctx_shifted_i in ctx_shifted:
|
||||
ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
|
||||
pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim)
|
||||
prompts.append(pts_i)
|
||||
prompts = torch.stack(prompts)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
class CustomCLIP(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
|
||||
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
|
||||
self.image_encoder = clip_model.visual
|
||||
self.text_encoder = TextEncoder(clip_model)
|
||||
self.logit_scale = clip_model.logit_scale
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, image, label=None):
|
||||
tokenized_prompts = self.tokenized_prompts
|
||||
logit_scale = self.logit_scale.exp()
|
||||
|
||||
image_features = self.image_encoder(image.type(self.dtype))
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
prompts = self.prompt_learner(image_features)
|
||||
|
||||
logits = []
|
||||
for pts_i, imf_i in zip(prompts, image_features):
|
||||
text_features = self.text_encoder(pts_i, tokenized_prompts)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
l_i = logit_scale * imf_i @ text_features.t()
|
||||
logits.append(l_i)
|
||||
logits = torch.stack(logits)
|
||||
|
||||
if self.prompt_learner.training:
|
||||
return F.cross_entropy(logits, label)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class CoCoOp(TrainerX):
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.TRAINER.COCOOP.PREC in ["fp16", "fp32", "amp"]
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
|
||||
if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp":
|
||||
# CLIP's default precision is fp16
|
||||
clip_model.float()
|
||||
|
||||
print("Building custom CLIP")
|
||||
self.model = CustomCLIP(cfg, classnames, clip_model)
|
||||
|
||||
print("Turning off gradients in both the image and the text encoder")
|
||||
name_to_update = "prompt_learner"
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if name_to_update not in name:
|
||||
param.requires_grad_(False)
|
||||
|
||||
# Double check
|
||||
enabled = set()
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
enabled.add(name)
|
||||
print(f"Parameters to be updated: {enabled}")
|
||||
|
||||
if cfg.MODEL.INIT_WEIGHTS:
|
||||
load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
|
||||
|
||||
self.model.to(self.device)
|
||||
# NOTE: only give prompt_learner to the optimizer
|
||||
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
|
||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
|
||||
|
||||
self.scaler = GradScaler() if cfg.TRAINER.COCOOP.PREC == "amp" else None
|
||||
|
||||
# Note that multi-gpu training could be slow because CLIP's size is
|
||||
# big, which slows down the copy operation in DataParallel
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count > 1:
|
||||
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
image, label = self.parse_batch_train(batch)
|
||||
|
||||
model = self.model
|
||||
optim = self.optim
|
||||
scaler = self.scaler
|
||||
|
||||
prec = self.cfg.TRAINER.COCOOP.PREC
|
||||
if prec == "amp":
|
||||
with autocast():
|
||||
loss = model(image, label)
|
||||
optim.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
loss = model(image, label)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
loss_summary = {"loss": loss.item()}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
|
||||
def load_model(self, directory, epoch=None):
|
||||
if not directory:
|
||||
print("Note that load_model() is skipped as no pretrained model is given")
|
||||
return
|
||||
|
||||
names = self.get_model_names()
|
||||
|
||||
# By default, the best model is loaded
|
||||
model_file = "model-best.pth.tar"
|
||||
|
||||
if epoch is not None:
|
||||
model_file = "model.pth.tar-" + str(epoch)
|
||||
|
||||
for name in names:
|
||||
model_path = osp.join(directory, name, model_file)
|
||||
|
||||
if not osp.exists(model_path):
|
||||
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
|
||||
|
||||
checkpoint = load_checkpoint(model_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint["epoch"]
|
||||
|
||||
# Ignore fixed token vectors
|
||||
if "token_prefix" in state_dict:
|
||||
del state_dict["token_prefix"]
|
||||
|
||||
if "token_suffix" in state_dict:
|
||||
del state_dict["token_suffix"]
|
||||
|
||||
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
|
||||
# set strict=False
|
||||
self._models[name].load_state_dict(state_dict, strict=False)
|
||||
@@ -0,0 +1,328 @@
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.utils import load_pretrained_weights, load_checkpoint
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
|
||||
from clip import clip
|
||||
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
|
||||
def load_clip_to_cpu(cfg):
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
url = clip._MODELS[backbone_name]
|
||||
model_path = clip._download(url)
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location="cpu").eval()
|
||||
state_dict = None
|
||||
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
design_details = {"trainer": 'CoOp',
|
||||
"vision_depth": 0,
|
||||
"language_depth": 0, "vision_ctx": 0,
|
||||
"language_ctx": 0}
|
||||
model = clip.build_model(state_dict or model.state_dict(), design_details)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self, clip_model):
|
||||
super().__init__()
|
||||
self.transformer = clip_model.transformer
|
||||
self.positional_embedding = clip_model.positional_embedding
|
||||
self.ln_final = clip_model.ln_final
|
||||
self.text_projection = clip_model.text_projection
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, prompts, tokenized_prompts):
|
||||
x = prompts + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PromptLearner(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
n_cls = len(classnames)
|
||||
n_ctx = cfg.TRAINER.COOP.N_CTX
|
||||
ctx_init = cfg.TRAINER.COOP.CTX_INIT
|
||||
dtype = clip_model.dtype
|
||||
ctx_dim = clip_model.ln_final.weight.shape[0]
|
||||
clip_imsize = clip_model.visual.input_resolution
|
||||
cfg_imsize = cfg.INPUT.SIZE[0]
|
||||
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
|
||||
|
||||
if ctx_init:
|
||||
# use given words to initialize context vectors
|
||||
ctx_init = ctx_init.replace("_", " ")
|
||||
n_ctx = len(ctx_init.split(" "))
|
||||
prompt = clip.tokenize(ctx_init)
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(prompt).type(dtype)
|
||||
ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
|
||||
prompt_prefix = ctx_init
|
||||
|
||||
else:
|
||||
# random initialization
|
||||
if cfg.TRAINER.COOP.CSC:
|
||||
print("Initializing class-specific contexts")
|
||||
ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
|
||||
else:
|
||||
print("Initializing a generic context")
|
||||
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
|
||||
nn.init.normal_(ctx_vectors, std=0.02)
|
||||
prompt_prefix = " ".join(["X"] * n_ctx)
|
||||
|
||||
print(f'Initial context: "{prompt_prefix}"')
|
||||
print(f"Number of context words (tokens): {n_ctx}")
|
||||
|
||||
self.ctx = nn.Parameter(ctx_vectors) # to be optimized
|
||||
|
||||
classnames = [name.replace("_", " ") for name in classnames]
|
||||
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
|
||||
prompts = [prompt_prefix + " " + name + "." for name in classnames]
|
||||
|
||||
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
||||
|
||||
# These token vectors will be saved when in save_model(),
|
||||
# but they should be ignored in load_model() as we want to use
|
||||
# those computed using the current class names
|
||||
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
|
||||
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
|
||||
|
||||
self.n_cls = n_cls
|
||||
self.n_ctx = n_ctx
|
||||
self.tokenized_prompts = tokenized_prompts # torch.Tensor
|
||||
self.name_lens = name_lens
|
||||
self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION
|
||||
|
||||
def forward(self):
|
||||
ctx = self.ctx
|
||||
if ctx.dim() == 2:
|
||||
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
|
||||
|
||||
prefix = self.token_prefix
|
||||
suffix = self.token_suffix
|
||||
|
||||
if self.class_token_position == "end":
|
||||
prompts = torch.cat(
|
||||
[
|
||||
prefix, # (n_cls, 1, dim)
|
||||
ctx, # (n_cls, n_ctx, dim)
|
||||
suffix, # (n_cls, *, dim)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
elif self.class_token_position == "middle":
|
||||
half_n_ctx = self.n_ctx // 2
|
||||
prompts = []
|
||||
for i in range(self.n_cls):
|
||||
name_len = self.name_lens[i]
|
||||
prefix_i = prefix[i : i + 1, :, :]
|
||||
class_i = suffix[i : i + 1, :name_len, :]
|
||||
suffix_i = suffix[i : i + 1, name_len:, :]
|
||||
ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
|
||||
ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
|
||||
prompt = torch.cat(
|
||||
[
|
||||
prefix_i, # (1, 1, dim)
|
||||
ctx_i_half1, # (1, n_ctx//2, dim)
|
||||
class_i, # (1, name_len, dim)
|
||||
ctx_i_half2, # (1, n_ctx//2, dim)
|
||||
suffix_i, # (1, *, dim)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
prompts = torch.cat(prompts, dim=0)
|
||||
|
||||
elif self.class_token_position == "front":
|
||||
prompts = []
|
||||
for i in range(self.n_cls):
|
||||
name_len = self.name_lens[i]
|
||||
prefix_i = prefix[i : i + 1, :, :]
|
||||
class_i = suffix[i : i + 1, :name_len, :]
|
||||
suffix_i = suffix[i : i + 1, name_len:, :]
|
||||
ctx_i = ctx[i : i + 1, :, :]
|
||||
prompt = torch.cat(
|
||||
[
|
||||
prefix_i, # (1, 1, dim)
|
||||
class_i, # (1, name_len, dim)
|
||||
ctx_i, # (1, n_ctx, dim)
|
||||
suffix_i, # (1, *, dim)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
prompts = torch.cat(prompts, dim=0)
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
class CustomCLIP(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
|
||||
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
|
||||
self.image_encoder = clip_model.visual
|
||||
self.text_encoder = TextEncoder(clip_model)
|
||||
self.logit_scale = clip_model.logit_scale
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, image):
|
||||
image_features = self.image_encoder(image.type(self.dtype))
|
||||
|
||||
prompts = self.prompt_learner()
|
||||
tokenized_prompts = self.tokenized_prompts
|
||||
text_features = self.text_encoder(prompts, tokenized_prompts)
|
||||
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits = logit_scale * image_features @ text_features.t()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class CoOp(TrainerX):
|
||||
"""Context Optimization (CoOp).
|
||||
|
||||
Learning to Prompt for Vision-Language Models
|
||||
https://arxiv.org/abs/2109.01134
|
||||
"""
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
|
||||
if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
|
||||
# CLIP's default precision is fp16
|
||||
clip_model.float()
|
||||
|
||||
print("Building custom CLIP")
|
||||
self.model = CustomCLIP(cfg, classnames, clip_model)
|
||||
|
||||
print("Turning off gradients in both the image and the text encoder")
|
||||
for name, param in self.model.named_parameters():
|
||||
if "prompt_learner" not in name:
|
||||
param.requires_grad_(False)
|
||||
|
||||
if cfg.MODEL.INIT_WEIGHTS:
|
||||
load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
|
||||
|
||||
self.model.to(self.device)
|
||||
# NOTE: only give prompt_learner to the optimizer
|
||||
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
|
||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
|
||||
|
||||
self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
|
||||
|
||||
# Note that multi-gpu training could be slow because CLIP's size is
|
||||
# big, which slows down the copy operation in DataParallel
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count > 1:
|
||||
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
image, label = self.parse_batch_train(batch)
|
||||
|
||||
prec = self.cfg.TRAINER.COOP.PREC
|
||||
if prec == "amp":
|
||||
with autocast():
|
||||
output = self.model(image)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.optim.zero_grad()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
else:
|
||||
output = self.model(image)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss": loss.item(),
|
||||
"acc": compute_accuracy(output, label)[0].item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
|
||||
def load_model(self, directory, epoch=None):
|
||||
if not directory:
|
||||
print("Note that load_model() is skipped as no pretrained model is given")
|
||||
return
|
||||
|
||||
names = self.get_model_names()
|
||||
|
||||
# By default, the best model is loaded
|
||||
model_file = "model-best.pth.tar"
|
||||
|
||||
if epoch is not None:
|
||||
model_file = "model.pth.tar-" + str(epoch)
|
||||
|
||||
for name in names:
|
||||
model_path = osp.join(directory, name, model_file)
|
||||
|
||||
if not osp.exists(model_path):
|
||||
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
|
||||
|
||||
checkpoint = load_checkpoint(model_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint["epoch"]
|
||||
|
||||
# Ignore fixed token vectors
|
||||
if "token_prefix" in state_dict:
|
||||
del state_dict["token_prefix"]
|
||||
|
||||
if "token_suffix" in state_dict:
|
||||
del state_dict["token_suffix"]
|
||||
|
||||
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
|
||||
# set strict=False
|
||||
self._models[name].load_state_dict(state_dict, strict=False)
|
||||
@@ -0,0 +1,94 @@
|
||||
# source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
|
||||
|
||||
IMAGENET_TEMPLATES = [
|
||||
"a bad photo of a {}.",
|
||||
"a photo of many {}.",
|
||||
"a sculpture of a {}.",
|
||||
"a photo of the hard to see {}.",
|
||||
"a low resolution photo of the {}.",
|
||||
"a rendering of a {}.",
|
||||
"graffiti of a {}.",
|
||||
"a bad photo of the {}.",
|
||||
"a cropped photo of the {}.",
|
||||
"a tattoo of a {}.",
|
||||
"the embroidered {}.",
|
||||
"a photo of a hard to see {}.",
|
||||
"a bright photo of a {}.",
|
||||
"a photo of a clean {}.",
|
||||
"a photo of a dirty {}.",
|
||||
"a dark photo of the {}.",
|
||||
"a drawing of a {}.",
|
||||
"a photo of my {}.",
|
||||
"the plastic {}.",
|
||||
"a photo of the cool {}.",
|
||||
"a close-up photo of a {}.",
|
||||
"a black and white photo of the {}.",
|
||||
"a painting of the {}.",
|
||||
"a painting of a {}.",
|
||||
"a pixelated photo of the {}.",
|
||||
"a sculpture of the {}.",
|
||||
"a bright photo of the {}.",
|
||||
"a cropped photo of a {}.",
|
||||
"a plastic {}.",
|
||||
"a photo of the dirty {}.",
|
||||
"a jpeg corrupted photo of a {}.",
|
||||
"a blurry photo of the {}.",
|
||||
"a photo of the {}.",
|
||||
"a good photo of the {}.",
|
||||
"a rendering of the {}.",
|
||||
"a {} in a video game.",
|
||||
"a photo of one {}.",
|
||||
"a doodle of a {}.",
|
||||
"a close-up photo of the {}.",
|
||||
"a photo of a {}.",
|
||||
"the origami {}.",
|
||||
"the {} in a video game.",
|
||||
"a sketch of a {}.",
|
||||
"a doodle of the {}.",
|
||||
"a origami {}.",
|
||||
"a low resolution photo of a {}.",
|
||||
"the toy {}.",
|
||||
"a rendition of the {}.",
|
||||
"a photo of the clean {}.",
|
||||
"a photo of a large {}.",
|
||||
"a rendition of a {}.",
|
||||
"a photo of a nice {}.",
|
||||
"a photo of a weird {}.",
|
||||
"a blurry photo of a {}.",
|
||||
"a cartoon {}.",
|
||||
"art of a {}.",
|
||||
"a sketch of the {}.",
|
||||
"a embroidered {}.",
|
||||
"a pixelated photo of a {}.",
|
||||
"itap of the {}.",
|
||||
"a jpeg corrupted photo of the {}.",
|
||||
"a good photo of a {}.",
|
||||
"a plushie {}.",
|
||||
"a photo of the nice {}.",
|
||||
"a photo of the small {}.",
|
||||
"a photo of the weird {}.",
|
||||
"the cartoon {}.",
|
||||
"art of the {}.",
|
||||
"a drawing of the {}.",
|
||||
"a photo of the large {}.",
|
||||
"a black and white photo of a {}.",
|
||||
"the plushie {}.",
|
||||
"a dark photo of a {}.",
|
||||
"itap of a {}.",
|
||||
"graffiti of the {}.",
|
||||
"a toy {}.",
|
||||
"itap of my {}.",
|
||||
"a photo of a cool {}.",
|
||||
"a photo of a small {}.",
|
||||
"a tattoo of the {}.",
|
||||
]
|
||||
|
||||
IMAGENET_TEMPLATES_SELECT = [
|
||||
"itap of a {}.",
|
||||
"a bad photo of the {}.",
|
||||
"a origami {}.",
|
||||
"a photo of the large {}.",
|
||||
"a {} in a video game.",
|
||||
"art of the {}.",
|
||||
"a photo of the small {}.",
|
||||
]
|
||||
@@ -0,0 +1,304 @@
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.utils import load_pretrained_weights, load_checkpoint
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
|
||||
from clip import clip
|
||||
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
|
||||
def load_clip_to_cpu(cfg):
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
url = clip._MODELS[backbone_name]
|
||||
model_path = clip._download(url)
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location="cpu").eval()
|
||||
state_dict = None
|
||||
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
design_details = {"trainer": 'IVLP',
|
||||
"vision_depth": cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION,
|
||||
"language_depth": cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT, "vision_ctx": cfg.TRAINER.IVLP.N_CTX_VISION,
|
||||
"language_ctx": cfg.TRAINER.IVLP.N_CTX_TEXT}
|
||||
model = clip.build_model(state_dict or model.state_dict(), design_details)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self, clip_model):
|
||||
super().__init__()
|
||||
self.transformer = clip_model.transformer
|
||||
self.positional_embedding = clip_model.positional_embedding
|
||||
self.ln_final = clip_model.ln_final
|
||||
self.text_projection = clip_model.text_projection
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, prompts, tokenized_prompts):
|
||||
x = prompts + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VLPromptLearner(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
n_cls = len(classnames)
|
||||
# Make sure Language depth >= 1
|
||||
assert cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \
|
||||
"\nPlease use VPT trainer if you want to learn only vision " \
|
||||
"branch "
|
||||
n_ctx = cfg.TRAINER.IVLP.N_CTX_TEXT
|
||||
ctx_init = cfg.TRAINER.IVLP.CTX_INIT
|
||||
dtype = clip_model.dtype
|
||||
ctx_dim = clip_model.ln_final.weight.shape[0]
|
||||
vis_dim = clip_model.visual.output_dim
|
||||
clip_imsize = clip_model.visual.input_resolution
|
||||
cfg_imsize = cfg.INPUT.SIZE[0]
|
||||
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
|
||||
|
||||
if ctx_init and (n_ctx) <= 4:
|
||||
# use given words to initialize context vectors
|
||||
ctx_init = ctx_init.replace("_", " ")
|
||||
n_ctx = n_ctx
|
||||
prompt = clip.tokenize(ctx_init)
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(prompt).type(dtype)
|
||||
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
|
||||
prompt_prefix = ctx_init
|
||||
else:
|
||||
# random initialization
|
||||
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
|
||||
nn.init.normal_(ctx_vectors, std=0.02)
|
||||
prompt_prefix = " ".join(["X"] * n_ctx)
|
||||
print(f"Independent V-L design")
|
||||
print(f'Initial text context: "{prompt_prefix}"')
|
||||
print(f"Number of context words (tokens) for Language prompting: {n_ctx}")
|
||||
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.IVLP.N_CTX_VISION}")
|
||||
self.ctx = nn.Parameter(ctx_vectors)
|
||||
|
||||
classnames = [name.replace("_", " ") for name in classnames]
|
||||
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
|
||||
prompts = [prompt_prefix + " " + name + "." for name in classnames]
|
||||
|
||||
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
||||
|
||||
# These token vectors will be saved when in save_model(),
|
||||
# but they should be ignored in load_model() as we want to use
|
||||
# those computed using the current class names
|
||||
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
|
||||
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS
|
||||
|
||||
self.n_cls = n_cls
|
||||
self.n_ctx = n_ctx
|
||||
self.tokenized_prompts = tokenized_prompts # torch.Tensor
|
||||
self.name_lens = name_lens
|
||||
|
||||
def construct_prompts(self, ctx, prefix, suffix, label=None):
|
||||
# dim0 is either batch_size (during training) or n_cls (during testing)
|
||||
# ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
|
||||
# prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
|
||||
# suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
|
||||
|
||||
if label is not None:
|
||||
prefix = prefix[label]
|
||||
suffix = suffix[label]
|
||||
|
||||
prompts = torch.cat(
|
||||
[
|
||||
prefix, # (dim0, 1, dim)
|
||||
ctx, # (dim0, n_ctx, dim)
|
||||
suffix, # (dim0, *, dim)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return prompts
|
||||
|
||||
def forward(self):
|
||||
ctx = self.ctx
|
||||
if ctx.dim() == 2:
|
||||
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
|
||||
|
||||
prefix = self.token_prefix
|
||||
suffix = self.token_suffix
|
||||
prompts = self.construct_prompts(ctx, prefix, suffix)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
class CustomCLIP(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
self.prompt_learner = VLPromptLearner(cfg, classnames, clip_model)
|
||||
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
|
||||
self.image_encoder = clip_model.visual
|
||||
self.text_encoder = TextEncoder(clip_model)
|
||||
self.logit_scale = clip_model.logit_scale
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, image, label=None):
|
||||
tokenized_prompts = self.tokenized_prompts
|
||||
logit_scale = self.logit_scale.exp()
|
||||
|
||||
prompts = self.prompt_learner()
|
||||
text_features = self.text_encoder(prompts, tokenized_prompts)
|
||||
image_features = self.image_encoder(image.type(self.dtype))
|
||||
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
logits = logit_scale * image_features @ text_features.t()
|
||||
|
||||
if self.prompt_learner.training:
|
||||
return F.cross_entropy(logits, label)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class IVLP(TrainerX):
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.TRAINER.IVLP.PREC in ["fp16", "fp32", "amp"]
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
|
||||
if cfg.TRAINER.IVLP.PREC == "fp32" or cfg.TRAINER.IVLP.PREC == "amp":
|
||||
# CLIP's default precision is fp16
|
||||
clip_model.float()
|
||||
|
||||
print("Building custom CLIP")
|
||||
self.model = CustomCLIP(cfg, classnames, clip_model)
|
||||
|
||||
print("Turning off gradients in both the image and the text encoder")
|
||||
name_to_update = "prompt_learner"
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if name_to_update not in name:
|
||||
# Make sure that VPT prompts are updated
|
||||
if "VPT" in name:
|
||||
param.requires_grad_(True)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
# Double check
|
||||
enabled = set()
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
enabled.add(name)
|
||||
print(f"Parameters to be updated: {enabled}")
|
||||
|
||||
if cfg.MODEL.INIT_WEIGHTS:
|
||||
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
|
||||
|
||||
self.model.to(self.device)
|
||||
# NOTE: only give prompt_learner to the optimizer
|
||||
self.optim = build_optimizer(self.model, cfg.OPTIM)
|
||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||
self.register_model("VLPromptLearner", self.model, self.optim, self.sched)
|
||||
|
||||
self.scaler = GradScaler() if cfg.TRAINER.IVLP.PREC == "amp" else None
|
||||
|
||||
# Note that multi-gpu training could be slow because CLIP's size is
|
||||
# big, which slows down the copy operation in DataParallel
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count > 1:
|
||||
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
image, label = self.parse_batch_train(batch)
|
||||
|
||||
model = self.model
|
||||
optim = self.optim
|
||||
scaler = self.scaler
|
||||
|
||||
prec = self.cfg.TRAINER.IVLP.PREC
|
||||
if prec == "amp":
|
||||
with autocast():
|
||||
loss = model(image, label)
|
||||
optim.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
loss = model(image, label)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
loss_summary = {"loss": loss.item()}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
|
||||
def load_model(self, directory, epoch=None):
|
||||
if not directory:
|
||||
print("Note that load_model() is skipped as no pretrained model is given")
|
||||
return
|
||||
|
||||
names = self.get_model_names()
|
||||
|
||||
# By default, the best model is loaded
|
||||
model_file = "model-best.pth.tar"
|
||||
|
||||
if epoch is not None:
|
||||
model_file = "model.pth.tar-" + str(epoch)
|
||||
|
||||
for name in names:
|
||||
model_path = osp.join(directory, name, model_file)
|
||||
|
||||
if not osp.exists(model_path):
|
||||
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
|
||||
|
||||
checkpoint = load_checkpoint(model_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint["epoch"]
|
||||
|
||||
# Ignore fixed token vectors
|
||||
if "prompt_learner.token_prefix" in state_dict:
|
||||
del state_dict["prompt_learner.token_prefix"]
|
||||
|
||||
if "prompt_learner.token_suffix" in state_dict:
|
||||
del state_dict["prompt_learner.token_suffix"]
|
||||
|
||||
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
|
||||
# set strict=False
|
||||
self._models[name].load_state_dict(state_dict, strict=False)
|
||||
@@ -0,0 +1,928 @@
|
||||
import os.path as osp
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import os
|
||||
import pickle
|
||||
import deepcore.methods as s_method
|
||||
import numpy as np
|
||||
|
||||
from torch.nn import functional as F
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.utils import load_pretrained_weights, load_checkpoint, mkdir_if_missing
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
from dassl.evaluation import Classification,EvaluatorBase
|
||||
from pygrad.pcgrad import PCGrad
|
||||
from datasets.data_manager import DataManager
|
||||
from dassl.data.datasets import build_dataset
|
||||
|
||||
|
||||
from clip import clip
|
||||
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
from trainers.zsclip import CUSTOM_TEMPLATES
|
||||
from .coop import load_clip_to_cpu as lcp
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import f1_score, confusion_matrix
|
||||
from collections import OrderedDict, defaultdict
|
||||
from .util import GradCAM,denorm
|
||||
import cv2
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
|
||||
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',]
|
||||
|
||||
|
||||
#['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
|
||||
#'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
|
||||
#]
|
||||
|
||||
BACKGROUND_CATEGORY_FOOD = ['table','forks','tablecloth','hands','spoon','glasses','dishes']
|
||||
|
||||
def load_clip_to_cpu(cfg):
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
url = clip._MODELS[backbone_name]
|
||||
model_path = clip._download(url)
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location="cpu").eval()
|
||||
state_dict = None
|
||||
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
design_details = {"trainer": 'MaPLe',
|
||||
"vision_depth": 0,
|
||||
"language_depth": 0, "vision_ctx": 0,
|
||||
"language_ctx": 0,
|
||||
"maple_length": cfg.TRAINER.MAPLE.N_CTX}
|
||||
|
||||
model = clip.build_model(state_dict or model.state_dict(), design_details)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self, clip_model):
|
||||
super().__init__()
|
||||
self.transformer = clip_model.transformer
|
||||
self.positional_embedding = clip_model.positional_embedding
|
||||
self.ln_final = clip_model.ln_final
|
||||
self.text_projection = clip_model.text_projection
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text):
|
||||
x = prompts + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
# Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass
|
||||
combined = [x, compound_prompts_deeper_text, 0] # third argument is the counter which denotes depth of prompt
|
||||
outputs = self.transformer(combined)
|
||||
x = outputs[0] # extract the x back from here
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiModalPromptLearner(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
n_cls = len(classnames)
|
||||
n_ctx = cfg.TRAINER.MAPLE.N_CTX # n_ctx
|
||||
ctx_init = cfg.TRAINER.MAPLE.CTX_INIT # a photo of
|
||||
dtype = clip_model.dtype
|
||||
ctx_dim = clip_model.ln_final.weight.shape[0] #512
|
||||
clip_imsize = clip_model.visual.input_resolution #224
|
||||
cfg_imsize = cfg.INPUT.SIZE[0] #224
|
||||
# Default is 1, which is compound shallow prompting
|
||||
assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, "For MaPLe, PROMPT_DEPTH should be >= 1"
|
||||
self.compound_prompts_depth = cfg.TRAINER.MAPLE.PROMPT_DEPTH #9 # max=12, but will create 11 such shared prompts
|
||||
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
|
||||
|
||||
if ctx_init and (n_ctx) <= 4:
|
||||
# use given words to initialize context vectors
|
||||
ctx_init = ctx_init.replace("_", " ")
|
||||
n_ctx = n_ctx
|
||||
prompt = clip.tokenize(ctx_init)
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(prompt).type(dtype)
|
||||
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
|
||||
prompt_prefix = ctx_init
|
||||
else:
|
||||
# random initialization
|
||||
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
|
||||
nn.init.normal_(ctx_vectors, std=0.02)
|
||||
prompt_prefix = " ".join(["X"] * n_ctx)
|
||||
print('MaPLe design: Multi-modal Prompt Learning')
|
||||
print(f'Initial context: "{prompt_prefix}"')
|
||||
print(f"Number of MaPLe context words (tokens): {n_ctx}")
|
||||
# These below, related to the shallow prompts
|
||||
# Linear layer so that the tokens will project to 512 and will be initialized from 768
|
||||
self.proj = nn.Linear(ctx_dim, 768)
|
||||
self.proj.half()
|
||||
self.ctx = nn.Parameter(ctx_vectors) #[2 512]
|
||||
# These below parameters related to the shared prompts
|
||||
# Define the compound prompts for the deeper layers
|
||||
|
||||
# Minimum can be 1, which defaults to shallow MaPLe
|
||||
# compound prompts
|
||||
self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(n_ctx, 512))
|
||||
for _ in range(self.compound_prompts_depth - 1)])
|
||||
for single_para in self.compound_prompts_text:
|
||||
nn.init.normal_(single_para, std=0.02)
|
||||
# Also make corresponding projection layers, for each prompt
|
||||
single_layer = nn.Linear(ctx_dim, 768)
|
||||
self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
classnames = [name.replace("_", " ") for name in classnames]
|
||||
|
||||
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
|
||||
prompts = [prompt_prefix + " " + name + "." for name in classnames]
|
||||
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
|
||||
|
||||
|
||||
###Introduce Background
|
||||
bg_template = 'a clean origami {}.'
|
||||
|
||||
bg_classesnames = [bg_template.format(name) for name in BACKGROUND_CATEGORY +BACKGROUND_CATEGORY_FOOD ]
|
||||
tokenized_bg_prompts = torch.cat([clip.tokenize(bg) for bg in bg_classesnames])
|
||||
bg_num = len(BACKGROUND_CATEGORY) + len(BACKGROUND_CATEGORY_FOOD)
|
||||
tokenized_prompts = torch.cat((tokenized_prompts,tokenized_bg_prompts),dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
|
||||
self.bg_embeding = embedding[-bg_num:]
|
||||
|
||||
# These token vectors will be saved when in save_model(),
|
||||
# but they should be ignored in load_model() as we want to use
|
||||
# those computed using the current class names
|
||||
self.register_buffer("token_prefix", embedding[:-bg_num, :1, :]) # SOS
|
||||
self.register_buffer("token_suffix", embedding[:-bg_num, 1 + n_ctx:, :]) # CLS, EOS
|
||||
|
||||
self.n_cls = n_cls
|
||||
self.n_ctx = n_ctx
|
||||
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num 77] [:-bg_num]
|
||||
self.name_lens = name_lens
|
||||
|
||||
def construct_prompts(self, ctx, prefix, suffix, label=None):
|
||||
# dim0 is either batch_size (during training) or n_cls (during testing)
|
||||
# ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
|
||||
# prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
|
||||
# suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
|
||||
|
||||
|
||||
|
||||
if label is not None:
|
||||
prefix = prefix[label]
|
||||
suffix = suffix[label]
|
||||
|
||||
prompts = torch.cat(
|
||||
[
|
||||
prefix, # (dim0, 1, dim)
|
||||
ctx, # (dim0, n_ctx, dim)
|
||||
suffix, # (dim0, *, dim)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
final_prompts = torch.cat((prompts,self.bg_embeding.cuda()),dim=0)
|
||||
return final_prompts
|
||||
|
||||
def forward(self):
|
||||
ctx = self.ctx
|
||||
|
||||
if ctx.dim() == 2:
|
||||
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
|
||||
|
||||
prefix = self.token_prefix
|
||||
suffix = self.token_suffix
|
||||
prompts = self.construct_prompts(ctx, prefix, suffix)
|
||||
|
||||
# Before returning, need to transform
|
||||
# prompts to 768 for the visual side
|
||||
visual_deep_prompts = []
|
||||
for index, layer in enumerate(self.compound_prompt_projections):
|
||||
visual_deep_prompts.append(layer(self.compound_prompts_text[index]))
|
||||
# Now the other way around
|
||||
# We will project the textual prompts from 512 to 768
|
||||
return prompts, self.proj(self.ctx), self.compound_prompts_text, visual_deep_prompts # pass here original, as for visual 768 is required
|
||||
|
||||
|
||||
class CustomCLIP(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model)
|
||||
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
|
||||
self.image_encoder = clip_model.visual
|
||||
self.image_encoder_ori = clip_model.visual_ori
|
||||
self.text_encoder = TextEncoder(clip_model)
|
||||
self.logit_scale = clip_model.logit_scale
|
||||
self.dtype = clip_model.dtype
|
||||
self.txt_f = []
|
||||
self.img_f = []
|
||||
self.one_hot_label = []
|
||||
self.vtx = []
|
||||
self.loaded_mask = None
|
||||
# self.loss_weights = torch.nn.Parameter(torch.tensor([0.8,0.03],dtype=self.dtype))
|
||||
|
||||
|
||||
|
||||
def get_uniform_ball_noise(self,input_shape,radius=1.0):
|
||||
uniform_noise_ball = torch.randn(input_shape).cuda()
|
||||
uniform_noise_sphere = F.normalize(uniform_noise_ball,dim=1)
|
||||
u = torch.rand(input_shape[0]).cuda()
|
||||
u = u **(1. / input_shape[1])
|
||||
uniform_noise_ball = (uniform_noise_sphere.T *u *radius).T
|
||||
return uniform_noise_ball.type(self.dtype)
|
||||
|
||||
|
||||
def get_learnable_noise(self,input_shape):
|
||||
para = 0.05
|
||||
noise = torch.nn.Parameter(torch.randn(input_shape)*para).cuda()
|
||||
|
||||
return noise.type(self.dtype)
|
||||
|
||||
def cos_sim(self,a,b):
|
||||
return F.cosine_similarity(a,b)
|
||||
|
||||
def forward(self, image, label=None,record=False,cal_gradient=False,weight=None,epoch=None,index=None,cfg=None,mask=None):
|
||||
tokenized_prompts = self.tokenized_prompts
|
||||
logit_scale = self.logit_scale.exp()
|
||||
|
||||
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
|
||||
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
|
||||
text_features_fg = text_features[:-len(BACKGROUND_CATEGORY)]
|
||||
ori_image_input = image.type(self.dtype)
|
||||
# text_features = text_features + self.get_learnable_noise(text_features.shape)
|
||||
|
||||
text_features_fg = text_features_fg / text_features_fg.norm(dim=-1, keepdim=True)
|
||||
|
||||
image_features, visual_ctx, mask_similarity = self.image_encoder(ori_image_input, shared_ctx,
|
||||
deep_compound_prompts_vision)
|
||||
|
||||
|
||||
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
# if label is not None:
|
||||
# image_features = image_features + self.get_uniform_ball_noise(image_features.shape)
|
||||
|
||||
logits = logit_scale * image_features @ text_features_fg.t()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if mask != None:
|
||||
|
||||
|
||||
|
||||
text_features_bg = text_features[-len(BACKGROUND_CATEGORY):]
|
||||
text_features_bg = text_features_bg / text_features_bg.norm(dim=-1, keepdim=True)
|
||||
image_features_fg,_,_ = self.image_encoder(ori_image_input*mask, shared_ctx, deep_compound_prompts_vision) #, shared_ctx, deep_compound_prompts_vision
|
||||
|
||||
|
||||
image_features_fg = image_features_fg / image_features_fg.norm(dim=-1, keepdim=True)
|
||||
image_features_bg,_,_ = self.image_encoder(ori_image_input*(1-mask), shared_ctx, deep_compound_prompts_vision)
|
||||
image_features_bg = image_features_bg / image_features_bg.norm(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
loss_re1 = F.triplet_margin_loss(image_features,image_features_fg.detach(),image_features_bg.detach(),margin=1.5)
|
||||
|
||||
# image_features_fg_ori = self.image_encoder_ori(ori_image_input*mask_random)
|
||||
# image_features_bg_ori = self.image_encoder_ori(ori_image_input*(1-mask_random))
|
||||
# image_features_fg_ori = image_features_fg_ori / image_features_fg_ori.norm(dim=-1, keepdim=True)
|
||||
# image_features_bg_ori = image_features_bg_ori / image_features_bg_ori.norm(dim=-1,keepdim=True)
|
||||
# image_features_all_ori = image_features_fg_ori + image_features_bg_ori
|
||||
# image_features_all_ori = image_features_all_ori / image_features_all_ori.norm(dim=-1,keepdim=True)
|
||||
# loss_reo = torch.abs(image_features_all_ori.detach() - image_features).mean()
|
||||
|
||||
foreground_score = logit_scale*image_features_fg.detach()@text_features_fg.t()
|
||||
pseudo_label = torch.argmax(image_features_bg @ text_features_bg.t(), dim=-1)
|
||||
logits_bg = logit_scale*(image_features_bg) @ text_features_bg.t()
|
||||
|
||||
para_bg = 0.5
|
||||
para_fg = 0.1
|
||||
para_vd = 0.8
|
||||
|
||||
|
||||
loss_bg = F.cross_entropy(logits_bg,pseudo_label)
|
||||
loss_fg = F.cross_entropy(foreground_score,label)
|
||||
|
||||
if epoch > 6: #Tunable parameters
|
||||
loss_re = para_fg*loss_fg + para_bg*loss_bg
|
||||
else:
|
||||
loss_re = para_vd*loss_re1 #loss_reo would be effective in base2novel setting
|
||||
|
||||
|
||||
if self.prompt_learner.training:
|
||||
if weight is None:
|
||||
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_vd':loss_re1.item(),'loss_bg':loss_bg.item(),'loss_fg':loss_fg.item()}
|
||||
else:
|
||||
return F.cross_entropy(weight.unsqueeze(-1)*logits,label), logits
|
||||
|
||||
if record: #store the embeeding
|
||||
one_hot_label = F.one_hot(label,num_classes=text_features.shape[0]).to(torch.float16)
|
||||
return image_features.detach(),(one_hot_label @ text_features).detach(), logits
|
||||
|
||||
if cal_gradient:
|
||||
#Treating this as initial gradient
|
||||
# one_hot_label = F.one_hot(label,num_classes=text_features.shape[0]).to(torch.float16)
|
||||
return F.cross_entropy(logits.requires_grad_(True), label), image_features.detach(), logits #,(one_hot_label @ text_features).detach()
|
||||
return logits
|
||||
|
||||
def grad_norm(self,loss_group,original_loss_group):
|
||||
alpha = 0.10
|
||||
self.loss_weights.grad.data = self.loss_weights.grad.data * 0.0
|
||||
W = self.prompt_learner.compound_prompt_projections[0]
|
||||
norms = []
|
||||
for i in range(len(loss_group)):
|
||||
gygw = torch.autograd.grad(loss_group[i],W.parameters(),retain_graph=True)
|
||||
norms.append(torch.norm(torch.mul(self.loss_weights[i],gygw[0])))
|
||||
norms = torch.stack(norms)
|
||||
loss_ratio = loss_group.data.cpu().numpy() / original_loss_group
|
||||
inverse_train_rate = loss_ratio / np.mean(loss_ratio)
|
||||
mean_norm = np.mean(norms.data.cpu().numpy())
|
||||
constant_norm = torch.tensor(mean_norm*(inverse_train_rate**alpha),requires_grad=False).cuda()
|
||||
grad_norm_loss = torch.sum(torch.abs(norms - constant_norm))
|
||||
|
||||
|
||||
|
||||
self.loss_weights.grad = torch.autograd.grad(grad_norm_loss,self.loss_weights)[0]
|
||||
|
||||
|
||||
|
||||
|
||||
def forward_test(self, image, label=None,record=False,cal_gradient=False,weight=None,cfg=None,attn_mask=False):
|
||||
tokenized_prompts = self.tokenized_prompts
|
||||
logit_scale = self.logit_scale.exp()
|
||||
|
||||
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
|
||||
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
|
||||
image_features,visual_ctx,mask = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)
|
||||
|
||||
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
logits = logit_scale * image_features @ text_features.t()
|
||||
|
||||
|
||||
if self.prompt_learner.training:
|
||||
if weight is None:
|
||||
return F.cross_entropy(logits, label),logits
|
||||
else:
|
||||
return F.cross_entropy(weight.unsqueeze(-1)*logits,label), logits
|
||||
|
||||
if record: #store the embeeding
|
||||
one_hot_label = F.one_hot(label,num_classes=text_features.shape[0]).to(torch.float16)
|
||||
return image_features.detach(),(one_hot_label @ text_features).detach(), logits
|
||||
if attn_mask:
|
||||
return logits,mask
|
||||
if cal_gradient:
|
||||
#Treating this as initial gradient
|
||||
# one_hot_label = F.one_hot(label,num_classes=text_features.shape[0]).to(torch.float16)
|
||||
return F.cross_entropy(logits.requires_grad_(True), label), image_features.detach(), logits #,(one_hot_label @ text_features).detach()
|
||||
return logits
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MaPLe(TrainerX):
|
||||
|
||||
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.TRAINER.MAPLE.PREC in ["fp16", "fp32", "amp"]
|
||||
|
||||
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
|
||||
if cfg.TRAINER.MAPLE.PREC == "fp32" or cfg.TRAINER.MAPLE.PREC == "amp":
|
||||
# CLIP's default precision is fp16
|
||||
clip_model.float()
|
||||
|
||||
print("Building custom CLIP")
|
||||
self.model = CustomCLIP(cfg, classnames, clip_model)
|
||||
|
||||
print("Turning off gradients in both the image and the text encoder")
|
||||
name_to_update = "prompt_learner"
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
|
||||
|
||||
if name_to_update not in name:
|
||||
# Make sure that VPT prompts are updated
|
||||
if "VPT" in name:
|
||||
param.requires_grad_(True)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
|
||||
# Double check
|
||||
enabled = set()
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
enabled.add(name)
|
||||
print(f"Parameters to be updated: {enabled}")
|
||||
|
||||
if cfg.MODEL.INIT_WEIGHTS:
|
||||
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
|
||||
|
||||
self.model.to(self.device)
|
||||
# self.model.loss_weights.requires_grad_(True) #open gradient for loss_weights
|
||||
# NOTE: only give prompt_learner to the optimizer
|
||||
|
||||
|
||||
self.optim = build_optimizer(self.model, cfg.OPTIM)
|
||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||
|
||||
|
||||
|
||||
self.selected_optim = build_optimizer(self.model, cfg.OPTIM_SELECTION)
|
||||
self.selected_sched = build_lr_scheduler(self.optim, cfg.OPTIM_SELECTION)
|
||||
|
||||
self.register_model("MultiModalPromptLearner", self.model, self.optim, self.sched)
|
||||
|
||||
self.scaler = GradScaler() if cfg.TRAINER.MAPLE.PREC == "amp" else None
|
||||
|
||||
# Note that multi-gpu training could be slow because CLIP's size is
|
||||
# big, which slows down the copy operation in DataParallel
|
||||
# device_count = torch.cuda.device_count()
|
||||
# if device_count > 1:
|
||||
# print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
|
||||
# self.model = nn.DataParallel(self.model)
|
||||
|
||||
|
||||
# def generate_text_feature(self):
|
||||
# cfg = self.cfg
|
||||
# classnames = self.dm.dataset.classnames
|
||||
# #
|
||||
# # print(f"Loading Custom CLIP (backbone: {cfg.MODEL.BACKBONE.NAME}) for selection")
|
||||
# # clip_model = lcp(cfg)
|
||||
# # clip_model.to(self.device)
|
||||
#
|
||||
# temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
||||
# prompts = [temp.format(c.replace("_", " ")) for c in classnames]
|
||||
# print(f"Prompts: {prompts}")
|
||||
# prompts = torch.cat([clip.tokenize(p) for p in prompts])
|
||||
# prompts = prompts.to(self.device)
|
||||
#
|
||||
# p, _, deep_compound_prompts_text, _ = self.model.prompt_learner()
|
||||
# with torch.no_grad():
|
||||
# text = self.model.text_encoder(prompts)
|
||||
# text_features = self.model.encode_text(prompts, tokenized_prompts, deep_compound_prompts_text)
|
||||
# text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
#
|
||||
# self.ori_text_features = text_features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def forward_backward(self, batch):
|
||||
if self.sample_weights is not None:
|
||||
image, label,index,mask = self.parse_batch_train_pair(batch)
|
||||
else:
|
||||
image, label,index,mask = self.parse_batch_train_pair(batch)
|
||||
weight = None
|
||||
|
||||
model = self.model
|
||||
optim = self.optim
|
||||
scaler = self.scaler
|
||||
|
||||
|
||||
prec = self.cfg.TRAINER.MAPLE.PREC
|
||||
if prec == "amp":
|
||||
with autocast():
|
||||
loss,_ = model(image, label, weight=weight,mask=mask)
|
||||
optim.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
loss,_,loss_dict = model(image, label, weight=weight,epoch=self.epoch,index=index,cfg=self.cfg,mask=mask)
|
||||
optim.zero_grad()
|
||||
|
||||
# optim.pc_backward(loss_task)
|
||||
loss.backward()
|
||||
# if self.epoch == 0:
|
||||
# self.loss_o1 = loss_task.data.cpu().numpy()
|
||||
# model.grad_norm(loss_task,self.loss_o1)
|
||||
|
||||
optim.step()
|
||||
|
||||
# normalized_coeff = 2 / torch.sum(model.loss_weights.data,dim=0)
|
||||
# model.loss_weights.data *= normalized_coeff
|
||||
|
||||
|
||||
|
||||
|
||||
loss_summary = loss_dict
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train_pair(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
index = batch["index"]
|
||||
mask = batch['mask']
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
mask = mask.to(self.device)
|
||||
|
||||
if self.sample_weights is not None:
|
||||
# weight = batch['weight'].cuda()
|
||||
return input, label,index,mask
|
||||
else:
|
||||
return input, label,index,mask
|
||||
|
||||
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
index = batch["index"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
|
||||
|
||||
if self.sample_weights is not None:
|
||||
weight = batch['weight'].cuda()
|
||||
return input, label,weight,index
|
||||
else:
|
||||
return input, label,index
|
||||
|
||||
def load_model(self, directory, epoch=None):
|
||||
if not directory:
|
||||
print("Note that load_model() is skipped as no pretrained model is given")
|
||||
return
|
||||
|
||||
names = self.get_model_names()
|
||||
|
||||
# By default, the best model is loaded
|
||||
model_file = "model-best.pth.tar"
|
||||
|
||||
if epoch is not None:
|
||||
model_file = "model.pth.tar-" + str(epoch)
|
||||
|
||||
for name in names:
|
||||
model_path = osp.join(directory, name, model_file)
|
||||
|
||||
if not osp.exists(model_path):
|
||||
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
|
||||
|
||||
checkpoint = load_checkpoint(model_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint["epoch"]
|
||||
|
||||
# Ignore fixed token vectors
|
||||
if "prompt_learner.token_prefix" in state_dict:
|
||||
del state_dict["prompt_learner.token_prefix"]
|
||||
|
||||
if "prompt_learner.token_suffix" in state_dict:
|
||||
del state_dict["prompt_learner.token_suffix"]
|
||||
|
||||
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
|
||||
# set strict=False
|
||||
self._models[name].load_state_dict(state_dict, strict=False)
|
||||
|
||||
def before_train(self):
|
||||
directory = self.cfg.OUTPUT_DIR
|
||||
if self.cfg.RESUME:
|
||||
directory = self.cfg.RESUME
|
||||
# self.start_epoch = self.resume_model_if_exist(directory) #in case of loading pre-trained weight
|
||||
|
||||
|
||||
# Redefine the dataloader
|
||||
selected_res = self.selector()
|
||||
if 'weights' in selected_res:
|
||||
c_weight = np.zeros(len(self.dm.dataset.train_x))
|
||||
c_weight[selected_res['indices']] = selected_res['weights']
|
||||
self.sample_weights = c_weight[selected_res['indices']]
|
||||
else:
|
||||
self.sample_weights = None
|
||||
|
||||
|
||||
|
||||
self.build_final_data_loader(selected_res['indices'],self.sample_weights)
|
||||
print(f'Finish the selecting process, now continue tune CLIP')
|
||||
# Initialize summary writer
|
||||
writer_dir = osp.join(self.output_dir, "tensorboard")
|
||||
mkdir_if_missing(writer_dir)
|
||||
self.init_writer(writer_dir)
|
||||
|
||||
# Remember the starting time (for computing the elapsed time)
|
||||
self.time_start = time.time()
|
||||
|
||||
print(f"Now generate the attentive masking in {self.cfg.TRAINER.DAPT_MODE} \n")
|
||||
|
||||
|
||||
if self.cfg.TRAINER.DAPT_MODE == 'dapt-s':
|
||||
self.generate_mask_train()
|
||||
else:
|
||||
self.generate_gradcam_train(split='train')
|
||||
|
||||
|
||||
|
||||
def after_epoch(self):
|
||||
last_epoch = (self.epoch + 1) == self.max_epoch
|
||||
do_test = not self.cfg.TEST.NO_TEST
|
||||
meet_checkpoint_freq = (
|
||||
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
|
||||
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False)
|
||||
|
||||
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
|
||||
curr_result = self.test(split="val")
|
||||
is_best = curr_result > self.best_result
|
||||
if is_best:
|
||||
self.best_result = curr_result
|
||||
self.save_model(
|
||||
self.epoch,
|
||||
self.output_dir,
|
||||
val_result=curr_result,
|
||||
model_name="model-best.pth.tar"
|
||||
)
|
||||
|
||||
# if meet_checkpoint_freq or last_epoch:
|
||||
# self.save_model(self.epoch, self.output_dir)
|
||||
|
||||
print(f"Now generate the attentive masking in {self.cfg.TRAINER.DAPT_MODE} \n")
|
||||
|
||||
|
||||
if self.cfg.TRAINER.DAPT_MODE == 'dapt-s':
|
||||
self.generate_mask_train()
|
||||
else:
|
||||
self.generate_gradcam_train(split='train')
|
||||
|
||||
|
||||
|
||||
|
||||
def build_final_data_loader(self,selected_ind=None,weight=None):
|
||||
new_dm = DataManager(self.cfg,self.dm.dataset,selected_ind,weight=weight)
|
||||
self.train_loader_x = new_dm.train_loader_x
|
||||
self.train_loader_xmore = new_dm.train_loader_xmore #for generate the attentive masking
|
||||
self.mask_list = torch.zeros((selected_ind.shape[0], 1, *self.cfg.INPUT.SIZE),dtype=torch.float16)
|
||||
|
||||
def selector(self):
|
||||
selection_ratio = self.cfg.DATASET.SELECTION_RATIO
|
||||
seed = self.cfg.SEED
|
||||
method = self.cfg.DATASET.SELECTION_METHOD
|
||||
print(f"Selecting {selection_ratio*100}% data by {method}")
|
||||
|
||||
if self.cfg.DATASET.SELECTION_METHOD == 'Uniform':
|
||||
|
||||
|
||||
selector = s_method.Uniform(self.dm, self.cfg,selection_ratio, seed)
|
||||
else:
|
||||
|
||||
selector = s_method.__dict__[method](dst_train=self.dm,
|
||||
args=self.cfg,
|
||||
fraction=selection_ratio,
|
||||
random_seed=seed,
|
||||
specific_model=self.model,
|
||||
optim = self.selected_optim,
|
||||
schedule = self.selected_sched,
|
||||
scar = self.scaler,
|
||||
balance = True
|
||||
)
|
||||
|
||||
|
||||
return selector.select()
|
||||
|
||||
@torch.no_grad()
|
||||
def test_withlabel(self, split=None):
|
||||
"""A generic testing pipeline."""
|
||||
self.set_model_mode("eval")
|
||||
new_estimate = NewClassification(self.cfg,self.evaluator._lab2cname)
|
||||
new_estimate.reset()
|
||||
|
||||
if split is None:
|
||||
split = self.cfg.TEST.SPLIT
|
||||
|
||||
if split == "val" and self.val_loader is not None:
|
||||
data_loader = self.val_loader
|
||||
else:
|
||||
split = "test" # in case val_loader is None
|
||||
data_loader = self.test_loader
|
||||
|
||||
print(f"Evaluate on the *{split}* set")
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
input, label = self.parse_batch_test(batch)
|
||||
|
||||
output = self.model.forward_test(input,label,cfg = self.cfg)
|
||||
new_estimate.process(output, label)
|
||||
|
||||
results = new_estimate.evaluate()
|
||||
|
||||
for k, v in results.items():
|
||||
tag = f"{split}/{k}"
|
||||
self.write_scalar(tag, v, self.epoch)
|
||||
|
||||
return list(results.values())[0]
|
||||
|
||||
|
||||
def generate_gradcam(self, split=None,attn_mask=False):
|
||||
"""A generic pipeline for generating GradCAM"""
|
||||
self.set_model_mode("eval")
|
||||
model_dict = {'arch':self.model,'layer_name':'target.layer'}
|
||||
cam = GradCAM(model_dict)
|
||||
# new_estimate = NewClassification(self.cfg,self.evaluator._lab2cname)
|
||||
# new_estimate.reset()
|
||||
|
||||
img_split = 'wrong' #true/wrong
|
||||
if split is None:
|
||||
split = self.cfg.TEST.SPLIT
|
||||
|
||||
if split == "val" and self.val_loader is not None:
|
||||
data_loader = self.val_loader
|
||||
else:
|
||||
split = "test" # in case val_loader is None
|
||||
data_loader = self.test_loader
|
||||
|
||||
print(f"Generate GradCAM on the *{split}* set")
|
||||
|
||||
save_path = self.cfg.OUTPUT_DIR + '/'+f'{split}_{img_split}_promptcamother'
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
input, label = self.parse_batch_test(batch)
|
||||
img_name = batch['impath'][0].split('/')[-1]
|
||||
img_save_path = os.path.join(save_path, img_name)
|
||||
img0 = denorm(batch['img0'].numpy(),self.cfg.INPUT.PIXEL_MEAN,self.cfg.INPUT.PIXEL_STD)
|
||||
saliency_map = cam.forward(input,label,cfg = self.cfg,split=img_split,attn_mask=attn_mask)
|
||||
if saliency_map != None:
|
||||
final_map = cam.show_cam(img0,saliency_map.detach().cpu(),img_save_path)
|
||||
|
||||
|
||||
|
||||
|
||||
def generate_mask_train(self):
|
||||
for batch_idx, batch in enumerate(tqdm(self.train_loader_xmore)):
|
||||
input, _, index = self.parse_batch_train(batch)
|
||||
b,c,h,w = input.shape
|
||||
mask = torch.ones((1,h,w),dtype=torch.float16)
|
||||
grid_sizes = [32,16]
|
||||
hide_prob = 0.5
|
||||
grid_size = grid_sizes[torch.randint(0,len(grid_sizes),size=(1,))]
|
||||
|
||||
if (grid_size != 0):
|
||||
for x in range(0,h,grid_size):
|
||||
for y in range(0,w,grid_size):
|
||||
x_end,y_end = min(h, x+grid_size),min(w,y+grid_size)
|
||||
if (random.random() <= hide_prob):
|
||||
mask[:,x:x_end,y:y_end] = 0
|
||||
self.mask_list[index, :] = mask
|
||||
self.model.loaded_mask = self.mask_list
|
||||
|
||||
|
||||
def generate_mask_bg(self):
|
||||
for batch_idx, batch in enumerate(tqdm(self.train_loader_xmore)):
|
||||
input, _, index = self.parse_batch_train(batch)
|
||||
b,c,h,w = input.shape
|
||||
mask = torch.ones((1,h,w),dtype=torch.float16)
|
||||
grid_sizes = [64,128]
|
||||
hide_prob = 0.5
|
||||
grid_size = grid_sizes[torch.randint(0,len(grid_sizes),size=(1,))]
|
||||
|
||||
if (grid_size != 0):
|
||||
for x in range(0,h,grid_size):
|
||||
for y in range(0,w,grid_size):
|
||||
x_end,y_end = min(h, x+grid_size),min(w,y+grid_size)
|
||||
if (random.random() <= hide_prob):
|
||||
mask[:,x:x_end,y:y_end] = 0
|
||||
self.mask_list[index, :] = mask
|
||||
self.model.loaded_mask = self.mask_list
|
||||
|
||||
|
||||
def generate_gradcam_train(self, split=None,attn_mask=False):
|
||||
"""A generic pipeline for generating GradCAM"""
|
||||
self.set_model_mode("eval")
|
||||
model_dict = {'arch':self.model,'layer_name':'target.layer'}
|
||||
cam = GradCAM(model_dict)
|
||||
# new_estimate = NewClassification(self.cfg,self.evaluator._lab2cname)
|
||||
# new_estimate.reset()
|
||||
|
||||
print(f"Generate GradCAM on the *{split}* set")
|
||||
|
||||
# save_path = self.cfg.OUTPUT_DIR + '/'+f'{split}_{img_split}_promptcamother'
|
||||
# if not os.path.exists(save_path):
|
||||
# os.mkdir(save_path)
|
||||
for batch_idx, batch in enumerate(tqdm(self.train_loader_xmore)):
|
||||
input, label, index = self.parse_batch_train(batch)
|
||||
# img0 = denorm(batch['img0'].numpy(),self.cfg.INPUT.PIXEL_MEAN,self.cfg.INPUT.PIXEL_STD)
|
||||
saliency_map = cam.forward_train(input,label,cfg = self.cfg,attn_mask=attn_mask)
|
||||
self.mask_list[index,:] = saliency_map.detach().cpu()
|
||||
# if saliency_map != None:
|
||||
# final_map = cam.show_cam(img0,saliency_map.detach().cpu(),img_save_path)
|
||||
self.model.loaded_mask = self.mask_list
|
||||
|
||||
|
||||
class NewClassification(Classification):
|
||||
def __init__(self, cfg, lab2cname=None, **kwargs):
|
||||
super(NewClassification, self).__init__(cfg,lab2cname)
|
||||
self._lab2cname = lab2cname
|
||||
self._correct = 0
|
||||
self._total = 0
|
||||
self._per_class_res = None
|
||||
self._y_true = []
|
||||
self._y_pred = []
|
||||
if cfg.TEST.PER_CLASS_RESULT:
|
||||
assert lab2cname is not None
|
||||
self._per_class_res = defaultdict(list)
|
||||
|
||||
def evaluate(self):
|
||||
results = OrderedDict()
|
||||
acc = 100.0 * self._correct / self._total
|
||||
err = 100.0 - acc
|
||||
macro_f1 = 100.0 * f1_score(
|
||||
self._y_true,
|
||||
self._y_pred,
|
||||
average="macro",
|
||||
labels=np.unique(self._y_true)
|
||||
)
|
||||
|
||||
# The first value will be returned by trainer.test()
|
||||
results["accuracy"] = acc
|
||||
results["error_rate"] = err
|
||||
results["macro_f1"] = macro_f1
|
||||
|
||||
wrong_ind = np.array(self._y_true) != np.array(self._y_pred)
|
||||
np.save(self.cfg.OUTPUT_DIR + '/'+'wrongind.npy',wrong_ind)
|
||||
print(
|
||||
"=> result\n"
|
||||
f"* total: {self._total:,}\n"
|
||||
f"* correct: {self._correct:,}\n"
|
||||
f"* accuracy: {acc:.1f}%\n"
|
||||
f"* error: {err:.1f}%\n"
|
||||
f"* macro_f1: {macro_f1:.1f}%"
|
||||
)
|
||||
|
||||
if self._per_class_res is not None:
|
||||
labels = list(self._per_class_res.keys())
|
||||
labels.sort()
|
||||
|
||||
print("=> per-class result")
|
||||
accs = []
|
||||
|
||||
for label in labels:
|
||||
classname = self._lab2cname[label]
|
||||
res = self._per_class_res[label]
|
||||
correct = sum(res)
|
||||
total = len(res)
|
||||
acc = 100.0 * correct / total
|
||||
accs.append(acc)
|
||||
|
||||
print(
|
||||
f"* class: {label} ({classname})\t"
|
||||
f"total: {total:,}\t"
|
||||
f"correct: {correct:,}\t"
|
||||
f"acc: {acc:.1f}%"
|
||||
)
|
||||
|
||||
mean_acc = np.mean(accs)
|
||||
np.save(self.cfg.OUTPUT_DIR + '/'+'per-class.npy',{'per_cls':accs, 'mean_acc':mean_acc})
|
||||
print(f"* average: {mean_acc:.1f}%")
|
||||
|
||||
results["perclass_accuracy"] = mean_acc
|
||||
|
||||
if self.cfg.TEST.COMPUTE_CMAT:
|
||||
cmat = confusion_matrix(
|
||||
self._y_true, self._y_pred, normalize="true"
|
||||
)
|
||||
save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
|
||||
torch.save(cmat, save_path)
|
||||
print(f"Confusion matrix is saved to {save_path}")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,265 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
|
||||
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
|
||||
]
|
||||
|
||||
class GradCAM(object):
|
||||
def __init__(self,model_dict):
|
||||
layer_name = model_dict['layer_name']
|
||||
self.model_arch = model_dict['arch']
|
||||
|
||||
self.gradient = dict()
|
||||
self.activation = dict()
|
||||
|
||||
self.gradient_t = dict()
|
||||
self.activation_t = dict()
|
||||
|
||||
def backward_hook(module,grad_input,grad_output):
|
||||
self.gradient['value'] = grad_output[0]
|
||||
return None
|
||||
|
||||
def forward_hook(module,input,output):
|
||||
self.activation['value'] = output
|
||||
return None
|
||||
|
||||
def backward_hook_t(module,grad_input,grad_output):
|
||||
self.gradient_t['value'] = grad_output[0]
|
||||
return None
|
||||
|
||||
def forward_hook_t(module,input,output):
|
||||
self.activation_t['value'] = output
|
||||
return None
|
||||
|
||||
target_layer = self.model_arch.image_encoder.transformer.resblocks[-1].ln_1
|
||||
# target_layer_t = self.model_arch.image_encoder.transformer.resblocks[-2].mlp.c_proj
|
||||
target_layer.register_forward_hook(forward_hook)
|
||||
target_layer.register_backward_hook(backward_hook)
|
||||
|
||||
# target_layer_t.register_forward_hook(forward_hook_t)
|
||||
# target_layer_t.register_backward_hook(backward_hook_t)
|
||||
|
||||
def forward(self,input,labels,cfg=None,retain_graph=False,split=None,attn_mask=False):
|
||||
|
||||
|
||||
|
||||
|
||||
b,c,h,w = input.shape
|
||||
patch_num,ori_size = self.model_arch.image_encoder.patch_num, self.model_arch.image_encoder.input_resolution
|
||||
|
||||
if attn_mask:
|
||||
logit,mask = self.model_arch.forward_test(input,labels,cfg=cfg,attn_mask=attn_mask)
|
||||
cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].reshape(b,-1,patch_num,patch_num) #+ mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1)
|
||||
aff = mask[:,1:-self.model_arch.prompt_learner.n_ctx, 1:-self.model_arch.prompt_learner.n_ctx]
|
||||
# aff = (aff + aff.permute(0,2,1)) / 2
|
||||
aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
|
||||
# aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
|
||||
|
||||
|
||||
# aff = (aff + aff.permute(0,2,1)) / 2
|
||||
|
||||
# aff = torch.bmm(aff,aff)
|
||||
|
||||
# aff = F.softmax(aff,dim=1)
|
||||
# cls_mask = torch.bmm(cls_mask, aff).reshape(b,-1,patch_num,patch_num)
|
||||
|
||||
|
||||
# cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1).reshape(b,-1,patch_num,patch_num)
|
||||
# # cls_mask = mask[:,-self.model_arch.prompt_learner.n_ctx:,1:-self.model_arch.prompt_learner.n_ctx].reshape(b,-1,patch_num,patch_num).mean(dim=1,keepdim=True)
|
||||
# final_cls_mask = F.upsample(cls_mask, size=(ori_size, ori_size), mode='bilinear',
|
||||
# align_corners=True)
|
||||
# final_cls_feature_min, final_cls_feature_max = final_cls_mask.min(), final_cls_mask.max()
|
||||
# final_cls_mask = (final_cls_mask - final_cls_feature_min) / (
|
||||
# final_cls_feature_max - final_cls_feature_min + 1e-6)
|
||||
# final_cls_mask = final_cls_mask / (final_cls_mask.max() + 1e-6)
|
||||
|
||||
else:
|
||||
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
|
||||
pred_label = torch.argmax(logit[:,:-len(BACKGROUND_CATEGORY)])
|
||||
sign = pred_label == labels
|
||||
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
|
||||
# print(f'Ignore the not {split} sample')
|
||||
# return None
|
||||
|
||||
# if attn_mask:
|
||||
# return final_cls_mask
|
||||
pred = logit[:,:-len(BACKGROUND_CATEGORY)].argmax(dim=-1)
|
||||
background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
|
||||
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]-len(BACKGROUND_CATEGORY)).to(torch.float16)
|
||||
|
||||
loss = (F.softmax(logit[:,:-len(BACKGROUND_CATEGORY)])*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
|
||||
|
||||
# score = logit[:,labels]
|
||||
self.model_arch.zero_grad()
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
gradients = self.gradient['value']
|
||||
activations = self.activation['value']
|
||||
|
||||
# gradients_t = self.gradient_t['value']
|
||||
# activations_t = self.activation_t['value']
|
||||
|
||||
visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
|
||||
# visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
|
||||
# cls = gradients[1:-self.model_arch.prompt_learner.n_ctx,:,:]
|
||||
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(dim=0,keepdim=True)#gradients[:1,:,:]
|
||||
cls_token_gradient,prompt_gradient = gradients[:1,:,:], gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(keepdim=True,dim=0)
|
||||
visual_gradient = torch.mean(gradients[1:-self.model_arch.prompt_learner.n_ctx],keepdim=True,dim=0)
|
||||
|
||||
lam = 0.5
|
||||
# cls_token_gradient = cls_token_gradient / (cls_token_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
|
||||
# prompt_gradient = prompt_gradient / (prompt_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
|
||||
|
||||
# sim = F.cosine_similarity(prompt_gradient.mean(dim=0,keepdim=True),cls_token_gradient,dim=-1)
|
||||
# print(sim)
|
||||
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].max(dim=0,keepdim=True)[0]#gradients[:1,:,:]
|
||||
|
||||
# token_gradient = cls_token_gradient
|
||||
# token_gradient = cls_token_gradient#*(prompt_gradient.mean(dim=0,keepdim=True))
|
||||
# propmt_mean = prompt_gradient.mean(dim=0,keepdim=True)
|
||||
token_gradient = visual_gradient
|
||||
|
||||
final_visual_feature = torch.bmm(visual_feature.permute(1,0,2),token_gradient.permute(1,2,0))
|
||||
final_visual_feature = F.relu(final_visual_feature).permute(0,2,1)
|
||||
# if attn_mask:
|
||||
# final_visual_feature = torch.bmm(final_visual_feature, aff)
|
||||
|
||||
final_visual_feature = final_visual_feature.reshape(final_visual_feature.shape[0],1, patch_num, patch_num)
|
||||
final_visual_feature = F.upsample(final_visual_feature,size=(ori_size,ori_size),mode='bilinear',align_corners=True)
|
||||
|
||||
# saliency_map = final_visual_feature / final_visual_feature.max()
|
||||
final_visual_feature_min, final_visual_feature_max = final_visual_feature.min(), final_visual_feature.max()
|
||||
saliency_map = final_visual_feature / (final_visual_feature_max + 1e-6)#(final_visual_feature-final_visual_feature_min) / (final_visual_feature_max - final_visual_feature_min + 1e-6)
|
||||
|
||||
threshold = 0.5
|
||||
# saliency_map[saliency_map >= threshold] = 1
|
||||
saliency_map[saliency_map < threshold] = 0
|
||||
|
||||
return saliency_map
|
||||
|
||||
|
||||
def forward_train(self,input,labels,cfg=None,retain_graph=False,split=None,attn_mask=False):
|
||||
|
||||
|
||||
|
||||
|
||||
b,c,h,w = input.shape
|
||||
patch_num,ori_size = self.model_arch.image_encoder.patch_num, self.model_arch.image_encoder.input_resolution
|
||||
|
||||
if attn_mask:
|
||||
logit,mask = self.model_arch.forward_test(input,labels,cfg=cfg,attn_mask=attn_mask)
|
||||
cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].reshape(b,-1,patch_num,patch_num) #+ mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1)
|
||||
aff = mask[:,1:-self.model_arch.prompt_learner.n_ctx, 1:-self.model_arch.prompt_learner.n_ctx]
|
||||
# aff = (aff + aff.permute(0,2,1)) / 2
|
||||
aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
|
||||
# aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
|
||||
|
||||
|
||||
# aff = (aff + aff.permute(0,2,1)) / 2
|
||||
|
||||
# aff = torch.bmm(aff,aff)
|
||||
|
||||
# aff = F.softmax(aff,dim=1)
|
||||
# cls_mask = torch.bmm(cls_mask, aff).reshape(b,-1,patch_num,patch_num)
|
||||
|
||||
|
||||
# cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1).reshape(b,-1,patch_num,patch_num)
|
||||
# # cls_mask = mask[:,-self.model_arch.prompt_learner.n_ctx:,1:-self.model_arch.prompt_learner.n_ctx].reshape(b,-1,patch_num,patch_num).mean(dim=1,keepdim=True)
|
||||
# final_cls_mask = F.upsample(cls_mask, size=(ori_size, ori_size), mode='bilinear',
|
||||
# align_corners=True)
|
||||
# final_cls_feature_min, final_cls_feature_max = final_cls_mask.min(), final_cls_mask.max()
|
||||
# final_cls_mask = (final_cls_mask - final_cls_feature_min) / (
|
||||
# final_cls_feature_max - final_cls_feature_min + 1e-6)
|
||||
# final_cls_mask = final_cls_mask / (final_cls_mask.max() + 1e-6)
|
||||
|
||||
else:
|
||||
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
|
||||
pred_label = torch.argmax(logit)
|
||||
sign = pred_label == labels
|
||||
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
|
||||
# print(f'Ignore the not {split} sample')
|
||||
# return None
|
||||
|
||||
# if attn_mask:
|
||||
# return final_cls_mask
|
||||
# pred = logit[:,-len(BACKGROUND_CATEGORY):].argmax(dim=-1)
|
||||
# background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
|
||||
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
|
||||
loss = (logit*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
|
||||
# score = logit[:,labels]
|
||||
self.model_arch.zero_grad()
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
gradients = self.gradient['value']
|
||||
activations = self.activation['value']
|
||||
|
||||
# gradients_t = self.gradient_t['value']
|
||||
# activations_t = self.activation_t['value']
|
||||
|
||||
visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
|
||||
# visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
|
||||
# cls = gradients[1:-self.model_arch.prompt_learner.n_ctx,:,:]
|
||||
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(dim=0,keepdim=True)#gradients[:1,:,:]
|
||||
cls_token_gradient,prompt_gradient = gradients[:1,:,:], gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(keepdim=True,dim=0)
|
||||
visual_gradient = torch.mean(gradients[1:-self.model_arch.prompt_learner.n_ctx],keepdim=True,dim=0)
|
||||
|
||||
lam = 0.5
|
||||
# cls_token_gradient = cls_token_gradient / (cls_token_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
|
||||
# prompt_gradient = prompt_gradient / (prompt_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
|
||||
|
||||
# sim = F.cosine_similarity(prompt_gradient.mean(dim=0,keepdim=True),cls_token_gradient,dim=-1)
|
||||
# print(sim)
|
||||
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].max(dim=0,keepdim=True)[0]#gradients[:1,:,:]
|
||||
|
||||
# token_gradient = cls_token_gradient
|
||||
# token_gradient = cls_token_gradient#*(prompt_gradient.mean(dim=0,keepdim=True))
|
||||
# propmt_mean = prompt_gradient.mean(dim=0,keepdim=True)
|
||||
token_gradient = visual_gradient
|
||||
|
||||
final_visual_feature = torch.bmm(visual_feature.permute(1,0,2),token_gradient.permute(1,2,0))
|
||||
final_visual_feature = F.relu(final_visual_feature).permute(0,2,1)
|
||||
# if attn_mask:
|
||||
# final_visual_feature = torch.bmm(final_visual_feature, aff)
|
||||
|
||||
final_visual_feature = final_visual_feature.reshape(final_visual_feature.shape[0],1, patch_num, patch_num)
|
||||
final_visual_feature = F.upsample(final_visual_feature,size=(ori_size,ori_size),mode='bilinear',align_corners=True)
|
||||
|
||||
# saliency_map = final_visual_feature / final_visual_feature.max()
|
||||
final_visual_feature_min, final_visual_feature_max = final_visual_feature.min(), final_visual_feature.max()
|
||||
saliency_map = final_visual_feature / (final_visual_feature_max + 1e-6)#(final_visual_feature-final_visual_feature_min) / (final_visual_feature_max - final_visual_feature_min + 1e-6)
|
||||
|
||||
threshold = 0.5
|
||||
saliency_map[saliency_map >= threshold] = 1
|
||||
saliency_map[saliency_map < threshold] = 0
|
||||
|
||||
return saliency_map
|
||||
|
||||
|
||||
def show_cam(self,img,mask,save_path=None):
|
||||
|
||||
heat_map = cv2.applyColorMap(np.uint8(255*mask.squeeze()), cv2.COLORMAP_JET)
|
||||
heatmap = torch.from_numpy(heat_map).permute(2,0,1).float().div(255)
|
||||
b,g,r = heatmap.split(1)
|
||||
heatmap = torch.cat([r,g,b])
|
||||
rate = 0.5
|
||||
res = rate*heatmap + (1-rate)*img
|
||||
res = res.div(res.max()).squeeze()
|
||||
res = np.transpose(np.uint8(255*res),(1,2,0))
|
||||
|
||||
pil_image = Image.fromarray(res)
|
||||
# pil_image.save('test1.jpg')
|
||||
pil_image.save(save_path)
|
||||
return pil_image
|
||||
|
||||
|
||||
|
||||
def denorm(img,mean,std):
|
||||
mean,std = np.array(mean),np.array(std)
|
||||
img = img*std[:, None, None] + mean[:, None, None]
|
||||
# img = np.clip(img*255, 0, 255) #.clamp(0,255)
|
||||
# img = img / 255
|
||||
return img
|
||||
|
||||
+239
@@ -0,0 +1,239 @@
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.utils import load_pretrained_weights, load_checkpoint
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
|
||||
from clip import clip
|
||||
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
|
||||
def load_clip_to_cpu(cfg):
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
url = clip._MODELS[backbone_name]
|
||||
model_path = clip._download(url)
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location="cpu").eval()
|
||||
state_dict = None
|
||||
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
design_details = { "trainer": "VPT",
|
||||
"vision_depth": cfg.TRAINER.VPT.PROMPT_DEPTH_VISION,
|
||||
"vision_ctx": cfg.TRAINER.VPT.N_CTX_VISION,
|
||||
"language_depth": 0,
|
||||
"language_ctx": 0}
|
||||
assert cfg.TRAINER.VPT.PROMPT_DEPTH_VISION >= 1, "For Vision Prompting, PROMPT_DEPTH_VISION should be >= 1"
|
||||
model = clip.build_model(state_dict or model.state_dict(), design_details)
|
||||
|
||||
return model.float()
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self, clip_model):
|
||||
super().__init__()
|
||||
self.transformer = clip_model.transformer
|
||||
self.positional_embedding = clip_model.positional_embedding
|
||||
self.ln_final = clip_model.ln_final
|
||||
self.text_projection = clip_model.text_projection
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, prompts, tokenized_prompts):
|
||||
x = prompts + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FixedEmbeddings():
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
clip_imsize = clip_model.visual.input_resolution
|
||||
cfg_imsize = cfg.INPUT.SIZE[0]
|
||||
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
|
||||
|
||||
prompt_prefix = "a photo of a"
|
||||
print('Vision Prompting Design')
|
||||
print(f'Initial context: "{prompt_prefix}"')
|
||||
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.VPT.N_CTX_VISION}")
|
||||
print(f"Using fixed hand crated prompts")
|
||||
|
||||
classnames = [name.replace("_", " ") for name in classnames]
|
||||
prompts = [prompt_prefix + " " + name + "." for name in classnames]
|
||||
|
||||
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
|
||||
with torch.no_grad():
|
||||
text_features = clip_model.encode_text(tokenized_prompts)
|
||||
|
||||
self.fixed_embeddings = text_features
|
||||
|
||||
def return_fixed_embeddings(self):
|
||||
return self.fixed_embeddings
|
||||
|
||||
|
||||
class CustomCLIP(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
super().__init__()
|
||||
self.embeddings = FixedEmbeddings(cfg, classnames, clip_model)
|
||||
self.image_encoder = clip_model.visual
|
||||
self.text_encoder = TextEncoder(clip_model)
|
||||
self.logit_scale = clip_model.logit_scale
|
||||
self.dtype = clip_model.dtype
|
||||
|
||||
def forward(self, image, label=None, training=False):
|
||||
logit_scale = self.logit_scale.exp()
|
||||
|
||||
text_features = self.embeddings.return_fixed_embeddings().cuda()
|
||||
image_features = self.image_encoder(image.type(self.dtype))
|
||||
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
logits = logit_scale * image_features @ text_features.t()
|
||||
|
||||
if training:
|
||||
return F.cross_entropy(logits, label)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class VPT(TrainerX):
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.TRAINER.VPT.PREC in ["fp16", "fp32", "amp"]
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
|
||||
if cfg.TRAINER.VPT.PREC == "fp32" or cfg.TRAINER.VPT.PREC == "amp":
|
||||
# CLIP's default precision is fp16
|
||||
clip_model.float()
|
||||
|
||||
print("Building custom CLIP")
|
||||
self.model = CustomCLIP(cfg, classnames, clip_model)
|
||||
|
||||
print("Turning off gradients in both the image and the text encoder")
|
||||
name_to_update = "prompt_learner"
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if name_to_update not in name:
|
||||
# Make sure that VPT prompts are updated
|
||||
if "VPT" in name:
|
||||
param.requires_grad_(True)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
# Double check
|
||||
enabled = set()
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
enabled.add(name)
|
||||
print(f"Parameters to be updated: {enabled}")
|
||||
|
||||
if cfg.MODEL.INIT_WEIGHTS:
|
||||
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
|
||||
|
||||
self.model.to(self.device)
|
||||
# NOTE: only give prompt_learner to the optimizer
|
||||
self.optim = build_optimizer(self.model, cfg.OPTIM)
|
||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||
self.register_model("prompt_learner", self.model, self.optim, self.sched)
|
||||
|
||||
self.scaler = GradScaler() if cfg.TRAINER.VPT.PREC == "amp" else None
|
||||
|
||||
# Note that multi-gpu training could be slow because CLIP's size is
|
||||
# big, which slows down the copy operation in DataParallel
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count > 1:
|
||||
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
def forward_backward(self, batch):
|
||||
image, label = self.parse_batch_train(batch)
|
||||
|
||||
model = self.model
|
||||
optim = self.optim
|
||||
scaler = self.scaler
|
||||
|
||||
prec = self.cfg.TRAINER.VPT.PREC
|
||||
if prec == "amp":
|
||||
with autocast():
|
||||
loss = model(image, label)
|
||||
optim.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
loss = model(image, label, training=True)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
loss_summary = {"loss": loss.item()}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch):
|
||||
input = batch["img"]
|
||||
label = batch["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
|
||||
def load_model(self, directory, epoch=None):
|
||||
if not directory:
|
||||
print("Note that load_model() is skipped as no pretrained model is given")
|
||||
return
|
||||
|
||||
names = self.get_model_names()
|
||||
|
||||
# By default, the best model is loaded
|
||||
model_file = "model-best.pth.tar"
|
||||
|
||||
if epoch is not None:
|
||||
model_file = "model.pth.tar-" + str(epoch)
|
||||
|
||||
for name in names:
|
||||
model_path = osp.join(directory, name, model_file)
|
||||
|
||||
if not osp.exists(model_path):
|
||||
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
|
||||
|
||||
checkpoint = load_checkpoint(model_path)
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint["epoch"]
|
||||
|
||||
# Ignore fixed token vectors
|
||||
if "prompt_learner.token_prefix" in state_dict:
|
||||
del state_dict["prompt_learner.token_prefix"]
|
||||
|
||||
if "prompt_learner.token_suffix" in state_dict:
|
||||
del state_dict["prompt_learner.token_suffix"]
|
||||
|
||||
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
|
||||
# set strict=False
|
||||
self._models[name].load_state_dict(state_dict, strict=False)
|
||||
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||
|
||||
from clip import clip
|
||||
from clip.model import convert_weights
|
||||
|
||||
from .coop import load_clip_to_cpu
|
||||
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
|
||||
|
||||
CUSTOM_TEMPLATES = {
|
||||
"OxfordPets": "a photo of a {}, a type of pet.",
|
||||
"OxfordFlowers": "a photo of a {}, a type of flower.",
|
||||
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
|
||||
"DescribableTextures": "{} texture.",
|
||||
"EuroSAT": "a centered satellite photo of {}.",
|
||||
"StanfordCars": "a photo of a {}.",
|
||||
"Food101": "a photo of {}, a type of food.",
|
||||
"SUN397": "a photo of a {}.",
|
||||
"Caltech101": "a photo of a {}.",
|
||||
"UCF101": "a photo of a person doing {}.",
|
||||
"ImageNet": "a photo of a {}.",
|
||||
"ImageNetSketch": "a photo of a {}.",
|
||||
"ImageNetV2": "a photo of a {}.",
|
||||
"ImageNetA": "a photo of a {}.",
|
||||
"ImageNetR": "a photo of a {}.",
|
||||
}
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class ZeroshotCLIP(TrainerX):
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
clip_model.to(self.device)
|
||||
|
||||
temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
|
||||
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
|
||||
print(f"Prompts: {prompts}")
|
||||
prompts = torch.cat([clip.tokenize(p) for p in prompts])
|
||||
prompts = prompts.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = clip_model.encode_text(prompts)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
self.text_features = text_features
|
||||
self.clip_model = clip_model
|
||||
|
||||
def model_inference(self, image):
|
||||
image_features = self.clip_model.encode_image(image)
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
logit_scale = self.clip_model.logit_scale.exp()
|
||||
logits = logit_scale * image_features @ self.text_features.t()
|
||||
return logits
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class ZeroshotCLIP2(ZeroshotCLIP):
|
||||
"""Prompt ensembling."""
|
||||
|
||||
# templates = IMAGENET_TEMPLATES
|
||||
templates = IMAGENET_TEMPLATES_SELECT
|
||||
|
||||
def build_model(self):
|
||||
cfg = self.cfg
|
||||
classnames = self.dm.dataset.classnames
|
||||
|
||||
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
|
||||
clip_model = load_clip_to_cpu(cfg)
|
||||
clip_model.to(self.device)
|
||||
|
||||
for params in clip_model.parameters():
|
||||
params.requires_grad_(False)
|
||||
|
||||
# add custom-made prompt
|
||||
if cfg.DATASET.NAME != "ImageNet":
|
||||
self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
|
||||
|
||||
num_temp = len(self.templates)
|
||||
print(f"Prompt ensembling (n={num_temp})")
|
||||
|
||||
mean_text_features = 0
|
||||
for i, temp in enumerate(self.templates):
|
||||
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
|
||||
prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
|
||||
text_features = clip_model.encode_text(prompts)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
mean_text_features = mean_text_features + text_features
|
||||
mean_text_features = mean_text_features / num_temp
|
||||
mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
self.text_features = mean_text_features
|
||||
self.clip_model = clip_model
|
||||
Reference in New Issue
Block a user