Compare commits
1 Commits
fuse
...
attn_cert_
| Author | SHA1 | Date | |
|---|---|---|---|
| 17c6c4c309 |
@@ -35,12 +35,12 @@ Follow [DATASETS.md](DATASETS.md) to install the datasets.
|
||||
|
||||
## Generalization From Base to New Classes
|
||||
|
||||
You will need `base2new_train.sh`, `base2new_test.sh`, and `base2new_all.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have three input argument, i.e., `TRAINER SG_WEIGHT DIV_WEIGHT`.
|
||||
You will need `base2new_train.sh`, `base2new_test.sh`, and `base2new_all.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have three input argument, i.e., `TRAINER SG_WEIGHT DIV_WEIGHT ATTN_REG_WEIGHT UNCERTAINTY_SCALE`.
|
||||
|
||||
You can run base to new on all datasets as follow:
|
||||
|
||||
```bash
|
||||
bash scripts/base2new_all.sh MSGCoOp 8.0 1.0
|
||||
bash scripts/base2new_all.sh MSGCoOp 8.0 1.0 0.01 0.5
|
||||
```
|
||||
|
||||
When the evaluation is done, you can use `extract_acc.py` (replace the `root_dir` in the `main` function to your output dir) to automatically calculate the average results. For instance, after you finish the trainning using the aforementioned commands, you would get
|
||||
@@ -91,13 +91,13 @@ Then, you will get the avarage accuracy.
|
||||
Fisrt, you need train on all classes over ImageNet:
|
||||
|
||||
```bash
|
||||
bash scripts/xd_train.sh MSGCoOp 8.0 1.0
|
||||
bash scripts/xd_train.sh MSGCoOp 8.0 1.0 0.01 0.5
|
||||
```
|
||||
|
||||
Then you can evaluate the performance on other ImageNet variants by run:
|
||||
|
||||
```bash
|
||||
bash scripts/xdo_test.sh MSGCoOp 8.0 1.0
|
||||
bash scripts/xdo_test.sh MSGCoOp 8.0 1.0 0.01 0.5
|
||||
```
|
||||
|
||||
And you will get the `output_xdo` after script finish. You can get the accuracy by `extract_acc.py` (need modify the `root_dir` to `output_xdo` ).
|
||||
|
||||
@@ -32,3 +32,4 @@ TRAINER:
|
||||
COOP:
|
||||
CTX_INIT: True
|
||||
ATTENTION_REG_WEIGHT: 0.01
|
||||
UNCERTAINTY_SCALE: 0.5
|
||||
|
||||
@@ -98,7 +98,7 @@ def print_model_results(results, model_name):
|
||||
print("No complete dataset results found for this model.")
|
||||
|
||||
def main():
|
||||
root_dir = 'output_xda' # 修改为你的output目录路径
|
||||
root_dir = 'output' # 修改为你的output目录路径
|
||||
target_model = 'MSGCoOp' # 指定要分析的模型
|
||||
|
||||
results = collect_model_results(root_dir, target_model)
|
||||
|
||||
@@ -4,6 +4,7 @@ TRAINER=$1
|
||||
KG_WEIGHT=$2
|
||||
MP_WEIGHT=$3
|
||||
ATTN_REG_WEIGHT=$4
|
||||
UNCERTAINTY_SCALE=$5
|
||||
|
||||
# Define datasets array
|
||||
datasets=(
|
||||
@@ -22,7 +23,7 @@ datasets=(
|
||||
|
||||
# Loop through datasets
|
||||
for dataset in "${datasets[@]}"; do
|
||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT}
|
||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT}
|
||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} ${UNCERTAINTY_SCALE}
|
||||
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} ${UNCERTAINTY_SCALE}
|
||||
done
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ N_PROMPTS=4
|
||||
KG_WEIGHT=$3
|
||||
MP_WEIGHT=$4
|
||||
ATTN_REG_WEIGHT=$5
|
||||
UNCERTAINTY_SCALE=$6
|
||||
#CFG=rn50_ep100 # config file
|
||||
CFG=vit_b16_ep100_ctxv1
|
||||
CTP=end # class token position (end or middle)
|
||||
@@ -20,7 +21,7 @@ SUB=new
|
||||
|
||||
for SEED in 1 2 3
|
||||
do
|
||||
COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED}
|
||||
COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
|
||||
MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
|
||||
DIR=output/base2new/test_${SUB}/${COMMON_DIR}
|
||||
|
||||
@@ -44,6 +45,7 @@ do
|
||||
TRAINER.COOP.CSC ${CSC} \
|
||||
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
|
||||
DATASET.NUM_SHOTS ${SHOTS} \
|
||||
DATASET.SUBSAMPLE_CLASSES ${SUB}
|
||||
DATASET.SUBSAMPLE_CLASSES ${SUB} \
|
||||
TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE}
|
||||
fi
|
||||
done
|
||||
|
||||
@@ -7,6 +7,7 @@ DATASET=$2
|
||||
KG_WEIGHT=$3
|
||||
MP_WEIGHT=$4
|
||||
ATTN_REG_WEIGHT=$5
|
||||
UNCERTAINTY_SCALE=$6
|
||||
N_PROMPTS=4
|
||||
#CFG=rn50_ep100 # config file\
|
||||
CFG=vit_b16_ep100_ctxv1
|
||||
@@ -17,7 +18,7 @@ CSC=False # class-specific context (False or True)
|
||||
|
||||
for SEED in 1 2 3
|
||||
do
|
||||
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED}
|
||||
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
|
||||
if [ -d "$DIR" ]; then
|
||||
echo "Results are available in ${DIR}. Skip this job"
|
||||
else
|
||||
@@ -34,9 +35,10 @@ do
|
||||
TRAINER.COOP.W ${KG_WEIGHT} \
|
||||
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
|
||||
DATASET.NUM_SHOTS ${SHOTS} \
|
||||
DATASET.SUBSAMPLE_CLASSES base \
|
||||
DATASET.SUBSAMPLE_CLASSES base \
|
||||
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
|
||||
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \
|
||||
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT}
|
||||
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \
|
||||
TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE}
|
||||
fi
|
||||
done
|
||||
|
||||
@@ -6,6 +6,8 @@ TRAINER=$1
|
||||
N_PROMPTS=3
|
||||
KG_WEIGHT=$2
|
||||
MP_WEIGHT=$3
|
||||
ATTN_REG_WEIGHT=$4
|
||||
UNCERTAINTY_SCALE=$5
|
||||
CFG=vit_b16_ep100_ctxv1
|
||||
CTP=end # class token position (end or middle)
|
||||
NCTX=4 # number of context tokens
|
||||
@@ -18,7 +20,7 @@ for DATASET in ${SRC_DATASETS}
|
||||
do
|
||||
for SEED in 1 2 3
|
||||
do
|
||||
DIR=output_xd/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
|
||||
DIR=output_xd/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
|
||||
if [ -d "$DIR" ]; then
|
||||
echo "Results are available in ${DIR}. Skip this job"
|
||||
else
|
||||
@@ -36,7 +38,9 @@ do
|
||||
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
|
||||
DATASET.NUM_SHOTS ${SHOTS} \
|
||||
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
|
||||
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT}
|
||||
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \
|
||||
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \
|
||||
TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE}
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
@@ -6,6 +6,8 @@ TRAINER=$1
|
||||
N_PROMPTS=3
|
||||
KG_WEIGHT=$2
|
||||
MP_WEIGHT=$3
|
||||
ATTN_REG_WEIGHT=$4
|
||||
UNCERTAINTY_SCALE=$5
|
||||
CFG=vit_b16_ep100_ctxv1
|
||||
CTP=end # class token position (end or middle)
|
||||
NCTX=4 # number of context tokens
|
||||
@@ -19,8 +21,8 @@ for DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r
|
||||
do
|
||||
for SEED in 1 2 3
|
||||
do
|
||||
MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
|
||||
DIR=output_xdo/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
|
||||
MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
|
||||
DIR=output_xdo/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
|
||||
if [ -d "$DIR" ]; then
|
||||
echo "Results are available in ${DIR}. Skip this job"
|
||||
else
|
||||
@@ -41,7 +43,9 @@ do
|
||||
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
|
||||
DATASET.NUM_SHOTS ${SHOTS} \
|
||||
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
|
||||
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT}
|
||||
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \
|
||||
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \
|
||||
TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE}
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
@@ -106,6 +106,7 @@ def extend_cfg(cfg):
|
||||
cfg.TRAINER.COOP.DIV_WEIGHT = 0.1
|
||||
cfg.TRAINER.COOP.N_PROMPTS = 3
|
||||
cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT = 0.01
|
||||
cfg.TRAINER.COOP.UNCERTAINTY_SCALE = 0.5
|
||||
|
||||
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
||||
"""
|
||||
|
||||
@@ -224,7 +224,7 @@ class Adapter(nn.Module):
|
||||
return x
|
||||
|
||||
class AttentionBasedIntegrator(nn.Module):
|
||||
def __init__(self, img_dim=512, n_prompts=4, dtype=None):
|
||||
def __init__(self, img_dim=512, n_prompts=4, dtype=None, uncertainty_scale=0.5):
|
||||
super().__init__()
|
||||
self.attention = nn.Sequential(
|
||||
nn.Linear(img_dim, img_dim // 4),
|
||||
@@ -232,18 +232,48 @@ class AttentionBasedIntegrator(nn.Module):
|
||||
nn.Linear(img_dim // 4, n_prompts)
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.uncertainty_scale = uncertainty_scale
|
||||
|
||||
if dtype is not None:
|
||||
self.attention = self.attention.to(dtype)
|
||||
|
||||
def forward(self, image_features, all_logits):
|
||||
batch_size = image_features.shape[0]
|
||||
n_prompts = all_logits.shape[0]
|
||||
|
||||
# 注意力权重
|
||||
attn_scores = self.attention(image_features)
|
||||
attn_weights = F.softmax(attn_scores, dim=-1) # [batch, n_prompts]
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
# 不确定性权重(基于预测熵)- 使用log-sum-exp技巧提高稳定性
|
||||
probs = F.softmax(all_logits, dim=-1) # [n_prompts, batch, n_classes]
|
||||
log_probs = F.log_softmax(all_logits, dim=-1)
|
||||
|
||||
weighted_logits = torch.einsum('bp,pbc->bc', attn_weights, all_logits)
|
||||
# 计算熵:H = -sum(p * log(p))
|
||||
entropy = -(probs * log_probs).sum(dim=-1) # [n_prompts, batch]
|
||||
|
||||
return weighted_logits, attn_weights
|
||||
# 转换为确定性(熵越小,确定性越大)
|
||||
certainty = -entropy.t() # [batch, n_prompts]
|
||||
|
||||
# 归一化不确定性权重(使用log-sum-exp提高稳定性)
|
||||
certainty_scaled = certainty * self.uncertainty_scale
|
||||
log_uncertainty_weights = certainty_scaled - torch.logsumexp(certainty_scaled, dim=-1, keepdim=True)
|
||||
uncertainty_weights = torch.exp(log_uncertainty_weights)
|
||||
|
||||
# 混合权重:结合注意力权重和不确定性权重(使用对数域)
|
||||
log_attn_weights = torch.log(attn_weights + 1e-8)
|
||||
log_uncertainty_weights = torch.log(uncertainty_weights + 1e-8)
|
||||
|
||||
# 在对数域中相加,然后指数化
|
||||
log_hybrid_weights = torch.log(torch.tensor(0.5, device=attn_weights.device)) + \
|
||||
log_attn_weights + log_uncertainty_weights
|
||||
log_hybrid_weights = log_hybrid_weights - torch.logsumexp(log_hybrid_weights, dim=-1, keepdim=True)
|
||||
hybrid_weights = torch.exp(log_hybrid_weights)
|
||||
|
||||
# 加权集成
|
||||
weighted_logits = torch.einsum('bp,pbc->bc', hybrid_weights, all_logits)
|
||||
|
||||
return weighted_logits, hybrid_weights, entropy
|
||||
|
||||
class CustomCLIP(nn.Module):
|
||||
def __init__(self, cfg, classnames, clip_model):
|
||||
@@ -259,10 +289,13 @@ class CustomCLIP(nn.Module):
|
||||
self.meta_net = self.prompt_learner.meta_net
|
||||
self.adapter = Adapter(512, 4).to(clip_model.dtype)
|
||||
|
||||
uncertainty_scale = getattr(cfg.TRAINER.COOP, 'UNCERTAINTY_SCALE', 0.5)
|
||||
|
||||
self.prompt_integrator = AttentionBasedIntegrator(
|
||||
img_dim=clip_model.visual.output_dim,
|
||||
n_prompts=self.n_prompts,
|
||||
dtype=clip_model.dtype
|
||||
dtype=clip_model.dtype,
|
||||
uncertainty_scale=uncertainty_scale
|
||||
)
|
||||
|
||||
def compute_diversity_loss(self, text_features):
|
||||
@@ -314,9 +347,9 @@ class CustomCLIP(nn.Module):
|
||||
|
||||
all_logits = torch.stack(all_logits)
|
||||
|
||||
logits, attn_weights = self.prompt_integrator(image_features, all_logits)
|
||||
logits, hybrid_weights, entropy = self.prompt_integrator(image_features, all_logits)
|
||||
|
||||
return logits, score, diversity_loss, attn_weights
|
||||
return logits, score, diversity_loss, hybrid_weights, entropy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
@@ -385,10 +418,10 @@ class MSGCoOp(TrainerX):
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
else:
|
||||
output, score, diversity_loss, attn_weights = self.model(image)
|
||||
output, score, diversity_loss, hybrid_weights, entropy = self.model(image)
|
||||
|
||||
# Add attention regularization to encourage balanced prompt usage
|
||||
attn_reg = -(attn_weights * torch.log(attn_weights + 1e-8)).mean()
|
||||
attn_reg = -(hybrid_weights * torch.log(hybrid_weights + 1e-8)).mean()
|
||||
|
||||
loss = F.cross_entropy(output, label) + self.w * score + diversity_loss * self.diversity_weight + self.attn_reg_weight * attn_reg
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
Reference in New Issue
Block a user