Upload to Main

This commit is contained in:
张菲
2025-10-07 22:42:55 +08:00
commit d3ddab7c5d
218 changed files with 125815 additions and 0 deletions

0
trainers/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

318
trainers/cocoop.py Normal file
View File

@@ -0,0 +1,318 @@
import os.path as osp
from collections import OrderedDict
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
design_details = {"trainer": 'CoCoOp',
"vision_depth": 0,
"language_depth": 0, "vision_ctx": 0,
"language_ctx": 0}
model = clip.build_model(state_dict or model.state_dict(), design_details)
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class PromptLearner(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
n_ctx = cfg.TRAINER.COCOOP.N_CTX
ctx_init = cfg.TRAINER.COCOOP.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
vis_dim = clip_model.visual.output_dim
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init:
# use given words to initialize context vectors
ctx_init = ctx_init.replace("_", " ")
n_ctx = len(ctx_init.split(" "))
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
prompt_prefix = ctx_init
else:
# random initialization
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens): {n_ctx}")
self.ctx = nn.Parameter(ctx_vectors)
self.meta_net = nn.Sequential(OrderedDict([
("linear1", nn.Linear(vis_dim, vis_dim // 16)),
("relu", nn.ReLU(inplace=True)),
("linear2", nn.Linear(vis_dim // 16, ctx_dim))
]))
if cfg.TRAINER.COCOOP.PREC == "fp16":
self.meta_net.half()
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor
self.name_lens = name_lens
def construct_prompts(self, ctx, prefix, suffix, label=None):
# dim0 is either batch_size (during training) or n_cls (during testing)
# ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
# prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
# suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
if label is not None:
prefix = prefix[label]
suffix = suffix[label]
prompts = torch.cat(
[
prefix, # (dim0, 1, dim)
ctx, # (dim0, n_ctx, dim)
suffix, # (dim0, *, dim)
],
dim=1,
)
return prompts
def forward(self, im_features):
prefix = self.token_prefix
suffix = self.token_suffix
ctx = self.ctx # (n_ctx, ctx_dim)
bias = self.meta_net(im_features) # (batch, ctx_dim)
bias = bias.unsqueeze(1) # (batch, 1, ctx_dim)
ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim)
ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim)
# Use instance-conditioned context tokens for all classes
prompts = []
for ctx_shifted_i in ctx_shifted:
ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim)
prompts.append(pts_i)
prompts = torch.stack(prompts)
return prompts
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
def forward(self, image, label=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
image_features = self.image_encoder(image.type(self.dtype))
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
prompts = self.prompt_learner(image_features)
logits = []
for pts_i, imf_i in zip(prompts, image_features):
text_features = self.text_encoder(pts_i, tokenized_prompts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
l_i = logit_scale * imf_i @ text_features.t()
logits.append(l_i)
logits = torch.stack(logits)
if self.prompt_learner.training:
return F.cross_entropy(logits, label)
return logits
@TRAINER_REGISTRY.register()
class CoCoOp(TrainerX):
def check_cfg(self, cfg):
assert cfg.TRAINER.COCOOP.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
name_to_update = "prompt_learner"
for name, param in self.model.named_parameters():
if name_to_update not in name:
param.requires_grad_(False)
# Double check
enabled = set()
for name, param in self.model.named_parameters():
if param.requires_grad:
enabled.add(name)
print(f"Parameters to be updated: {enabled}")
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.COCOOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model)
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
model = self.model
optim = self.optim
scaler = self.scaler
prec = self.cfg.TRAINER.COCOOP.PREC
if prec == "amp":
with autocast():
loss = model(image, label)
optim.zero_grad()
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else:
loss = model(image, label)
optim.zero_grad()
loss.backward()
optim.step()
loss_summary = {"loss": loss.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def load_model(self, directory, epoch=None):
if not directory:
print("Note that load_model() is skipped as no pretrained model is given")
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)

328
trainers/coop.py Normal file
View File

@@ -0,0 +1,328 @@
import os.path as osp
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
design_details = {"trainer": 'CoOp',
"vision_depth": 0,
"language_depth": 0, "vision_ctx": 0,
"language_ctx": 0}
model = clip.build_model(state_dict or model.state_dict(), design_details)
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class PromptLearner(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
n_ctx = cfg.TRAINER.COOP.N_CTX
ctx_init = cfg.TRAINER.COOP.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init:
# use given words to initialize context vectors
ctx_init = ctx_init.replace("_", " ")
n_ctx = len(ctx_init.split(" "))
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
prompt_prefix = ctx_init
else:
# random initialization
if cfg.TRAINER.COOP.CSC:
print("Initializing class-specific contexts")
ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
else:
print("Initializing a generic context")
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens): {n_ctx}")
self.ctx = nn.Parameter(ctx_vectors) # to be optimized
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor
self.name_lens = name_lens
self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION
def forward(self):
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
if self.class_token_position == "end":
prompts = torch.cat(
[
prefix, # (n_cls, 1, dim)
ctx, # (n_cls, n_ctx, dim)
suffix, # (n_cls, *, dim)
],
dim=1,
)
elif self.class_token_position == "middle":
half_n_ctx = self.n_ctx // 2
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = prefix[i : i + 1, :, :]
class_i = suffix[i : i + 1, :name_len, :]
suffix_i = suffix[i : i + 1, name_len:, :]
ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
ctx_i_half1, # (1, n_ctx//2, dim)
class_i, # (1, name_len, dim)
ctx_i_half2, # (1, n_ctx//2, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
elif self.class_token_position == "front":
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = prefix[i : i + 1, :, :]
class_i = suffix[i : i + 1, :name_len, :]
suffix_i = suffix[i : i + 1, name_len:, :]
ctx_i = ctx[i : i + 1, :, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
class_i, # (1, name_len, dim)
ctx_i, # (1, n_ctx, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
else:
raise ValueError
return prompts
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
def forward(self, image):
image_features = self.image_encoder(image.type(self.dtype))
prompts = self.prompt_learner()
tokenized_prompts = self.tokenized_prompts
text_features = self.text_encoder(prompts, tokenized_prompts)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp()
logits = logit_scale * image_features @ text_features.t()
return logits
@TRAINER_REGISTRY.register()
class CoOp(TrainerX):
"""Context Optimization (CoOp).
Learning to Prompt for Vision-Language Models
https://arxiv.org/abs/2109.01134
"""
def check_cfg(self, cfg):
assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
for name, param in self.model.named_parameters():
if "prompt_learner" not in name:
param.requires_grad_(False)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model)
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
prec = self.cfg.TRAINER.COOP.PREC
if prec == "amp":
with autocast():
output = self.model(image)
loss = F.cross_entropy(output, label)
self.optim.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optim)
self.scaler.update()
else:
output = self.model(image)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def load_model(self, directory, epoch=None):
if not directory:
print("Note that load_model() is skipped as no pretrained model is given")
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)

View File

