From 17c6c4c30932ea29b0b5e1e707980369d897e141 Mon Sep 17 00:00:00 2001 From: rain-bus Date: Sun, 1 Feb 2026 17:48:28 +0800 Subject: [PATCH] Uncertain fuse --- MSGCoOp/README.md | 8 +-- .../trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml | 1 + MSGCoOp/extract_acc.py | 2 +- MSGCoOp/scripts/base2new_all.sh | 5 +- MSGCoOp/scripts/base2new_test.sh | 6 ++- MSGCoOp/scripts/base2new_train.sh | 8 +-- MSGCoOp/scripts/xd_train.sh | 8 ++- MSGCoOp/scripts/xdo_test.sh | 10 ++-- MSGCoOp/train.py | 1 + MSGCoOp/trainers/msgcoop.py | 51 +++++++++++++++---- 10 files changed, 74 insertions(+), 26 deletions(-) diff --git a/MSGCoOp/README.md b/MSGCoOp/README.md index 363032f..0239966 100644 --- a/MSGCoOp/README.md +++ b/MSGCoOp/README.md @@ -35,12 +35,12 @@ Follow [DATASETS.md](DATASETS.md) to install the datasets. ## Generalization From Base to New Classes -You will need `base2new_train.sh`, `base2new_test.sh`, and `base2new_all.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have three input argument, i.e., `TRAINER SG_WEIGHT DIV_WEIGHT`. +You will need `base2new_train.sh`, `base2new_test.sh`, and `base2new_all.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have three input argument, i.e., `TRAINER SG_WEIGHT DIV_WEIGHT ATTN_REG_WEIGHT UNCERTAINTY_SCALE`. You can run base to new on all datasets as follow: ```bash -bash scripts/base2new_all.sh MSGCoOp 8.0 1.0 +bash scripts/base2new_all.sh MSGCoOp 8.0 1.0 0.01 0.5 ``` When the evaluation is done, you can use `extract_acc.py` (replace the `root_dir` in the `main` function to your output dir) to automatically calculate the average results. For instance, after you finish the trainning using the aforementioned commands, you would get @@ -91,13 +91,13 @@ Then, you will get the avarage accuracy. Fisrt, you need train on all classes over ImageNet: ```bash -bash scripts/xd_train.sh MSGCoOp 8.0 1.0 +bash scripts/xd_train.sh MSGCoOp 8.0 1.0 0.01 0.5 ``` Then you can evaluate the performance on other ImageNet variants by run: ```bash -bash scripts/xdo_test.sh MSGCoOp 8.0 1.0 +bash scripts/xdo_test.sh MSGCoOp 8.0 1.0 0.01 0.5 ``` And you will get the `output_xdo` after script finish. You can get the accuracy by `extract_acc.py` (need modify the `root_dir` to `output_xdo` ). diff --git a/MSGCoOp/configs/trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml b/MSGCoOp/configs/trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml index a7e11f9..734687f 100644 --- a/MSGCoOp/configs/trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml +++ b/MSGCoOp/configs/trainers/MSGCoOp/vit_b16_ep100_ctxv1.yaml @@ -32,3 +32,4 @@ TRAINER: COOP: CTX_INIT: True ATTENTION_REG_WEIGHT: 0.01 + UNCERTAINTY_SCALE: 0.5 diff --git a/MSGCoOp/extract_acc.py b/MSGCoOp/extract_acc.py index ef2f12c..f2d276e 100644 --- a/MSGCoOp/extract_acc.py +++ b/MSGCoOp/extract_acc.py @@ -98,7 +98,7 @@ def print_model_results(results, model_name): print("No complete dataset results found for this model.") def main(): - root_dir = 'output_xda' # 修改为你的output目录路径 + root_dir = 'output' # 修改为你的output目录路径 target_model = 'MSGCoOp' # 指定要分析的模型 results = collect_model_results(root_dir, target_model) diff --git a/MSGCoOp/scripts/base2new_all.sh b/MSGCoOp/scripts/base2new_all.sh index 6e73bfc..87a9dac 100644 --- a/MSGCoOp/scripts/base2new_all.sh +++ b/MSGCoOp/scripts/base2new_all.sh @@ -4,6 +4,7 @@ TRAINER=$1 KG_WEIGHT=$2 MP_WEIGHT=$3 ATTN_REG_WEIGHT=$4 +UNCERTAINTY_SCALE=$5 # Define datasets array datasets=( @@ -22,7 +23,7 @@ datasets=( # Loop through datasets for dataset in "${datasets[@]}"; do - CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} - CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} + CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} ${UNCERTAINTY_SCALE} + CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} ${UNCERTAINTY_SCALE} done diff --git a/MSGCoOp/scripts/base2new_test.sh b/MSGCoOp/scripts/base2new_test.sh index 662eec9..94d7513 100644 --- a/MSGCoOp/scripts/base2new_test.sh +++ b/MSGCoOp/scripts/base2new_test.sh @@ -8,6 +8,7 @@ N_PROMPTS=4 KG_WEIGHT=$3 MP_WEIGHT=$4 ATTN_REG_WEIGHT=$5 +UNCERTAINTY_SCALE=$6 #CFG=rn50_ep100 # config file CFG=vit_b16_ep100_ctxv1 CTP=end # class token position (end or middle) @@ -20,7 +21,7 @@ SUB=new for SEED in 1 2 3 do - COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED} + COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED} MODEL_DIR=output/base2new/train_base/${COMMON_DIR} DIR=output/base2new/test_${SUB}/${COMMON_DIR} @@ -44,6 +45,7 @@ do TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} \ - DATASET.SUBSAMPLE_CLASSES ${SUB} + DATASET.SUBSAMPLE_CLASSES ${SUB} \ + TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE} fi done diff --git a/MSGCoOp/scripts/base2new_train.sh b/MSGCoOp/scripts/base2new_train.sh index c9651ea..6fd579d 100644 --- a/MSGCoOp/scripts/base2new_train.sh +++ b/MSGCoOp/scripts/base2new_train.sh @@ -7,6 +7,7 @@ DATASET=$2 KG_WEIGHT=$3 MP_WEIGHT=$4 ATTN_REG_WEIGHT=$5 +UNCERTAINTY_SCALE=$6 N_PROMPTS=4 #CFG=rn50_ep100 # config file\ CFG=vit_b16_ep100_ctxv1 @@ -17,7 +18,7 @@ CSC=False # class-specific context (False or True) for SEED in 1 2 3 do - DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED} + DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else @@ -34,9 +35,10 @@ do TRAINER.COOP.W ${KG_WEIGHT} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} \ - DATASET.SUBSAMPLE_CLASSES base \ + DATASET.SUBSAMPLE_CLASSES base \ TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \ TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \ - TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} + TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \ + TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE} fi done diff --git a/MSGCoOp/scripts/xd_train.sh b/MSGCoOp/scripts/xd_train.sh index ec6170a..567637b 100644 --- a/MSGCoOp/scripts/xd_train.sh +++ b/MSGCoOp/scripts/xd_train.sh @@ -6,6 +6,8 @@ TRAINER=$1 N_PROMPTS=3 KG_WEIGHT=$2 MP_WEIGHT=$3 +ATTN_REG_WEIGHT=$4 +UNCERTAINTY_SCALE=$5 CFG=vit_b16_ep100_ctxv1 CTP=end # class token position (end or middle) NCTX=4 # number of context tokens @@ -18,7 +20,7 @@ for DATASET in ${SRC_DATASETS} do for SEED in 1 2 3 do - DIR=output_xd/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} + DIR=output_xd/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else @@ -36,7 +38,9 @@ do TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} \ TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \ - TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} + TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \ + TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \ + TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE} fi done done diff --git a/MSGCoOp/scripts/xdo_test.sh b/MSGCoOp/scripts/xdo_test.sh index c8f93e4..8198028 100644 --- a/MSGCoOp/scripts/xdo_test.sh +++ b/MSGCoOp/scripts/xdo_test.sh @@ -6,6 +6,8 @@ TRAINER=$1 N_PROMPTS=3 KG_WEIGHT=$2 MP_WEIGHT=$3 +ATTN_REG_WEIGHT=$4 +UNCERTAINTY_SCALE=$5 CFG=vit_b16_ep100_ctxv1 CTP=end # class token position (end or middle) NCTX=4 # number of context tokens @@ -19,8 +21,8 @@ for DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r do for SEED in 1 2 3 do - MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} - DIR=output_xdo/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} + MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED} + DIR=output_xdo/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED} if [ -d "$DIR" ]; then echo "Results are available in ${DIR}. Skip this job" else @@ -41,7 +43,9 @@ do TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ DATASET.NUM_SHOTS ${SHOTS} \ TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \ - TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} + TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \ + TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \ + TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE} fi done done diff --git a/MSGCoOp/train.py b/MSGCoOp/train.py index db4a054..22b0896 100644 --- a/MSGCoOp/train.py +++ b/MSGCoOp/train.py @@ -106,6 +106,7 @@ def extend_cfg(cfg): cfg.TRAINER.COOP.DIV_WEIGHT = 0.1 cfg.TRAINER.COOP.N_PROMPTS = 3 cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT = 0.01 + cfg.TRAINER.COOP.UNCERTAINTY_SCALE = 0.5 cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new """ diff --git a/MSGCoOp/trainers/msgcoop.py b/MSGCoOp/trainers/msgcoop.py index 8313e70..26852c3 100644 --- a/MSGCoOp/trainers/msgcoop.py +++ b/MSGCoOp/trainers/msgcoop.py @@ -224,7 +224,7 @@ class Adapter(nn.Module): return x class AttentionBasedIntegrator(nn.Module): - def __init__(self, img_dim=512, n_prompts=4, dtype=None): + def __init__(self, img_dim=512, n_prompts=4, dtype=None, uncertainty_scale=0.5): super().__init__() self.attention = nn.Sequential( nn.Linear(img_dim, img_dim // 4), @@ -232,18 +232,48 @@ class AttentionBasedIntegrator(nn.Module): nn.Linear(img_dim // 4, n_prompts) ) self.dtype = dtype + self.uncertainty_scale = uncertainty_scale if dtype is not None: self.attention = self.attention.to(dtype) def forward(self, image_features, all_logits): + batch_size = image_features.shape[0] + n_prompts = all_logits.shape[0] + + # 注意力权重 attn_scores = self.attention(image_features) + attn_weights = F.softmax(attn_scores, dim=-1) # [batch, n_prompts] - attn_weights = F.softmax(attn_scores, dim=-1) + # 不确定性权重(基于预测熵)- 使用log-sum-exp技巧提高稳定性 + probs = F.softmax(all_logits, dim=-1) # [n_prompts, batch, n_classes] + log_probs = F.log_softmax(all_logits, dim=-1) - weighted_logits = torch.einsum('bp,pbc->bc', attn_weights, all_logits) + # 计算熵:H = -sum(p * log(p)) + entropy = -(probs * log_probs).sum(dim=-1) # [n_prompts, batch] - return weighted_logits, attn_weights + # 转换为确定性(熵越小,确定性越大) + certainty = -entropy.t() # [batch, n_prompts] + + # 归一化不确定性权重(使用log-sum-exp提高稳定性) + certainty_scaled = certainty * self.uncertainty_scale + log_uncertainty_weights = certainty_scaled - torch.logsumexp(certainty_scaled, dim=-1, keepdim=True) + uncertainty_weights = torch.exp(log_uncertainty_weights) + + # 混合权重:结合注意力权重和不确定性权重(使用对数域) + log_attn_weights = torch.log(attn_weights + 1e-8) + log_uncertainty_weights = torch.log(uncertainty_weights + 1e-8) + + # 在对数域中相加,然后指数化 + log_hybrid_weights = torch.log(torch.tensor(0.5, device=attn_weights.device)) + \ + log_attn_weights + log_uncertainty_weights + log_hybrid_weights = log_hybrid_weights - torch.logsumexp(log_hybrid_weights, dim=-1, keepdim=True) + hybrid_weights = torch.exp(log_hybrid_weights) + + # 加权集成 + weighted_logits = torch.einsum('bp,pbc->bc', hybrid_weights, all_logits) + + return weighted_logits, hybrid_weights, entropy class CustomCLIP(nn.Module): def __init__(self, cfg, classnames, clip_model): @@ -259,10 +289,13 @@ class CustomCLIP(nn.Module): self.meta_net = self.prompt_learner.meta_net self.adapter = Adapter(512, 4).to(clip_model.dtype) + uncertainty_scale = getattr(cfg.TRAINER.COOP, 'UNCERTAINTY_SCALE', 0.5) + self.prompt_integrator = AttentionBasedIntegrator( img_dim=clip_model.visual.output_dim, n_prompts=self.n_prompts, - dtype=clip_model.dtype + dtype=clip_model.dtype, + uncertainty_scale=uncertainty_scale ) def compute_diversity_loss(self, text_features): @@ -314,9 +347,9 @@ class CustomCLIP(nn.Module): all_logits = torch.stack(all_logits) - logits, attn_weights = self.prompt_integrator(image_features, all_logits) + logits, hybrid_weights, entropy = self.prompt_integrator(image_features, all_logits) - return logits, score, diversity_loss, attn_weights + return logits, score, diversity_loss, hybrid_weights, entropy @TRAINER_REGISTRY.register() @@ -385,10 +418,10 @@ class MSGCoOp(TrainerX): self.scaler.step(self.optim) self.scaler.update() else: - output, score, diversity_loss, attn_weights = self.model(image) + output, score, diversity_loss, hybrid_weights, entropy = self.model(image) # Add attention regularization to encourage balanced prompt usage - attn_reg = -(attn_weights * torch.log(attn_weights + 1e-8)).mean() + attn_reg = -(hybrid_weights * torch.log(hybrid_weights + 1e-8)).mean() loss = F.cross_entropy(output, label) + self.w * score + diversity_loss * self.diversity_weight + self.attn_reg_weight * attn_reg self.model_backward_and_update(loss)