1 Commits

Author SHA1 Message Date
miunangel
fa24c48109 Uncertain fuse 2026-02-01 20:52:22 +08:00
9 changed files with 217 additions and 62 deletions

View File

@@ -31,4 +31,3 @@ MODEL:
TRAINER: TRAINER:
COOP: COOP:
CTX_INIT: True CTX_INIT: True
ATTENTION_REG_WEIGHT: 0.01

View File

@@ -3,26 +3,36 @@
TRAINER=$1 TRAINER=$1
KG_WEIGHT=$2 KG_WEIGHT=$2
MP_WEIGHT=$3 MP_WEIGHT=$3
ATTN_REG_WEIGHT=$4
# Define datasets array CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT}
datasets=( CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT}
"ucf101"
"eurosat"
"oxford_pets"
"food101"
"oxford_flowers"
"dtd"
"caltech101"
"fgvc_aircraft"
"stanford_cars"
"sun397"
"imagenet"
)
# Loop through datasets CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT}
for dataset in "${datasets[@]}"; do CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT}
done
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT}

View File

@@ -0,0 +1,39 @@
#!/bin/bash
TRAINER=$1
KG_WEIGHT=$2
MP_WEIGHT=$3
UNC_TEMPERATURE=$4
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train_unc.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test_unc.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT} ${UNC_TEMPERATURE}

View File

@@ -7,7 +7,6 @@ DATASET=$2
N_PROMPTS=4 N_PROMPTS=4
KG_WEIGHT=$3 KG_WEIGHT=$3
MP_WEIGHT=$4 MP_WEIGHT=$4
ATTN_REG_WEIGHT=$5
#CFG=rn50_ep100 # config file #CFG=rn50_ep100 # config file
CFG=vit_b16_ep100_ctxv1 CFG=vit_b16_ep100_ctxv1
CTP=end # class token position (end or middle) CTP=end # class token position (end or middle)
@@ -20,7 +19,7 @@ SUB=new
for SEED in 1 2 3 for SEED in 1 2 3
do do
COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED} COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=output/base2new/train_base/${COMMON_DIR} MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
DIR=output/base2new/test_${SUB}/${COMMON_DIR} DIR=output/base2new/test_${SUB}/${COMMON_DIR}

View File

@@ -0,0 +1,51 @@
#!/bin/bash
# custom config
DATA=~/Datasets/CoOp
TRAINER=$1
DATASET=$2
N_PROMPTS=4
KG_WEIGHT=$3
MP_WEIGHT=$4
UNC_TEMPERATURE=$5
#CFG=rn50_ep100 # config file
CFG=vit_b16_ep100_ctxv1
CTP=end # class token position (end or middle)
NCTX=4 # number of context tokens
SHOTS=16 # number of shots (1, 2, 4, 8, 16)
CSC=False # class-specific context (False or True)
LOADEP=100
SUB=new
for SEED in 1 2 3
do
COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_unc${UNC_TEMPERATURE}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
DIR=output/base2new/test_${SUB}/${COMMON_DIR}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB} \
TRAINER.COOP.UNC_ENABLED True \
TRAINER.COOP.UNC_TEMPERATURE ${UNC_TEMPERATURE}
fi
done

View File

@@ -6,7 +6,6 @@ TRAINER=$1
DATASET=$2 DATASET=$2
KG_WEIGHT=$3 KG_WEIGHT=$3
MP_WEIGHT=$4 MP_WEIGHT=$4
ATTN_REG_WEIGHT=$5
N_PROMPTS=4 N_PROMPTS=4
#CFG=rn50_ep100 # config file\ #CFG=rn50_ep100 # config file\
CFG=vit_b16_ep100_ctxv1 CFG=vit_b16_ep100_ctxv1
@@ -17,7 +16,7 @@ CSC=False # class-specific context (False or True)
for SEED in 1 2 3 for SEED in 1 2 3
do do
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGH}/${TRAINER}/${CFG}/seed${SEED} DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job" echo "Results are available in ${DIR}. Skip this job"
else else
@@ -36,7 +35,6 @@ do
DATASET.NUM_SHOTS ${SHOTS} \ DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES base \ DATASET.SUBSAMPLE_CLASSES base \
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \ TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \ TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT}
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT}
fi fi
done done

View File

@@ -0,0 +1,43 @@
#!/bin/bash
# custom config
DATA=~/Datasets/CoOp
TRAINER=$1
DATASET=$2
KG_WEIGHT=$3
MP_WEIGHT=$4
UNC_TEMPERATURE=$5
N_PROMPTS=4
#CFG=rn50_ep100 # config file
CFG=vit_b16_ep100_ctxv1
CTP=end # class token position (end or middle)
NCTX=4 # number of context tokens
SHOTS=16 # number of shots (1, 2, 4, 8, 16)
CSC=False # class-specific context (False or True)
for SEED in 1 2 3
do
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_unc${UNC_TEMPERATURE}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
TRAINER.COOP.N_CTX ${NCTX} \
TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.W ${KG_WEIGHT} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES base \
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \
TRAINER.COOP.UNC_ENABLED True \
TRAINER.COOP.UNC_TEMPERATURE ${UNC_TEMPERATURE}
fi
done

View File

@@ -105,7 +105,10 @@ def extend_cfg(cfg):
cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.COOP.DIV_WEIGHT = 0.1 cfg.TRAINER.COOP.DIV_WEIGHT = 0.1
cfg.TRAINER.COOP.N_PROMPTS = 3 cfg.TRAINER.COOP.N_PROMPTS = 3
cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT = 0.01
# 不确定性集成配置
cfg.TRAINER.COOP.UNC_ENABLED = False # 是否启用基于熵的不确定性集成
cfg.TRAINER.COOP.UNC_TEMPERATURE = 1.0 # 控制权重分布的平滑度
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
""" """