@@ -0,0 +1,94 @@
# source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
IMAGENET_TEMPLATES = [
"a bad photo of a {}.",
"a photo of many {}.",
"a sculpture of a {}.",
"a photo of the hard to see {}.",
"a low resolution photo of the {}.",
"a rendering of a {}.",
"graffiti of a {}.",
"a bad photo of the {}.",
"a cropped photo of the {}.",
"a tattoo of a {}.",
"the embroidered {}.",
"a photo of a hard to see {}.",
"a bright photo of a {}.",
"a photo of a clean {}.",
"a photo of a dirty {}.",
"a dark photo of the {}.",
"a drawing of a {}.",
"a photo of my {}.",
"the plastic {}.",
"a photo of the cool {}.",
"a close-up photo of a {}.",
"a black and white photo of the {}.",
"a painting of the {}.",
"a painting of a {}.",
"a pixelated photo of the {}.",
"a sculpture of the {}.",
"a bright photo of the {}.",
"a cropped photo of a {}.",
"a plastic {}.",
"a photo of the dirty {}.",
"a jpeg corrupted photo of a {}.",
"a blurry photo of the {}.",
"a photo of the {}.",
"a good photo of the {}.",
"a rendering of the {}.",
"a {} in a video game.",
"a photo of one {}.",
"a doodle of a {}.",
"a close-up photo of the {}.",
"a photo of a {}.",
"the origami {}.",
"the {} in a video game.",
"a sketch of a {}.",
"a doodle of the {}.",
"a origami {}.",
"a low resolution photo of a {}.",
"the toy {}.",
"a rendition of the {}.",
"a photo of the clean {}.",
"a photo of a large {}.",
"a rendition of a {}.",
"a photo of a nice {}.",
"a photo of a weird {}.",
"a blurry photo of a {}.",
"a cartoon {}.",
"art of a {}.",
"a sketch of the {}.",
"a embroidered {}.",
"a pixelated photo of a {}.",
"itap of the {}.",
"a jpeg corrupted photo of the {}.",
"a good photo of a {}.",
"a plushie {}.",
"a photo of the nice {}.",
"a photo of the small {}.",
"a photo of the weird {}.",
"the cartoon {}.",
"art of the {}.",
"a drawing of the {}.",
"a photo of the large {}.",
"a black and white photo of a {}.",
"the plushie {}.",
"a dark photo of a {}.",
"itap of a {}.",
"graffiti of the {}.",
"a toy {}.",
"itap of my {}.",
"a photo of a cool {}.",
"a photo of a small {}.",
"a tattoo of the {}.",
]
IMAGENET_TEMPLATES_SELECT = [
"itap of a {}.",
"a bad photo of the {}.",
"a origami {}.",
"a photo of the large {}.",
"a {} in a video game.",
"art of the {}.",
"a photo of the small {}.",
]

304
trainers/independentVL.py Normal file
View File

