From 0ba13ffbbd7b805c2efa88491ceed2c64d0cd263 Mon Sep 17 00:00:00 2001 From: rain-bus Date: Sat, 31 Jan 2026 23:48:05 +0800 Subject: [PATCH] Attn fuse --- .../trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml | 1 + MSGCoOp/scripts/base2new_all.sh | 50 +++++++---------- MSGCoOp/scripts/base2new_test.sh | 3 +- MSGCoOp/scripts/base2new_train.sh | 6 ++- MSGCoOp/train.py | 1 + MSGCoOp/trainers/msgcoop.py | 53 ++++++++++++++++--- 6 files changed, 73 insertions(+), 41 deletions(-) diff --git a/MSGCoOp/configs/trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml b/MSGCoOp/configs/trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml index bf55dab..a7e11f9 100644 --- a/MSGCoOp/configs/trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml +++ b/MSGCoOp/configs/trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml @@ -31,3 +31,4 @@ MODEL: TRAINER: COOP: CTX_INIT: True + ATTENTION_REG_WEIGHT: 0.01 diff --git a/MSGCoOp/scripts/base2new_all.sh b/MSGCoOp/scripts/base2new_all.sh index 6543a71..6e73bfc 100644 --- a/MSGCoOp/scripts/base2new_all.sh +++ b/MSGCoOp/scripts/base2new_all.sh @@ -3,36 +3,26 @@ TRAINER=$1 KG_WEIGHT=$2 MP_WEIGHT=$3 +ATTN_REG_WEIGHT=$4 -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT} +# Define datasets array +datasets=( + "ucf101" + "eurosat" + "oxford_pets" + "food101" + "oxford_flowers" + "dtd" + "caltech101" + "fgvc_aircraft" + "stanford_cars" + "sun397" + "imagenet" +) -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT} +# Loop through datasets +for dataset in "${datasets[@]}"; do + CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} + CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} +done -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT} - -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT} - -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT} - -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT} - -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT} - -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT} - -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT} - -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT} - -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT} -CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT} diff --git a/MSGCoOp/scripts/base2new_test.sh b/MSGCoOp/scripts/base2new_test.sh index 869f280..662eec9 100644 --- a/MSGCoOp/scripts/base2new_test.sh +++ b/MSGCoOp/scripts/base2new_test.sh @@ -7,6 +7,7 @@ DATASET=$2 N_PROMPTS=4 KG_WEIGHT=$3 MP_WEIGHT=$4 +ATTN_REG_WEIGHT=$5 #CFG=rn50_ep100 # config file CFG=vit_b16_ep100_ctxv1 CTP=end # class token position (end or middle) @@ -19,7 +20,7 @@ SUB=new for SEED in 1 2 3 do - COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} + COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED} MODEL_DIR=output/base2new/train_base/${COMMON_DIR} DIR=output/base2new/test_${SUB}/${COMMON_DIR} diff --git a/MSGCoOp/scripts/base2new_train.sh b/MSGCoOp/scripts/base2new_train.sh index 94d54f0..c9651ea 100644 --- a/MSGCoOp/scripts/base2new_train.sh +++ b/MSGCoOp/scripts/base2new_train.sh @@ -6,6 +6,7 @@ TRAINER=$1 DATASET=$2 KG_WEIGHT=$3 MP_WEIGHT=$4 +ATTN_REG_WEIGHT=$5 N_PROMPTS=4 #CFG=rn50_ep100 # config file\ CFG=vit_b16_ep100_ctxv1 @@ -16,7 +17,7 @@ CSC=False # class-specific context (False or True) for SEED in 1 2 3 do - DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} + DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else @@ -35,6 +36,7 @@ do DATASET.NUM_SHOTS ${SHOTS} \ DATASET.SUBSAMPLE_CLASSES base \ TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \ - TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} + TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \ + TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} fi done diff --git a/MSGCoOp/train.py b/MSGCoOp/train.py index 1258a76..db4a054 100644 --- a/MSGCoOp/train.py +++ b/MSGCoOp/train.py @@ -105,6 +105,7 @@ def extend_cfg(cfg): cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp cfg.TRAINER.COOP.DIV_WEIGHT = 0.1 cfg.TRAINER.COOP.N_PROMPTS = 3 + cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT = 0.01 cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new """ diff --git a/MSGCoOp/trainers/msgcoop.py b/MSGCoOp/trainers/msgcoop.py index 0047355..8313e70 100644 --- a/MSGCoOp/trainers/msgcoop.py +++ b/MSGCoOp/trainers/msgcoop.py @@ -223,6 +223,28 @@ class Adapter(nn.Module): x = self.fc(x) return x +class AttentionBasedIntegrator(nn.Module): + def __init__(self, img_dim=512, n_prompts=4, dtype=None): + super().__init__() + self.attention = nn.Sequential( + nn.Linear(img_dim, img_dim // 4), + nn.Tanh(), + nn.Linear(img_dim // 4, n_prompts) + ) + self.dtype = dtype + + if dtype is not None: + self.attention = self.attention.to(dtype) + + def forward(self, image_features, all_logits): + attn_scores = self.attention(image_features) + + attn_weights = F.softmax(attn_scores, dim=-1) + + weighted_logits = torch.einsum('bp,pbc->bc', attn_weights, all_logits) + + return weighted_logits, attn_weights + class CustomCLIP(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() @@ -236,6 +258,12 @@ class CustomCLIP(nn.Module): self.dtype = clip_model.dtype self.meta_net = self.prompt_learner.meta_net self.adapter = Adapter(512, 4).to(clip_model.dtype) + + self.prompt_integrator = AttentionBasedIntegrator( + img_dim=clip_model.visual.output_dim, + n_prompts=self.n_prompts, + dtype=clip_model.dtype + ) def compute_diversity_loss(self, text_features): if self.n_prompts == 1: @@ -283,10 +311,12 @@ class CustomCLIP(nn.Module): text_features_i = text_features_i / text_features_i.norm(dim=-1, keepdim=True) logits_i = logit_scale * image_features @ text_features_i.t() all_logits.append(logits_i) + + all_logits = torch.stack(all_logits) + + logits, attn_weights = self.prompt_integrator(image_features, all_logits) - logits = torch.stack(all_logits).mean(dim=0) - - return logits, score, diversity_loss + return logits, score, diversity_loss, attn_weights @TRAINER_REGISTRY.register() @@ -310,10 +340,11 @@ class MSGCoOp(TrainerX): self.model = CustomCLIP(cfg, classnames, clip_model) self.w = cfg.TRAINER.COOP.W self.diversity_weight = cfg.TRAINER.COOP.DIV_WEIGHT + self.attn_reg_weight = cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT if hasattr(cfg.TRAINER.COOP, 'ATTENTION_REG_WEIGHT') else 0.01 print("Turning off gradients in both the image and the text encoder") for name, param in self.model.named_parameters(): - if "ctx" not in name: + if "ctx" not in name and "prompt_integrator" not in name: param.requires_grad_(False) else: print(name) @@ -322,8 +353,10 @@ class MSGCoOp(TrainerX): load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) - # NOTE: only give prompt_learner to the optimizer - self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) + + # NOTE: give prompt_learner and prompt_integrator to the optimizer + trainable_params = list(self.model.prompt_learner.parameters()) + list(self.model.prompt_integrator.parameters()) + self.optim = build_optimizer([{'params': trainable_params}], cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) @@ -352,8 +385,12 @@ class MSGCoOp(TrainerX): self.scaler.step(self.optim) self.scaler.update() else: - output, score, diversity_loss = self.model(image) - loss = F.cross_entropy(output, label)+self.w*score + diversity_loss * self.diversity_weight + output, score, diversity_loss, attn_weights = self.model(image) + + # Add attention regularization to encourage balanced prompt usage + attn_reg = -(attn_weights * torch.log(attn_weights + 1e-8)).mean() + + loss = F.cross_entropy(output, label) + self.w * score + diversity_loss * self.diversity_weight + self.attn_reg_weight * attn_reg self.model_backward_and_update(loss) loss_summary = {