View File

@@ -223,27 +223,40 @@ class Adapter(nn.Module):
x = self.fc(x) x = self.fc(x)
return x return x
class AttentionBasedIntegrator(nn.Module): class UncertaintyPromptIntegrator(nn.Module):
def __init__(self, img_dim=512, n_prompts=4, dtype=None): def __init__(self, temperature=1.0):
"""
基于预测熵的不确定性加权集成器
Args:
temperature: 控制权重分布的平滑度,值越大权重分布越平均
"""
super().__init__() super().__init__()
self.attention = nn.Sequential( self.temperature = temperature
nn.Linear(img_dim, img_dim // 4),
nn.Tanh(),
nn.Linear(img_dim // 4, n_prompts)
)
self.dtype = dtype
if dtype is not None: def forward(self, all_logits):
self.attention = self.attention.to(dtype) """
Args:
all_logits: [n_prompts, batch_size, n_classes]
def forward(self, image_features, all_logits): Returns:
attn_scores = self.attention(image_features) integrated_logits: [batch_size, n_classes]
prompt_weights: [n_prompts, batch_size]
entropy: [n_prompts, batch_size]
"""
n_prompts, batch_size, n_classes = all_logits.shape
attn_weights = F.softmax(attn_scores, dim=-1) log_probs = F.log_softmax(all_logits, dim=-1)
probs = log_probs.exp()
weighted_logits = torch.einsum('bp,pbc->bc', attn_weights, all_logits) entropy = -(probs * log_probs).sum(dim=-1)
return weighted_logits, attn_weights temperature = max(self.temperature, 1e-8)
weights = F.softmax(-entropy / temperature, dim=0)
integrated_logits = torch.einsum('pb,pbc->bc', weights, all_logits)
return integrated_logits, weights, entropy
class CustomCLIP(nn.Module): class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model): def __init__(self, cfg, classnames, clip_model):
@@ -259,11 +272,13 @@ class CustomCLIP(nn.Module):
self.meta_net = self.prompt_learner.meta_net self.meta_net = self.prompt_learner.meta_net
self.adapter = Adapter(512, 4).to(clip_model.dtype) self.adapter = Adapter(512, 4).to(clip_model.dtype)
self.prompt_integrator = AttentionBasedIntegrator( self.use_uncertainty_integration = cfg.TRAINER.COOP.get('UNC_ENABLED', False)
img_dim=clip_model.visual.output_dim, self.unc_temperature = cfg.TRAINER.COOP.get('UNC_TEMPERATURE', 1.0)
n_prompts=self.n_prompts,
dtype=clip_model.dtype if self.use_uncertainty_integration:
) self.unc_integrator = UncertaintyPromptIntegrator(
temperature=self.unc_temperature
)
def compute_diversity_loss(self, text_features): def compute_diversity_loss(self, text_features):
if self.n_prompts == 1: if self.n_prompts == 1:
@@ -314,9 +329,14 @@ class CustomCLIP(nn.Module):
all_logits = torch.stack(all_logits) all_logits = torch.stack(all_logits)
logits, attn_weights = self.prompt_integrator(image_features, all_logits) if self.use_uncertainty_integration:
logits, prompt_weights, entropy = self.unc_integrator(all_logits)
self.last_prompt_weights = prompt_weights.detach()
self.last_entropy = entropy.detach()
else:
logits = all_logits.mean(dim=0)
return logits, score, diversity_loss, attn_weights return logits, score, diversity_loss
@TRAINER_REGISTRY.register() @TRAINER_REGISTRY.register()
@@ -340,11 +360,10 @@ class MSGCoOp(TrainerX):
self.model = CustomCLIP(cfg, classnames, clip_model) self.model = CustomCLIP(cfg, classnames, clip_model)
self.w = cfg.TRAINER.COOP.W self.w = cfg.TRAINER.COOP.W
self.diversity_weight = cfg.TRAINER.COOP.DIV_WEIGHT self.diversity_weight = cfg.TRAINER.COOP.DIV_WEIGHT
self.attn_reg_weight = cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT if hasattr(cfg.TRAINER.COOP, 'ATTENTION_REG_WEIGHT') else 0.01
print("Turning off gradients in both the image and the text encoder") print("Turning off gradients in both the image and the text encoder")
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if "ctx" not in name and "prompt_integrator" not in name: if "ctx" not in name:
param.requires_grad_(False) param.requires_grad_(False)
else: else:
print(name) print(name)
@@ -353,10 +372,8 @@ class MSGCoOp(TrainerX):
load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device) self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
# NOTE: give prompt_learner and prompt_integrator to the optimizer self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
trainable_params = list(self.model.prompt_learner.parameters()) + list(self.model.prompt_integrator.parameters())
self.optim = build_optimizer([{'params': trainable_params}], cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
@@ -385,12 +402,8 @@ class MSGCoOp(TrainerX):
self.scaler.step(self.optim) self.scaler.step(self.optim)
self.scaler.update() self.scaler.update()
else: else:
output, score, diversity_loss, attn_weights = self.model(image) output, score, diversity_loss = self.model(image)
loss = F.cross_entropy(output, label)+self.w*score + diversity_loss * self.diversity_weight
# Add attention regularization to encourage balanced prompt usage
attn_reg = -(attn_weights * torch.log(attn_weights + 1e-8)).mean()
loss = F.cross_entropy(output, label) + self.w * score + diversity_loss * self.diversity_weight + self.attn_reg_weight * attn_reg
self.model_backward_and_update(loss) self.model_backward_and_update(loss)
loss_summary = { loss_summary = {