@@ -0,0 +1,304 @@
import os.path as osp
from collections import OrderedDict
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
design_details = {"trainer": 'IVLP',
"vision_depth": cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION,
"language_depth": cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT, "vision_ctx": cfg.TRAINER.IVLP.N_CTX_VISION,
"language_ctx": cfg.TRAINER.IVLP.N_CTX_TEXT}
model = clip.build_model(state_dict or model.state_dict(), design_details)
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class VLPromptLearner(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
# Make sure Language depth >= 1
assert cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \
"\nPlease use VPT trainer if you want to learn only vision " \
"branch "
n_ctx = cfg.TRAINER.IVLP.N_CTX_TEXT
ctx_init = cfg.TRAINER.IVLP.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
vis_dim = clip_model.visual.output_dim
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init and (n_ctx) <= 4:
# use given words to initialize context vectors
ctx_init = ctx_init.replace("_", " ")
n_ctx = n_ctx
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
prompt_prefix = ctx_init
else:
# random initialization
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print(f"Independent V-L design")
print(f'Initial text context: "{prompt_prefix}"')
print(f"Number of context words (tokens) for Language prompting: {n_ctx}")
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.IVLP.N_CTX_VISION}")
self.ctx = nn.Parameter(ctx_vectors)
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor
self.name_lens = name_lens
def construct_prompts(self, ctx, prefix, suffix, label=None):
# dim0 is either batch_size (during training) or n_cls (during testing)
# ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
# prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
# suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
if label is not None:
prefix = prefix[label]
suffix = suffix[label]
prompts = torch.cat(
[
prefix, # (dim0, 1, dim)
ctx, # (dim0, n_ctx, dim)
suffix, # (dim0, *, dim)
],
dim=1,
)
return prompts
def forward(self):
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
prompts = self.construct_prompts(ctx, prefix, suffix)
return prompts
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.prompt_learner = VLPromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
def forward(self, image, label=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts = self.prompt_learner()
text_features = self.text_encoder(prompts, tokenized_prompts)
image_features = self.image_encoder(image.type(self.dtype))
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logits = logit_scale * image_features @ text_features.t()
if self.prompt_learner.training:
return F.cross_entropy(logits, label)
return logits
@TRAINER_REGISTRY.register()
class IVLP(TrainerX):
def check_cfg(self, cfg):
assert cfg.TRAINER.IVLP.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.IVLP.PREC == "fp32" or cfg.TRAINER.IVLP.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
name_to_update = "prompt_learner"
for name, param in self.model.named_parameters():
if name_to_update not in name:
# Make sure that VPT prompts are updated
if "VPT" in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
# Double check
enabled = set()
for name, param in self.model.named_parameters():
if param.requires_grad:
enabled.add(name)
print(f"Parameters to be updated: {enabled}")
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("VLPromptLearner", self.model, self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.IVLP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model)
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
model = self.model
optim = self.optim
scaler = self.scaler
prec = self.cfg.TRAINER.IVLP.PREC
if prec == "amp":
with autocast():
loss = model(image, label)
optim.zero_grad()
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else:
loss = model(image, label)
optim.zero_grad()
loss.backward()
optim.step()
loss_summary = {"loss": loss.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def load_model(self, directory, epoch=None):
if not directory:
print("Note that load_model() is skipped as no pretrained model is given")
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "prompt_learner.token_prefix" in state_dict:
del state_dict["prompt_learner.token_prefix"]
if "prompt_learner.token_suffix" in state_dict:
del state_dict["prompt_learner.token_suffix"]
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)

928
trainers/maple.py Normal file
View File

@@ -0,0 +1,928 @@
import os.path as osp
import random
from collections import OrderedDict
import math
import copy
import torch
import torch.nn as nn
import time
import os
import pickle
import deepcore.methods as s_method
import numpy as np
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint, mkdir_if_missing
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.evaluation import Classification,EvaluatorBase
from pygrad.pcgrad import PCGrad
from datasets.data_manager import DataManager
from dassl.data.datasets import build_dataset
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from trainers.zsclip import CUSTOM_TEMPLATES
from .coop import load_clip_to_cpu as lcp
from tqdm import tqdm
from sklearn.metrics import f1_score, confusion_matrix
from collections import OrderedDict, defaultdict
from .util import GradCAM,denorm
import cv2
_tokenizer = _Tokenizer()
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',]
#['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
#'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
#]
BACKGROUND_CATEGORY_FOOD = ['table','forks','tablecloth','hands','spoon','glasses','dishes']
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
design_details = {"trainer": 'MaPLe',
"vision_depth": 0,
"language_depth": 0, "vision_ctx": 0,
"language_ctx": 0,
"maple_length": cfg.TRAINER.MAPLE.N_CTX}
model = clip.build_model(state_dict or model.state_dict(), design_details)
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
# Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass
combined = [x, compound_prompts_deeper_text, 0] # third argument is the counter which denotes depth of prompt
outputs = self.transformer(combined)
x = outputs[0] # extract the x back from here
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class MultiModalPromptLearner(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
n_ctx = cfg.TRAINER.MAPLE.N_CTX # n_ctx
ctx_init = cfg.TRAINER.MAPLE.CTX_INIT # a photo of
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0] #512
clip_imsize = clip_model.visual.input_resolution #224
cfg_imsize = cfg.INPUT.SIZE[0] #224
# Default is 1, which is compound shallow prompting
assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, "For MaPLe, PROMPT_DEPTH should be >= 1"
self.compound_prompts_depth = cfg.TRAINER.MAPLE.PROMPT_DEPTH #9 # max=12, but will create 11 such shared prompts
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init and (n_ctx) <= 4:
# use given words to initialize context vectors
ctx_init = ctx_init.replace("_", " ")
n_ctx = n_ctx
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
prompt_prefix = ctx_init
else:
# random initialization
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print('MaPLe design: Multi-modal Prompt Learning')
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of MaPLe context words (tokens): {n_ctx}")
# These below, related to the shallow prompts
# Linear layer so that the tokens will project to 512 and will be initialized from 768
self.proj = nn.Linear(ctx_dim, 768)
self.proj.half()
self.ctx = nn.Parameter(ctx_vectors) #[2 512]
# These below parameters related to the shared prompts
# Define the compound prompts for the deeper layers
# Minimum can be 1, which defaults to shallow MaPLe
# compound prompts
self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(n_ctx, 512))
for _ in range(self.compound_prompts_depth - 1)])
for single_para in self.compound_prompts_text:
nn.init.normal_(single_para, std=0.02)
# Also make corresponding projection layers, for each prompt
single_layer = nn.Linear(ctx_dim, 768)
self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
###Introduce Background
bg_template = 'a clean origami {}.'
bg_classesnames = [bg_template.format(name) for name in BACKGROUND_CATEGORY +BACKGROUND_CATEGORY_FOOD ]
tokenized_bg_prompts = torch.cat([clip.tokenize(bg) for bg in bg_classesnames])
bg_num = len(BACKGROUND_CATEGORY) + len(BACKGROUND_CATEGORY_FOOD)
tokenized_prompts = torch.cat((tokenized_prompts,tokenized_bg_prompts),dim=0)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.bg_embeding = embedding[-bg_num:]
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:-bg_num, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:-bg_num, 1 + n_ctx:, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num 77] [:-bg_num]
self.name_lens = name_lens
def construct_prompts(self, ctx, prefix, suffix, label=None):
# dim0 is either batch_size (during training) or n_cls (during testing)
# ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
# prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
# suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
if label is not None:
prefix = prefix[label]
suffix = suffix[label]
prompts = torch.cat(
[
prefix, # (dim0, 1, dim)
ctx, # (dim0, n_ctx, dim)
suffix, # (dim0, *, dim)
],
dim=1,
)
final_prompts = torch.cat((prompts,self.bg_embeding.cuda()),dim=0)
return final_prompts
def forward(self):
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
prompts = self.construct_prompts(ctx, prefix, suffix)
# Before returning, need to transform
# prompts to 768 for the visual side
visual_deep_prompts = []
for index, layer in enumerate(self.compound_prompt_projections):
visual_deep_prompts.append(layer(self.compound_prompts_text[index]))
# Now the other way around
# We will project the textual prompts from 512 to 768
return prompts, self.proj(self.ctx), self.compound_prompts_text, visual_deep_prompts # pass here original, as for visual 768 is required
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.image_encoder_ori = clip_model.visual_ori
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
self.txt_f = []
self.img_f = []
self.one_hot_label = []
self.vtx = []
self.loaded_mask = None
# self.loss_weights = torch.nn.Parameter(torch.tensor([0.8,0.03],dtype=self.dtype))
def get_uniform_ball_noise(self,input_shape,radius=1.0):
uniform_noise_ball = torch.randn(input_shape).cuda()
uniform_noise_sphere = F.normalize(uniform_noise_ball,dim=1)
u = torch.rand(input_shape[0]).cuda()
u = u **(1. / input_shape[1])
uniform_noise_ball = (uniform_noise_sphere.T *u *radius).T
return uniform_noise_ball.type(self.dtype)
def get_learnable_noise(self,input_shape):
para = 0.05
noise = torch.nn.Parameter(torch.randn(input_shape)*para).cuda()
return noise.type(self.dtype)
def cos_sim(self,a,b):
return F.cosine_similarity(a,b)
def forward(self, image, label=None,record=False,cal_gradient=False,weight=None,epoch=None,index=None,cfg=None,mask=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
text_features_fg = text_features[:-len(BACKGROUND_CATEGORY)]
ori_image_input = image.type(self.dtype)
# text_features = text_features + self.get_learnable_noise(text_features.shape)
text_features_fg = text_features_fg / text_features_fg.norm(dim=-1, keepdim=True)
image_features, visual_ctx, mask_similarity = self.image_encoder(ori_image_input, shared_ctx,
deep_compound_prompts_vision)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# if label is not None:
# image_features = image_features + self.get_uniform_ball_noise(image_features.shape)
logits = logit_scale * image_features @ text_features_fg.t()
if mask != None:
text_features_bg = text_features[-len(BACKGROUND_CATEGORY):]
text_features_bg = text_features_bg / text_features_bg.norm(dim=-1, keepdim=True)
image_features_fg,_,_ = self.image_encoder(ori_image_input*mask, shared_ctx, deep_compound_prompts_vision) #, shared_ctx, deep_compound_prompts_vision
image_features_fg = image_features_fg / image_features_fg.norm(dim=-1, keepdim=True)
image_features_bg,_,_ = self.image_encoder(ori_image_input*(1-mask), shared_ctx, deep_compound_prompts_vision)
image_features_bg = image_features_bg / image_features_bg.norm(dim=-1, keepdim=True)
loss_re1 = F.triplet_margin_loss(image_features,image_features_fg.detach(),image_features_bg.detach(),margin=1.5)
# image_features_fg_ori = self.image_encoder_ori(ori_image_input*mask_random)
# image_features_bg_ori = self.image_encoder_ori(ori_image_input*(1-mask_random))
# image_features_fg_ori = image_features_fg_ori / image_features_fg_ori.norm(dim=-1, keepdim=True)
# image_features_bg_ori = image_features_bg_ori / image_features_bg_ori.norm(dim=-1,keepdim=True)
# image_features_all_ori = image_features_fg_ori + image_features_bg_ori
# image_features_all_ori = image_features_all_ori / image_features_all_ori.norm(dim=-1,keepdim=True)
# loss_reo = torch.abs(image_features_all_ori.detach() - image_features).mean()
foreground_score = logit_scale*image_features_fg.detach()@text_features_fg.t()
pseudo_label = torch.argmax(image_features_bg @ text_features_bg.t(), dim=-1)
logits_bg = logit_scale*(image_features_bg) @ text_features_bg.t()
para_bg = 0.5
para_fg = 0.1
para_vd = 0.8
loss_bg = F.cross_entropy(logits_bg,pseudo_label)
loss_fg = F.cross_entropy(foreground_score,label)
if epoch > 6: #Tunable parameters
loss_re = para_fg*loss_fg + para_bg*loss_bg
else:
loss_re = para_vd*loss_re1 #loss_reo would be effective in base2novel setting
if self.prompt_learner.training:
if weight is None:
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_vd':loss_re1.item(),'loss_bg':loss_bg.item(),'loss_fg':loss_fg.item()}
else:
return F.cross_entropy(weight.unsqueeze(-1)*logits,label), logits
if record: #store the embeeding
one_hot_label = F.one_hot(label,num_classes=text_features.shape[0]).to(torch.float16)
return image_features.detach(),(one_hot_label @ text_features).detach(), logits
if cal_gradient:
#Treating this as initial gradient
# one_hot_label = F.one_hot(label,num_classes=text_features.shape[0]).to(torch.float16)
return F.cross_entropy(logits.requires_grad_(True), label), image_features.detach(), logits #,(one_hot_label @ text_features).detach()
return logits
def grad_norm(self,loss_group,original_loss_group):
alpha = 0.10
self.loss_weights.grad.data = self.loss_weights.grad.data * 0.0
W = self.prompt_learner.compound_prompt_projections[0]
norms = []
for i in range(len(loss_group)):
gygw = torch.autograd.grad(loss_group[i],W.parameters(),retain_graph=True)
norms.append(torch.norm(torch.mul(self.loss_weights[i],gygw[0])))
norms = torch.stack(norms)
loss_ratio = loss_group.data.cpu().numpy() / original_loss_group
inverse_train_rate = loss_ratio / np.mean(loss_ratio)
mean_norm = np.mean(norms.data.cpu().numpy())
constant_norm = torch.tensor(mean_norm*(inverse_train_rate**alpha),requires_grad=False).cuda()
grad_norm_loss = torch.sum(torch.abs(norms - constant_norm))
self.loss_weights.grad = torch.autograd.grad(grad_norm_loss,self.loss_weights)[0]
def forward_test(self, image, label=None,record=False,cal_gradient=False,weight=None,cfg=None,attn_mask=False):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
image_features,visual_ctx,mask = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logits = logit_scale * image_features @ text_features.t()
if self.prompt_learner.training:
if weight is None:
return F.cross_entropy(logits, label),logits
else:
return F.cross_entropy(weight.unsqueeze(-1)*logits,label), logits
if record: #store the embeeding
one_hot_label = F.one_hot(label,num_classes=text_features.shape[0]).to(torch.float16)
return image_features.detach(),(one_hot_label @ text_features).detach(), logits
if attn_mask:
return logits,mask
if cal_gradient:
#Treating this as initial gradient
# one_hot_label = F.one_hot(label,num_classes=text_features.shape[0]).to(torch.float16)
return F.cross_entropy(logits.requires_grad_(True), label), image_features.detach(), logits #,(one_hot_label @ text_features).detach()
return logits
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
@TRAINER_REGISTRY.register()
class MaPLe(TrainerX):
def check_cfg(self, cfg):
assert cfg.TRAINER.MAPLE.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.MAPLE.PREC == "fp32" or cfg.TRAINER.MAPLE.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
name_to_update = "prompt_learner"
for name, param in self.model.named_parameters():
if name_to_update not in name:
# Make sure that VPT prompts are updated
if "VPT" in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
# Double check
enabled = set()
for name, param in self.model.named_parameters():
if param.requires_grad:
enabled.add(name)
print(f"Parameters to be updated: {enabled}")
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
# self.model.loss_weights.requires_grad_(True) #open gradient for loss_weights
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.selected_optim = build_optimizer(self.model, cfg.OPTIM_SELECTION)
self.selected_sched = build_lr_scheduler(self.optim, cfg.OPTIM_SELECTION)
self.register_model("MultiModalPromptLearner", self.model, self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.MAPLE.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
# device_count = torch.cuda.device_count()
# if device_count > 1:
# print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
# self.model = nn.DataParallel(self.model)
# def generate_text_feature(self):
# cfg = self.cfg
# classnames = self.dm.dataset.classnames
# #
# # print(f"Loading Custom CLIP (backbone: {cfg.MODEL.BACKBONE.NAME}) for selection")
# # clip_model = lcp(cfg)
# # clip_model.to(self.device)
#
# temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
# prompts = [temp.format(c.replace("_", " ")) for c in classnames]
# print(f"Prompts: {prompts}")
# prompts = torch.cat([clip.tokenize(p) for p in prompts])
# prompts = prompts.to(self.device)
#
# p, _, deep_compound_prompts_text, _ = self.model.prompt_learner()
# with torch.no_grad():
# text = self.model.text_encoder(prompts)
# text_features = self.model.encode_text(prompts, tokenized_prompts, deep_compound_prompts_text)
# text_features = text_features / text_features.norm(dim=-1, keepdim=True)
#
# self.ori_text_features = text_features
def forward_backward(self, batch):
if self.sample_weights is not None:
image, label,index,mask = self.parse_batch_train_pair(batch)
else:
image, label,index,mask = self.parse_batch_train_pair(batch)
weight = None
model = self.model
optim = self.optim
scaler = self.scaler
prec = self.cfg.TRAINER.MAPLE.PREC
if prec == "amp":
with autocast():
loss,_ = model(image, label, weight=weight,mask=mask)
optim.zero_grad()
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else:
loss,_,loss_dict = model(image, label, weight=weight,epoch=self.epoch,index=index,cfg=self.cfg,mask=mask)
optim.zero_grad()
# optim.pc_backward(loss_task)
loss.backward()
# if self.epoch == 0:
# self.loss_o1 = loss_task.data.cpu().numpy()
# model.grad_norm(loss_task,self.loss_o1)
optim.step()
# normalized_coeff = 2 / torch.sum(model.loss_weights.data,dim=0)
# model.loss_weights.data *= normalized_coeff
loss_summary = loss_dict
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train_pair(self, batch):
input = batch["img"]
label = batch["label"]
index = batch["index"]
mask = batch['mask']
input = input.to(self.device)
label = label.to(self.device)
mask = mask.to(self.device)
if self.sample_weights is not None:
# weight = batch['weight'].cuda()
return input, label,index,mask
else:
return input, label,index,mask
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
index = batch["index"]
input = input.to(self.device)
label = label.to(self.device)
if self.sample_weights is not None:
weight = batch['weight'].cuda()
return input, label,weight,index
else:
return input, label,index
def load_model(self, directory, epoch=None):
if not directory:
print("Note that load_model() is skipped as no pretrained model is given")
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "prompt_learner.token_prefix" in state_dict:
del state_dict["prompt_learner.token_prefix"]
if "prompt_learner.token_suffix" in state_dict:
del state_dict["prompt_learner.token_suffix"]
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
def before_train(self):
directory = self.cfg.OUTPUT_DIR
if self.cfg.RESUME:
directory = self.cfg.RESUME
# self.start_epoch = self.resume_model_if_exist(directory) #in case of loading pre-trained weight
# Redefine the dataloader
selected_res = self.selector()
if 'weights' in selected_res:
c_weight = np.zeros(len(self.dm.dataset.train_x))
c_weight[selected_res['indices']] = selected_res['weights']
self.sample_weights = c_weight[selected_res['indices']]
else:
self.sample_weights = None
self.build_final_data_loader(selected_res['indices'],self.sample_weights)
print(f'Finish the selecting process, now continue tune CLIP')
# Initialize summary writer
writer_dir = osp.join(self.output_dir, "tensorboard")
mkdir_if_missing(writer_dir)
self.init_writer(writer_dir)
# Remember the starting time (for computing the elapsed time)
self.time_start = time.time()
print(f"Now generate the attentive masking in {self.cfg.TRAINER.DAPT_MODE} \n")
if self.cfg.TRAINER.DAPT_MODE == 'dapt-s':
self.generate_mask_train()
else:
self.generate_gradcam_train(split='train')
def after_epoch(self):
last_epoch = (self.epoch + 1) == self.max_epoch
do_test = not self.cfg.TEST.NO_TEST
meet_checkpoint_freq = (
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False)
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
curr_result = self.test(split="val")
is_best = curr_result > self.best_result
if is_best:
self.best_result = curr_result
self.save_model(
self.epoch,
self.output_dir,
val_result=curr_result,
model_name="model-best.pth.tar"
)
# if meet_checkpoint_freq or last_epoch:
# self.save_model(self.epoch, self.output_dir)
print(f"Now generate the attentive masking in {self.cfg.TRAINER.DAPT_MODE} \n")
if self.cfg.TRAINER.DAPT_MODE == 'dapt-s':
self.generate_mask_train()
else:
self.generate_gradcam_train(split='train')
def build_final_data_loader(self,selected_ind=None,weight=None):
new_dm = DataManager(self.cfg,self.dm.dataset,selected_ind,weight=weight)
self.train_loader_x = new_dm.train_loader_x
self.train_loader_xmore = new_dm.train_loader_xmore #for generate the attentive masking
self.mask_list = torch.zeros((selected_ind.shape[0], 1, *self.cfg.INPUT.SIZE),dtype=torch.float16)
def selector(self):
selection_ratio = self.cfg.DATASET.SELECTION_RATIO
seed = self.cfg.SEED
method = self.cfg.DATASET.SELECTION_METHOD
print(f"Selecting {selection_ratio*100}% data by {method}")
if self.cfg.DATASET.SELECTION_METHOD == 'Uniform':
selector = s_method.Uniform(self.dm, self.cfg,selection_ratio, seed)
else:
selector = s_method.__dict__[method](dst_train=self.dm,
args=self.cfg,
fraction=selection_ratio,
random_seed=seed,
specific_model=self.model,
optim = self.selected_optim,
schedule = self.selected_sched,
scar = self.scaler,
balance = True
)
return selector.select()
@torch.no_grad()
def test_withlabel(self, split=None):
"""A generic testing pipeline."""
self.set_model_mode("eval")
new_estimate = NewClassification(self.cfg,self.evaluator._lab2cname)
new_estimate.reset()
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
else:
split = "test" # in case val_loader is None
data_loader = self.test_loader
print(f"Evaluate on the *{split}* set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
input, label = self.parse_batch_test(batch)
output = self.model.forward_test(input,label,cfg = self.cfg)
new_estimate.process(output, label)
results = new_estimate.evaluate()
for k, v in results.items():
tag = f"{split}/{k}"
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
def generate_gradcam(self, split=None,attn_mask=False):
"""A generic pipeline for generating GradCAM"""
self.set_model_mode("eval")
model_dict = {'arch':self.model,'layer_name':'target.layer'}
cam = GradCAM(model_dict)
# new_estimate = NewClassification(self.cfg,self.evaluator._lab2cname)
# new_estimate.reset()
img_split = 'wrong' #true/wrong
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
else:
split = "test" # in case val_loader is None
data_loader = self.test_loader
print(f"Generate GradCAM on the *{split}* set")
save_path = self.cfg.OUTPUT_DIR + '/'+f'{split}_{img_split}_promptcamother'
if not os.path.exists(save_path):
os.mkdir(save_path)
for batch_idx, batch in enumerate(tqdm(data_loader)):
input, label = self.parse_batch_test(batch)
img_name = batch['impath'][0].split('/')[-1]
img_save_path = os.path.join(save_path, img_name)
img0 = denorm(batch['img0'].numpy(),self.cfg.INPUT.PIXEL_MEAN,self.cfg.INPUT.PIXEL_STD)
saliency_map = cam.forward(input,label,cfg = self.cfg,split=img_split,attn_mask=attn_mask)
if saliency_map != None:
final_map = cam.show_cam(img0,saliency_map.detach().cpu(),img_save_path)
def generate_mask_train(self):
for batch_idx, batch in enumerate(tqdm(self.train_loader_xmore)):
input, _, index = self.parse_batch_train(batch)
b,c,h,w = input.shape
mask = torch.ones((1,h,w),dtype=torch.float16)
grid_sizes = [32,16]
hide_prob = 0.5
grid_size = grid_sizes[torch.randint(0,len(grid_sizes),size=(1,))]
if (grid_size != 0):
for x in range(0,h,grid_size):
for y in range(0,w,grid_size):
x_end,y_end = min(h, x+grid_size),min(w,y+grid_size)
if (random.random() <= hide_prob):
mask[:,x:x_end,y:y_end] = 0
self.mask_list[index, :] = mask
self.model.loaded_mask = self.mask_list
def generate_mask_bg(self):
for batch_idx, batch in enumerate(tqdm(self.train_loader_xmore)):
input, _, index = self.parse_batch_train(batch)
b,c,h,w = input.shape
mask = torch.ones((1,h,w),dtype=torch.float16)
grid_sizes = [64,128]
hide_prob = 0.5
grid_size = grid_sizes[torch.randint(0,len(grid_sizes),size=(1,))]
if (grid_size != 0):
for x in range(0,h,grid_size):
for y in range(0,w,grid_size):
x_end,y_end = min(h, x+grid_size),min(w,y+grid_size)
if (random.random() <= hide_prob):
mask[:,x:x_end,y:y_end] = 0
self.mask_list[index, :] = mask
self.model.loaded_mask = self.mask_list
def generate_gradcam_train(self, split=None,attn_mask=False):
"""A generic pipeline for generating GradCAM"""
self.set_model_mode("eval")
model_dict = {'arch':self.model,'layer_name':'target.layer'}
cam = GradCAM(model_dict)
# new_estimate = NewClassification(self.cfg,self.evaluator._lab2cname)
# new_estimate.reset()
print(f"Generate GradCAM on the *{split}* set")
# save_path = self.cfg.OUTPUT_DIR + '/'+f'{split}_{img_split}_promptcamother'
# if not os.path.exists(save_path):
# os.mkdir(save_path)
for batch_idx, batch in enumerate(tqdm(self.train_loader_xmore)):
input, label, index = self.parse_batch_train(batch)
# img0 = denorm(batch['img0'].numpy(),self.cfg.INPUT.PIXEL_MEAN,self.cfg.INPUT.PIXEL_STD)
saliency_map = cam.forward_train(input,label,cfg = self.cfg,attn_mask=attn_mask)
self.mask_list[index,:] = saliency_map.detach().cpu()
# if saliency_map != None:
# final_map = cam.show_cam(img0,saliency_map.detach().cpu(),img_save_path)
self.model.loaded_mask = self.mask_list
class NewClassification(Classification):
def __init__(self, cfg, lab2cname=None, **kwargs):
super(NewClassification, self).__init__(cfg,lab2cname)
self._lab2cname = lab2cname
self._correct = 0
self._total = 0
self._per_class_res = None
self._y_true = []
self._y_pred = []
if cfg.TEST.PER_CLASS_RESULT:
assert lab2cname is not None
self._per_class_res = defaultdict(list)
def evaluate(self):
results = OrderedDict()
acc = 100.0 * self._correct / self._total
err = 100.0 - acc
macro_f1 = 100.0 * f1_score(
self._y_true,
self._y_pred,
average="macro",
labels=np.unique(self._y_true)
)
# The first value will be returned by trainer.test()
results["accuracy"] = acc
results["error_rate"] = err
results["macro_f1"] = macro_f1
wrong_ind = np.array(self._y_true) != np.array(self._y_pred)
np.save(self.cfg.OUTPUT_DIR + '/'+'wrongind.npy',wrong_ind)
print(
"=> result\n"
f"* total: {self._total:,}\n"
f"* correct: {self._correct:,}\n"
f"* accuracy: {acc:.1f}%\n"
f"* error: {err:.1f}%\n"
f"* macro_f1: {macro_f1:.1f}%"
)
if self._per_class_res is not None:
labels = list(self._per_class_res.keys())
labels.sort()
print("=> per-class result")
accs = []
for label in labels:
classname = self._lab2cname[label]
res = self._per_class_res[label]
correct = sum(res)
total = len(res)
acc = 100.0 * correct / total
accs.append(acc)
print(
f"* class: {label} ({classname})\t"
f"total: {total:,}\t"
f"correct: {correct:,}\t"
f"acc: {acc:.1f}%"
)
mean_acc = np.mean(accs)
np.save(self.cfg.OUTPUT_DIR + '/'+'per-class.npy',{'per_cls':accs, 'mean_acc':mean_acc})
print(f"* average: {mean_acc:.1f}%")
results["perclass_accuracy"] = mean_acc
if self.cfg.TEST.COMPUTE_CMAT:
cmat = confusion_matrix(
self._y_true, self._y_pred, normalize="true"
)
save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
torch.save(cmat, save_path)
print(f"Confusion matrix is saved to {save_path}")
return results

265
trainers/util.py Normal file
View File

@@ -0,0 +1,265 @@
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import os
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
]
class GradCAM(object):
def __init__(self,model_dict):
layer_name = model_dict['layer_name']
self.model_arch = model_dict['arch']
self.gradient = dict()
self.activation = dict()
self.gradient_t = dict()
self.activation_t = dict()
def backward_hook(module,grad_input,grad_output):
self.gradient['value'] = grad_output[0]
return None
def forward_hook(module,input,output):
self.activation['value'] = output
return None
def backward_hook_t(module,grad_input,grad_output):
self.gradient_t['value'] = grad_output[0]
return None
def forward_hook_t(module,input,output):
self.activation_t['value'] = output
return None
target_layer = self.model_arch.image_encoder.transformer.resblocks[-1].ln_1
# target_layer_t = self.model_arch.image_encoder.transformer.resblocks[-2].mlp.c_proj
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)
# target_layer_t.register_forward_hook(forward_hook_t)
# target_layer_t.register_backward_hook(backward_hook_t)
def forward(self,input,labels,cfg=None,retain_graph=False,split=None,attn_mask=False):
b,c,h,w = input.shape
patch_num,ori_size = self.model_arch.image_encoder.patch_num, self.model_arch.image_encoder.input_resolution
if attn_mask:
logit,mask = self.model_arch.forward_test(input,labels,cfg=cfg,attn_mask=attn_mask)
cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].reshape(b,-1,patch_num,patch_num) #+ mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1)
aff = mask[:,1:-self.model_arch.prompt_learner.n_ctx, 1:-self.model_arch.prompt_learner.n_ctx]
# aff = (aff + aff.permute(0,2,1)) / 2
aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
# aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
# aff = (aff + aff.permute(0,2,1)) / 2
# aff = torch.bmm(aff,aff)
# aff = F.softmax(aff,dim=1)
# cls_mask = torch.bmm(cls_mask, aff).reshape(b,-1,patch_num,patch_num)
# cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1).reshape(b,-1,patch_num,patch_num)
# # cls_mask = mask[:,-self.model_arch.prompt_learner.n_ctx:,1:-self.model_arch.prompt_learner.n_ctx].reshape(b,-1,patch_num,patch_num).mean(dim=1,keepdim=True)
# final_cls_mask = F.upsample(cls_mask, size=(ori_size, ori_size), mode='bilinear',
# align_corners=True)
# final_cls_feature_min, final_cls_feature_max = final_cls_mask.min(), final_cls_mask.max()
# final_cls_mask = (final_cls_mask - final_cls_feature_min) / (
# final_cls_feature_max - final_cls_feature_min + 1e-6)
# final_cls_mask = final_cls_mask / (final_cls_mask.max() + 1e-6)
else:
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
pred_label = torch.argmax(logit[:,:-len(BACKGROUND_CATEGORY)])
sign = pred_label == labels
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
# print(f'Ignore the not {split} sample')
# return None
# if attn_mask:
# return final_cls_mask
pred = logit[:,:-len(BACKGROUND_CATEGORY)].argmax(dim=-1)
background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]-len(BACKGROUND_CATEGORY)).to(torch.float16)
loss = (F.softmax(logit[:,:-len(BACKGROUND_CATEGORY)])*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
# score = logit[:,labels]
self.model_arch.zero_grad()
loss.backward(retain_graph=retain_graph)
gradients = self.gradient['value']
activations = self.activation['value']
# gradients_t = self.gradient_t['value']
# activations_t = self.activation_t['value']
visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
# visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
# cls = gradients[1:-self.model_arch.prompt_learner.n_ctx,:,:]
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(dim=0,keepdim=True)#gradients[:1,:,:]
cls_token_gradient,prompt_gradient = gradients[:1,:,:], gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(keepdim=True,dim=0)
visual_gradient = torch.mean(gradients[1:-self.model_arch.prompt_learner.n_ctx],keepdim=True,dim=0)
lam = 0.5
# cls_token_gradient = cls_token_gradient / (cls_token_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
# prompt_gradient = prompt_gradient / (prompt_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
# sim = F.cosine_similarity(prompt_gradient.mean(dim=0,keepdim=True),cls_token_gradient,dim=-1)
# print(sim)
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].max(dim=0,keepdim=True)[0]#gradients[:1,:,:]
# token_gradient = cls_token_gradient
# token_gradient = cls_token_gradient#*(prompt_gradient.mean(dim=0,keepdim=True))
# propmt_mean = prompt_gradient.mean(dim=0,keepdim=True)
token_gradient = visual_gradient
final_visual_feature = torch.bmm(visual_feature.permute(1,0,2),token_gradient.permute(1,2,0))
final_visual_feature = F.relu(final_visual_feature).permute(0,2,1)
# if attn_mask:
# final_visual_feature = torch.bmm(final_visual_feature, aff)
final_visual_feature = final_visual_feature.reshape(final_visual_feature.shape[0],1, patch_num, patch_num)
final_visual_feature = F.upsample(final_visual_feature,size=(ori_size,ori_size),mode='bilinear',align_corners=True)
# saliency_map = final_visual_feature / final_visual_feature.max()
final_visual_feature_min, final_visual_feature_max = final_visual_feature.min(), final_visual_feature.max()
saliency_map = final_visual_feature / (final_visual_feature_max + 1e-6)#(final_visual_feature-final_visual_feature_min) / (final_visual_feature_max - final_visual_feature_min + 1e-6)
threshold = 0.5
# saliency_map[saliency_map >= threshold] = 1
saliency_map[saliency_map < threshold] = 0
return saliency_map
def forward_train(self,input,labels,cfg=None,retain_graph=False,split=None,attn_mask=False):
b,c,h,w = input.shape
patch_num,ori_size = self.model_arch.image_encoder.patch_num, self.model_arch.image_encoder.input_resolution
if attn_mask:
logit,mask = self.model_arch.forward_test(input,labels,cfg=cfg,attn_mask=attn_mask)
cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].reshape(b,-1,patch_num,patch_num) #+ mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1)
aff = mask[:,1:-self.model_arch.prompt_learner.n_ctx, 1:-self.model_arch.prompt_learner.n_ctx]
# aff = (aff + aff.permute(0,2,1)) / 2
aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
# aff = aff / (aff.sum(dim=1,keepdim=True) + 1e-6)
# aff = (aff + aff.permute(0,2,1)) / 2
# aff = torch.bmm(aff,aff)
# aff = F.softmax(aff,dim=1)
# cls_mask = torch.bmm(cls_mask, aff).reshape(b,-1,patch_num,patch_num)
# cls_mask = mask[:,1:-self.model_arch.prompt_learner.n_ctx,:1].permute(0,2,1).reshape(b,-1,patch_num,patch_num)
# # cls_mask = mask[:,-self.model_arch.prompt_learner.n_ctx:,1:-self.model_arch.prompt_learner.n_ctx].reshape(b,-1,patch_num,patch_num).mean(dim=1,keepdim=True)
# final_cls_mask = F.upsample(cls_mask, size=(ori_size, ori_size), mode='bilinear',
# align_corners=True)
# final_cls_feature_min, final_cls_feature_max = final_cls_mask.min(), final_cls_mask.max()
# final_cls_mask = (final_cls_mask - final_cls_feature_min) / (
# final_cls_feature_max - final_cls_feature_min + 1e-6)
# final_cls_mask = final_cls_mask / (final_cls_mask.max() + 1e-6)
else:
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
pred_label = torch.argmax(logit)
sign = pred_label == labels
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
# print(f'Ignore the not {split} sample')
# return None
# if attn_mask:
# return final_cls_mask
# pred = logit[:,-len(BACKGROUND_CATEGORY):].argmax(dim=-1)
# background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
loss = (logit*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
# score = logit[:,labels]
self.model_arch.zero_grad()
loss.backward(retain_graph=retain_graph)
gradients = self.gradient['value']
activations = self.activation['value']
# gradients_t = self.gradient_t['value']
# activations_t = self.activation_t['value']
visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
# visual_feature = activations[1:-self.model_arch.prompt_learner.n_ctx]
# cls = gradients[1:-self.model_arch.prompt_learner.n_ctx,:,:]
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(dim=0,keepdim=True)#gradients[:1,:,:]
cls_token_gradient,prompt_gradient = gradients[:1,:,:], gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].mean(keepdim=True,dim=0)
visual_gradient = torch.mean(gradients[1:-self.model_arch.prompt_learner.n_ctx],keepdim=True,dim=0)
lam = 0.5
# cls_token_gradient = cls_token_gradient / (cls_token_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
# prompt_gradient = prompt_gradient / (prompt_gradient.max(dim=-1,keepdim=True)[0] + 1e-6)
# sim = F.cosine_similarity(prompt_gradient.mean(dim=0,keepdim=True),cls_token_gradient,dim=-1)
# print(sim)
# cls_token_gradient = gradients[-self.model_arch.prompt_learner.n_ctx:,:,:].max(dim=0,keepdim=True)[0]#gradients[:1,:,:]
# token_gradient = cls_token_gradient
# token_gradient = cls_token_gradient#*(prompt_gradient.mean(dim=0,keepdim=True))
# propmt_mean = prompt_gradient.mean(dim=0,keepdim=True)
token_gradient = visual_gradient
final_visual_feature = torch.bmm(visual_feature.permute(1,0,2),token_gradient.permute(1,2,0))
final_visual_feature = F.relu(final_visual_feature).permute(0,2,1)
# if attn_mask:
# final_visual_feature = torch.bmm(final_visual_feature, aff)
final_visual_feature = final_visual_feature.reshape(final_visual_feature.shape[0],1, patch_num, patch_num)
final_visual_feature = F.upsample(final_visual_feature,size=(ori_size,ori_size),mode='bilinear',align_corners=True)
# saliency_map = final_visual_feature / final_visual_feature.max()
final_visual_feature_min, final_visual_feature_max = final_visual_feature.min(), final_visual_feature.max()
saliency_map = final_visual_feature / (final_visual_feature_max + 1e-6)#(final_visual_feature-final_visual_feature_min) / (final_visual_feature_max - final_visual_feature_min + 1e-6)
threshold = 0.5
saliency_map[saliency_map >= threshold] = 1
saliency_map[saliency_map < threshold] = 0
return saliency_map
def show_cam(self,img,mask,save_path=None):
heat_map = cv2.applyColorMap(np.uint8(255*mask.squeeze()), cv2.COLORMAP_JET)
heatmap = torch.from_numpy(heat_map).permute(2,0,1).float().div(255)
b,g,r = heatmap.split(1)
heatmap = torch.cat([r,g,b])
rate = 0.5
res = rate*heatmap + (1-rate)*img
res = res.div(res.max()).squeeze()
res = np.transpose(np.uint8(255*res),(1,2,0))
pil_image = Image.fromarray(res)
# pil_image.save('test1.jpg')
pil_image.save(save_path)
return pil_image
def denorm(img,mean,std):
mean,std = np.array(mean),np.array(std)
img = img*std[:, None, None] + mean[:, None, None]
# img = np.clip(img*255, 0, 255) #.clamp(0,255)
# img = img / 255
return img

