Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0ba13ffbbd |
@@ -31,3 +31,4 @@ MODEL:
|
|||||||
TRAINER:
|
TRAINER:
|
||||||
COOP:
|
COOP:
|
||||||
CTX_INIT: True
|
CTX_INIT: True
|
||||||
|
ATTENTION_REG_WEIGHT: 0.01
|
||||||
|
|||||||
@@ -3,36 +3,26 @@
|
|||||||
TRAINER=$1
|
TRAINER=$1
|
||||||
KG_WEIGHT=$2
|
KG_WEIGHT=$2
|
||||||
MP_WEIGHT=$3
|
MP_WEIGHT=$3
|
||||||
|
ATTN_REG_WEIGHT=$4
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT}
|
# Define datasets array
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT}
|
datasets=(
|
||||||
|
"ucf101"
|
||||||
|
"eurosat"
|
||||||
|
"oxford_pets"
|
||||||
|
"food101"
|
||||||
|
"oxford_flowers"
|
||||||
|
"dtd"
|
||||||
|
"caltech101"
|
||||||
|
"fgvc_aircraft"
|
||||||
|
"stanford_cars"
|
||||||
|
"sun397"
|
||||||
|
"imagenet"
|
||||||
|
)
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT}
|
# Loop through datasets
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT}
|
for dataset in "${datasets[@]}"; do
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT}
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT}
|
||||||
|
done
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT}
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ DATASET=$2
|
|||||||
N_PROMPTS=4
|
N_PROMPTS=4
|
||||||
KG_WEIGHT=$3
|
KG_WEIGHT=$3
|
||||||
MP_WEIGHT=$4
|
MP_WEIGHT=$4
|
||||||
|
ATTN_REG_WEIGHT=$5
|
||||||
#CFG=rn50_ep100 # config file
|
#CFG=rn50_ep100 # config file
|
||||||
CFG=vit_b16_ep100_ctxv1
|
CFG=vit_b16_ep100_ctxv1
|
||||||
CTP=end # class token position (end or middle)
|
CTP=end # class token position (end or middle)
|
||||||
@@ -19,7 +20,7 @@ SUB=new
|
|||||||
|
|
||||||
for SEED in 1 2 3
|
for SEED in 1 2 3
|
||||||
do
|
do
|
||||||
COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
|
COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED}
|
||||||
MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
|
MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
|
||||||
DIR=output/base2new/test_${SUB}/${COMMON_DIR}
|
DIR=output/base2new/test_${SUB}/${COMMON_DIR}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ TRAINER=$1
|
|||||||
DATASET=$2
|
DATASET=$2
|
||||||
KG_WEIGHT=$3
|
KG_WEIGHT=$3
|
||||||
MP_WEIGHT=$4
|
MP_WEIGHT=$4
|
||||||
|
ATTN_REG_WEIGHT=$5
|
||||||
N_PROMPTS=4
|
N_PROMPTS=4
|
||||||
#CFG=rn50_ep100 # config file\
|
#CFG=rn50_ep100 # config file\
|
||||||
CFG=vit_b16_ep100_ctxv1
|
CFG=vit_b16_ep100_ctxv1
|
||||||
@@ -16,7 +17,7 @@ CSC=False # class-specific context (False or True)
|
|||||||
|
|
||||||
for SEED in 1 2 3
|
for SEED in 1 2 3
|
||||||
do
|
do
|
||||||
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
|
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED}
|
||||||
if [ -d "$DIR" ]; then
|
if [ -d "$DIR" ]; then
|
||||||
echo "Results are available in ${DIR}. Skip this job"
|
echo "Results are available in ${DIR}. Skip this job"
|
||||||
else
|
else
|
||||||
@@ -35,6 +36,7 @@ do
|
|||||||
DATASET.NUM_SHOTS ${SHOTS} \
|
DATASET.NUM_SHOTS ${SHOTS} \
|
||||||
DATASET.SUBSAMPLE_CLASSES base \
|
DATASET.SUBSAMPLE_CLASSES base \
|
||||||
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
|
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
|
||||||
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT}
|
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \
|
||||||
|
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT}
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ def extend_cfg(cfg):
|
|||||||
cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
|
cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
|
||||||
cfg.TRAINER.COOP.DIV_WEIGHT = 0.1
|
cfg.TRAINER.COOP.DIV_WEIGHT = 0.1
|
||||||
cfg.TRAINER.COOP.N_PROMPTS = 3
|
cfg.TRAINER.COOP.N_PROMPTS = 3
|
||||||
|
cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT = 0.01
|
||||||
|
|
||||||
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -223,6 +223,28 @@ class Adapter(nn.Module):
|
|||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class AttentionBasedIntegrator(nn.Module):
|
||||||
|
def __init__(self, img_dim=512, n_prompts=4, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = nn.Sequential(
|
||||||
|
nn.Linear(img_dim, img_dim // 4),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(img_dim // 4, n_prompts)
|
||||||
|
)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
self.attention = self.attention.to(dtype)
|
||||||
|
|
||||||
|
def forward(self, image_features, all_logits):
|
||||||
|
attn_scores = self.attention(image_features)
|
||||||
|
|
||||||
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
|
weighted_logits = torch.einsum('bp,pbc->bc', attn_weights, all_logits)
|
||||||
|
|
||||||
|
return weighted_logits, attn_weights
|
||||||
|
|
||||||
class CustomCLIP(nn.Module):
|
class CustomCLIP(nn.Module):
|
||||||
def __init__(self, cfg, classnames, clip_model):
|
def __init__(self, cfg, classnames, clip_model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -236,6 +258,12 @@ class CustomCLIP(nn.Module):
|
|||||||
self.dtype = clip_model.dtype
|
self.dtype = clip_model.dtype
|
||||||
self.meta_net = self.prompt_learner.meta_net
|
self.meta_net = self.prompt_learner.meta_net
|
||||||
self.adapter = Adapter(512, 4).to(clip_model.dtype)
|
self.adapter = Adapter(512, 4).to(clip_model.dtype)
|
||||||
|
|
||||||
|
self.prompt_integrator = AttentionBasedIntegrator(
|
||||||
|
img_dim=clip_model.visual.output_dim,
|
||||||
|
n_prompts=self.n_prompts,
|
||||||
|
dtype=clip_model.dtype
|
||||||
|
)
|
||||||
|
|
||||||
def compute_diversity_loss(self, text_features):
|
def compute_diversity_loss(self, text_features):
|
||||||
if self.n_prompts == 1:
|
if self.n_prompts == 1:
|
||||||
@@ -283,10 +311,12 @@ class CustomCLIP(nn.Module):
|
|||||||
text_features_i = text_features_i / text_features_i.norm(dim=-1, keepdim=True)
|
text_features_i = text_features_i / text_features_i.norm(dim=-1, keepdim=True)
|
||||||
logits_i = logit_scale * image_features @ text_features_i.t()
|
logits_i = logit_scale * image_features @ text_features_i.t()
|
||||||
all_logits.append(logits_i)
|
all_logits.append(logits_i)
|
||||||
|
|
||||||
|
all_logits = torch.stack(all_logits)
|
||||||
|
|
||||||
|
logits, attn_weights = self.prompt_integrator(image_features, all_logits)
|
||||||
|
|
||||||
logits = torch.stack(all_logits).mean(dim=0)
|
return logits, score, diversity_loss, attn_weights
|
||||||
|
|
||||||
return logits, score, diversity_loss
|
|
||||||
|
|
||||||
|
|
||||||
@TRAINER_REGISTRY.register()
|
@TRAINER_REGISTRY.register()
|
||||||
@@ -310,10 +340,11 @@ class MSGCoOp(TrainerX):
|
|||||||
self.model = CustomCLIP(cfg, classnames, clip_model)
|
self.model = CustomCLIP(cfg, classnames, clip_model)
|
||||||
self.w = cfg.TRAINER.COOP.W
|
self.w = cfg.TRAINER.COOP.W
|
||||||
self.diversity_weight = cfg.TRAINER.COOP.DIV_WEIGHT
|
self.diversity_weight = cfg.TRAINER.COOP.DIV_WEIGHT
|
||||||
|
self.attn_reg_weight = cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT if hasattr(cfg.TRAINER.COOP, 'ATTENTION_REG_WEIGHT') else 0.01
|
||||||
|
|
||||||
print("Turning off gradients in both the image and the text encoder")
|
print("Turning off gradients in both the image and the text encoder")
|
||||||
for name, param in self.model.named_parameters():
|
for name, param in self.model.named_parameters():
|
||||||
if "ctx" not in name:
|
if "ctx" not in name and "prompt_integrator" not in name:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
else:
|
else:
|
||||||
print(name)
|
print(name)
|
||||||
@@ -322,8 +353,10 @@ class MSGCoOp(TrainerX):
|
|||||||
load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
|
load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
|
||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
# NOTE: only give prompt_learner to the optimizer
|
|
||||||
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
|
# NOTE: give prompt_learner and prompt_integrator to the optimizer
|
||||||
|
trainable_params = list(self.model.prompt_learner.parameters()) + list(self.model.prompt_integrator.parameters())
|
||||||
|
self.optim = build_optimizer([{'params': trainable_params}], cfg.OPTIM)
|
||||||
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
|
||||||
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
|
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
|
||||||
|
|
||||||
@@ -352,8 +385,12 @@ class MSGCoOp(TrainerX):
|
|||||||
self.scaler.step(self.optim)
|
self.scaler.step(self.optim)
|
||||||
self.scaler.update()
|
self.scaler.update()
|
||||||
else:
|
else:
|
||||||
output, score, diversity_loss = self.model(image)
|
output, score, diversity_loss, attn_weights = self.model(image)
|
||||||
loss = F.cross_entropy(output, label)+self.w*score + diversity_loss * self.diversity_weight
|
|
||||||
|
# Add attention regularization to encourage balanced prompt usage
|
||||||
|
attn_reg = -(attn_weights * torch.log(attn_weights + 1e-8)).mean()
|
||||||
|
|
||||||
|
loss = F.cross_entropy(output, label) + self.w * score + diversity_loss * self.diversity_weight + self.attn_reg_weight * attn_reg
|
||||||
self.model_backward_and_update(loss)
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
loss_summary = {
|
loss_summary = {
|
||||||
|
|||||||
Reference in New Issue
Block a user