239
trainers/vpt.py Normal file
View File

@@ -0,0 +1,239 @@
import os.path as osp
from collections import OrderedDict
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
design_details = { "trainer": "VPT",
"vision_depth": cfg.TRAINER.VPT.PROMPT_DEPTH_VISION,
"vision_ctx": cfg.TRAINER.VPT.N_CTX_VISION,
"language_depth": 0,
"language_ctx": 0}
assert cfg.TRAINER.VPT.PROMPT_DEPTH_VISION >= 1, "For Vision Prompting, PROMPT_DEPTH_VISION should be >= 1"
model = clip.build_model(state_dict or model.state_dict(), design_details)
return model.float()
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class FixedEmbeddings():
def __init__(self, cfg, classnames, clip_model):
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
prompt_prefix = "a photo of a"
print('Vision Prompting Design')
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.VPT.N_CTX_VISION}")
print(f"Using fixed hand crated prompts")
classnames = [name.replace("_", " ") for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
with torch.no_grad():
text_features = clip_model.encode_text(tokenized_prompts)
self.fixed_embeddings = text_features
def return_fixed_embeddings(self):
return self.fixed_embeddings
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.embeddings = FixedEmbeddings(cfg, classnames, clip_model)
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
def forward(self, image, label=None, training=False):
logit_scale = self.logit_scale.exp()
text_features = self.embeddings.return_fixed_embeddings().cuda()
image_features = self.image_encoder(image.type(self.dtype))
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logits = logit_scale * image_features @ text_features.t()
if training:
return F.cross_entropy(logits, label)
return logits
@TRAINER_REGISTRY.register()
class VPT(TrainerX):
def check_cfg(self, cfg):
assert cfg.TRAINER.VPT.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.VPT.PREC == "fp32" or cfg.TRAINER.VPT.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
name_to_update = "prompt_learner"
for name, param in self.model.named_parameters():
if name_to_update not in name:
# Make sure that VPT prompts are updated
if "VPT" in name:
param.requires_grad_(True)
else:
param.requires_grad_(False)
# Double check
enabled = set()
for name, param in self.model.named_parameters():
if param.requires_grad:
enabled.add(name)
print(f"Parameters to be updated: {enabled}")
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model, self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.VPT.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model)
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
model = self.model
optim = self.optim
scaler = self.scaler
prec = self.cfg.TRAINER.VPT.PREC
if prec == "amp":
with autocast():
loss = model(image, label)
optim.zero_grad()
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else:
loss = model(image, label, training=True)
optim.zero_grad()
loss.backward()
optim.step()
loss_summary = {"loss": loss.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def load_model(self, directory, epoch=None):
if not directory:
print("Note that load_model() is skipped as no pretrained model is given")
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "prompt_learner.token_prefix" in state_dict:
del state_dict["prompt_learner.token_prefix"]
if "prompt_learner.token_suffix" in state_dict:
del state_dict["prompt_learner.token_suffix"]
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)

99
trainers/zsclip.py Normal file
View File

@@ -0,0 +1,99 @@
import torch
import torch.nn as nn
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.model import convert_weights
from .coop import load_clip_to_cpu
from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
CUSTOM_TEMPLATES = {
"OxfordPets": "a photo of a {}, a type of pet.",
"OxfordFlowers": "a photo of a {}, a type of flower.",
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
"DescribableTextures": "{} texture.",
"EuroSAT": "a centered satellite photo of {}.",
"StanfordCars": "a photo of a {}.",
"Food101": "a photo of {}, a type of food.",
"SUN397": "a photo of a {}.",
"Caltech101": "a photo of a {}.",
"UCF101": "a photo of a person doing {}.",
"ImageNet": "a photo of a {}.",
"ImageNetSketch": "a photo of a {}.",
"ImageNetV2": "a photo of a {}.",
"ImageNetA": "a photo of a {}.",
"ImageNetR": "a photo of a {}.",
}
@TRAINER_REGISTRY.register()
class ZeroshotCLIP(TrainerX):
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
clip_model.to(self.device)
temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
print(f"Prompts: {prompts}")
prompts = torch.cat([clip.tokenize(p) for p in prompts])
prompts = prompts.to(self.device)
with torch.no_grad():
text_features = clip_model.encode_text(prompts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
self.text_features = text_features
self.clip_model = clip_model
def model_inference(self, image):
image_features = self.clip_model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
logit_scale = self.clip_model.logit_scale.exp()
logits = logit_scale * image_features @ self.text_features.t()
return logits
@TRAINER_REGISTRY.register()
class ZeroshotCLIP2(ZeroshotCLIP):
"""Prompt ensembling."""
# templates = IMAGENET_TEMPLATES
templates = IMAGENET_TEMPLATES_SELECT
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
clip_model.to(self.device)
for params in clip_model.parameters():
params.requires_grad_(False)
# add custom-made prompt
if cfg.DATASET.NAME != "ImageNet":
self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
num_temp = len(self.templates)
print(f"Prompt ensembling (n={num_temp})")
mean_text_features = 0
for i, temp in enumerate(self.templates):
prompts = [temp.format(c.replace("_", " ")) for c in classnames]
prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
text_features = clip_model.encode_text(prompts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
mean_text_features = mean_text_features + text_features
mean_text_features = mean_text_features / num_temp
mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)
self.text_features = mean_text_features
self.clip_model = clip_model