2 Commits

Author SHA1 Message Date
17c6c4c309 Uncertain fuse 2026-02-01 17:48:47 +08:00
0ba13ffbbd Attn fuse 2026-01-31 23:48:05 +08:00
10 changed files with 133 additions and 53 deletions

View File

@@ -35,12 +35,12 @@ Follow [DATASETS.md](DATASETS.md) to install the datasets.
## Generalization From Base to New Classes ## Generalization From Base to New Classes
You will need `base2new_train.sh`, `base2new_test.sh`, and `base2new_all.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have three input argument, i.e., `TRAINER SG_WEIGHT DIV_WEIGHT`. You will need `base2new_train.sh`, `base2new_test.sh`, and `base2new_all.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have three input argument, i.e., `TRAINER SG_WEIGHT DIV_WEIGHT ATTN_REG_WEIGHT UNCERTAINTY_SCALE`.
You can run base to new on all datasets as follow: You can run base to new on all datasets as follow:
```bash ```bash
bash scripts/base2new_all.sh MSGCoOp 8.0 1.0 bash scripts/base2new_all.sh MSGCoOp 8.0 1.0 0.01 0.5
``` ```
When the evaluation is done, you can use `extract_acc.py` (replace the `root_dir` in the `main` function to your output dir) to automatically calculate the average results. For instance, after you finish the trainning using the aforementioned commands, you would get When the evaluation is done, you can use `extract_acc.py` (replace the `root_dir` in the `main` function to your output dir) to automatically calculate the average results. For instance, after you finish the trainning using the aforementioned commands, you would get
@@ -91,13 +91,13 @@ Then, you will get the avarage accuracy.
Fisrt, you need train on all classes over ImageNet: Fisrt, you need train on all classes over ImageNet:
```bash ```bash
bash scripts/xd_train.sh MSGCoOp 8.0 1.0 bash scripts/xd_train.sh MSGCoOp 8.0 1.0 0.01 0.5
``` ```
Then you can evaluate the performance on other ImageNet variants by run: Then you can evaluate the performance on other ImageNet variants by run:
```bash ```bash
bash scripts/xdo_test.sh MSGCoOp 8.0 1.0 bash scripts/xdo_test.sh MSGCoOp 8.0 1.0 0.01 0.5
``` ```
And you will get the `output_xdo` after script finish. You can get the accuracy by `extract_acc.py` (need modify the `root_dir` to `output_xdo` ). And you will get the `output_xdo` after script finish. You can get the accuracy by `extract_acc.py` (need modify the `root_dir` to `output_xdo` ).

View File

@@ -31,3 +31,5 @@ MODEL:
TRAINER: TRAINER:
COOP: COOP:
CTX_INIT: True CTX_INIT: True
ATTENTION_REG_WEIGHT: 0.01
UNCERTAINTY_SCALE: 0.5

View File

@@ -98,7 +98,7 @@ def print_model_results(results, model_name):
print("No complete dataset results found for this model.") print("No complete dataset results found for this model.")
def main(): def main():
root_dir = 'output_xda' # 修改为你的output目录路径 root_dir = 'output' # 修改为你的output目录路径
target_model = 'MSGCoOp' # 指定要分析的模型 target_model = 'MSGCoOp' # 指定要分析的模型
results = collect_model_results(root_dir, target_model) results = collect_model_results(root_dir, target_model)

View File

@@ -3,36 +3,27 @@
TRAINER=$1 TRAINER=$1
KG_WEIGHT=$2 KG_WEIGHT=$2
MP_WEIGHT=$3 MP_WEIGHT=$3
ATTN_REG_WEIGHT=$4
UNCERTAINTY_SCALE=$5
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT} # Define datasets array
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ucf101 ${KG_WEIGHT} ${MP_WEIGHT} datasets=(
"ucf101"
"eurosat"
"oxford_pets"
"food101"
"oxford_flowers"
"dtd"
"caltech101"
"fgvc_aircraft"
"stanford_cars"
"sun397"
"imagenet"
)
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT} # Loop through datasets
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} eurosat ${KG_WEIGHT} ${MP_WEIGHT} for dataset in "${datasets[@]}"; do
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} ${UNCERTAINTY_SCALE}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} ${dataset} ${KG_WEIGHT} ${MP_WEIGHT} ${ATTN_REG_WEIGHT} ${UNCERTAINTY_SCALE}
done
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} oxford_pets ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} food101 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} oxford_flowers ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} dtd ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} caltech101 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} fgvc_aircraft ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} stanford_cars ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} sun397 ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_train.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT}
CUDA_VISIBLE_DEVICES=0 bash scripts/base2new_test.sh ${TRAINER} imagenet ${KG_WEIGHT} ${MP_WEIGHT}

View File

@@ -7,6 +7,8 @@ DATASET=$2
N_PROMPTS=4 N_PROMPTS=4
KG_WEIGHT=$3 KG_WEIGHT=$3
MP_WEIGHT=$4 MP_WEIGHT=$4
ATTN_REG_WEIGHT=$5
UNCERTAINTY_SCALE=$6
#CFG=rn50_ep100 # config file #CFG=rn50_ep100 # config file
CFG=vit_b16_ep100_ctxv1 CFG=vit_b16_ep100_ctxv1
CTP=end # class token position (end or middle) CTP=end # class token position (end or middle)
@@ -19,7 +21,7 @@ SUB=new
for SEED in 1 2 3 for SEED in 1 2 3
do do
COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} COMMON_DIR=${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=output/base2new/train_base/${COMMON_DIR} MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
DIR=output/base2new/test_${SUB}/${COMMON_DIR} DIR=output/base2new/test_${SUB}/${COMMON_DIR}
@@ -43,6 +45,7 @@ do
TRAINER.COOP.CSC ${CSC} \ TRAINER.COOP.CSC ${CSC} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \ DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB} DATASET.SUBSAMPLE_CLASSES ${SUB} \
TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE}
fi fi
done done

View File

@@ -6,6 +6,8 @@ TRAINER=$1
DATASET=$2 DATASET=$2
KG_WEIGHT=$3 KG_WEIGHT=$3
MP_WEIGHT=$4 MP_WEIGHT=$4
ATTN_REG_WEIGHT=$5
UNCERTAINTY_SCALE=$6
N_PROMPTS=4 N_PROMPTS=4
#CFG=rn50_ep100 # config file\ #CFG=rn50_ep100 # config file\
CFG=vit_b16_ep100_ctxv1 CFG=vit_b16_ep100_ctxv1
@@ -16,7 +18,7 @@ CSC=False # class-specific context (False or True)
for SEED in 1 2 3 for SEED in 1 2 3
do do
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job" echo "Results are available in ${DIR}. Skip this job"
else else
@@ -33,8 +35,10 @@ do
TRAINER.COOP.W ${KG_WEIGHT} \ TRAINER.COOP.W ${KG_WEIGHT} \
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \ DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES base \ DATASET.SUBSAMPLE_CLASSES base \
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \ TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \
TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE}
fi fi
done done

View File

@@ -6,6 +6,8 @@ TRAINER=$1
N_PROMPTS=3 N_PROMPTS=3
KG_WEIGHT=$2 KG_WEIGHT=$2
MP_WEIGHT=$3 MP_WEIGHT=$3
ATTN_REG_WEIGHT=$4
UNCERTAINTY_SCALE=$5
CFG=vit_b16_ep100_ctxv1 CFG=vit_b16_ep100_ctxv1
CTP=end # class token position (end or middle) CTP=end # class token position (end or middle)
NCTX=4 # number of context tokens NCTX=4 # number of context tokens
@@ -18,7 +20,7 @@ for DATASET in ${SRC_DATASETS}
do do
for SEED in 1 2 3 for SEED in 1 2 3
do do
DIR=output_xd/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} DIR=output_xd/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job" echo "Results are available in ${DIR}. Skip this job"
else else
@@ -36,7 +38,9 @@ do
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \ DATASET.NUM_SHOTS ${SHOTS} \
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \ TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \
TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE}
fi fi
done done
done done

View File

@@ -6,6 +6,8 @@ TRAINER=$1
N_PROMPTS=3 N_PROMPTS=3
KG_WEIGHT=$2 KG_WEIGHT=$2
MP_WEIGHT=$3 MP_WEIGHT=$3
ATTN_REG_WEIGHT=$4
UNCERTAINTY_SCALE=$5
CFG=vit_b16_ep100_ctxv1 CFG=vit_b16_ep100_ctxv1
CTP=end # class token position (end or middle) CTP=end # class token position (end or middle)
NCTX=4 # number of context tokens NCTX=4 # number of context tokens
@@ -19,8 +21,8 @@ for DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r
do do
for SEED in 1 2 3 for SEED in 1 2 3
do do
MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
DIR=output_xdo/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}/${TRAINER}/${CFG}/seed${SEED} DIR=output_xdo/base2new/train_base/${DATASET}/shots_${SHOTS}_${KG_WEIGHT}_${MP_WEIGHT}_${ATTN_REG_WEIGHT}_${UNCERTAINTY_SCALE}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job" echo "Results are available in ${DIR}. Skip this job"
else else
@@ -41,7 +43,9 @@ do
TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
DATASET.NUM_SHOTS ${SHOTS} \ DATASET.NUM_SHOTS ${SHOTS} \
TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \ TRAINER.COOP.N_PROMPTS ${N_PROMPTS} \
TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} TRAINER.COOP.DIV_WEIGHT ${MP_WEIGHT} \
TRAINER.COOP.ATTENTION_REG_WEIGHT ${ATTN_REG_WEIGHT} \
TRAINER.COOP.UNCERTAINTY_SCALE ${UNCERTAINTY_SCALE}
fi fi
done done
done done

View File

@@ -105,6 +105,8 @@ def extend_cfg(cfg):
cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.COOP.DIV_WEIGHT = 0.1 cfg.TRAINER.COOP.DIV_WEIGHT = 0.1
cfg.TRAINER.COOP.N_PROMPTS = 3 cfg.TRAINER.COOP.N_PROMPTS = 3
cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT = 0.01
cfg.TRAINER.COOP.UNCERTAINTY_SCALE = 0.5
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
""" """

View File

@@ -223,6 +223,58 @@ class Adapter(nn.Module):
x = self.fc(x) x = self.fc(x)
return x return x
class AttentionBasedIntegrator(nn.Module):
def __init__(self, img_dim=512, n_prompts=4, dtype=None, uncertainty_scale=0.5):
super().__init__()
self.attention = nn.Sequential(
nn.Linear(img_dim, img_dim // 4),
nn.Tanh(),
nn.Linear(img_dim // 4, n_prompts)
)
self.dtype = dtype
self.uncertainty_scale = uncertainty_scale
if dtype is not None:
self.attention = self.attention.to(dtype)
def forward(self, image_features, all_logits):
batch_size = image_features.shape[0]
n_prompts = all_logits.shape[0]
# 注意力权重
attn_scores = self.attention(image_features)
attn_weights = F.softmax(attn_scores, dim=-1) # [batch, n_prompts]
# 不确定性权重(基于预测熵)- 使用log-sum-exp技巧提高稳定性
probs = F.softmax(all_logits, dim=-1) # [n_prompts, batch, n_classes]
log_probs = F.log_softmax(all_logits, dim=-1)
# 计算熵H = -sum(p * log(p))
entropy = -(probs * log_probs).sum(dim=-1) # [n_prompts, batch]
# 转换为确定性(熵越小,确定性越大)
certainty = -entropy.t() # [batch, n_prompts]
# 归一化不确定性权重使用log-sum-exp提高稳定性
certainty_scaled = certainty * self.uncertainty_scale
log_uncertainty_weights = certainty_scaled - torch.logsumexp(certainty_scaled, dim=-1, keepdim=True)
uncertainty_weights = torch.exp(log_uncertainty_weights)
# 混合权重:结合注意力权重和不确定性权重(使用对数域)
log_attn_weights = torch.log(attn_weights + 1e-8)
log_uncertainty_weights = torch.log(uncertainty_weights + 1e-8)
# 在对数域中相加,然后指数化
log_hybrid_weights = torch.log(torch.tensor(0.5, device=attn_weights.device)) + \
log_attn_weights + log_uncertainty_weights
log_hybrid_weights = log_hybrid_weights - torch.logsumexp(log_hybrid_weights, dim=-1, keepdim=True)
hybrid_weights = torch.exp(log_hybrid_weights)
# 加权集成
weighted_logits = torch.einsum('bp,pbc->bc', hybrid_weights, all_logits)
return weighted_logits, hybrid_weights, entropy
class CustomCLIP(nn.Module): class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model): def __init__(self, cfg, classnames, clip_model):
super().__init__() super().__init__()
@@ -236,6 +288,15 @@ class CustomCLIP(nn.Module):
self.dtype = clip_model.dtype self.dtype = clip_model.dtype
self.meta_net = self.prompt_learner.meta_net self.meta_net = self.prompt_learner.meta_net
self.adapter = Adapter(512, 4).to(clip_model.dtype) self.adapter = Adapter(512, 4).to(clip_model.dtype)
uncertainty_scale = getattr(cfg.TRAINER.COOP, 'UNCERTAINTY_SCALE', 0.5)
self.prompt_integrator = AttentionBasedIntegrator(
img_dim=clip_model.visual.output_dim,
n_prompts=self.n_prompts,
dtype=clip_model.dtype,
uncertainty_scale=uncertainty_scale
)
def compute_diversity_loss(self, text_features): def compute_diversity_loss(self, text_features):
if self.n_prompts == 1: if self.n_prompts == 1:
@@ -283,10 +344,12 @@ class CustomCLIP(nn.Module):
text_features_i = text_features_i / text_features_i.norm(dim=-1, keepdim=True) text_features_i = text_features_i / text_features_i.norm(dim=-1, keepdim=True)
logits_i = logit_scale * image_features @ text_features_i.t() logits_i = logit_scale * image_features @ text_features_i.t()
all_logits.append(logits_i) all_logits.append(logits_i)
all_logits = torch.stack(all_logits)
logits, hybrid_weights, entropy = self.prompt_integrator(image_features, all_logits)
logits = torch.stack(all_logits).mean(dim=0) return logits, score, diversity_loss, hybrid_weights, entropy
return logits, score, diversity_loss
@TRAINER_REGISTRY.register() @TRAINER_REGISTRY.register()
@@ -310,10 +373,11 @@ class MSGCoOp(TrainerX):
self.model = CustomCLIP(cfg, classnames, clip_model) self.model = CustomCLIP(cfg, classnames, clip_model)
self.w = cfg.TRAINER.COOP.W self.w = cfg.TRAINER.COOP.W
self.diversity_weight = cfg.TRAINER.COOP.DIV_WEIGHT self.diversity_weight = cfg.TRAINER.COOP.DIV_WEIGHT
self.attn_reg_weight = cfg.TRAINER.COOP.ATTENTION_REG_WEIGHT if hasattr(cfg.TRAINER.COOP, 'ATTENTION_REG_WEIGHT') else 0.01
print("Turning off gradients in both the image and the text encoder") print("Turning off gradients in both the image and the text encoder")
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if "ctx" not in name: if "ctx" not in name and "prompt_integrator" not in name:
param.requires_grad_(False) param.requires_grad_(False)
else: else:
print(name) print(name)
@@ -322,8 +386,10 @@ class MSGCoOp(TrainerX):
load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device) self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) # NOTE: give prompt_learner and prompt_integrator to the optimizer
trainable_params = list(self.model.prompt_learner.parameters()) + list(self.model.prompt_integrator.parameters())
self.optim = build_optimizer([{'params': trainable_params}], cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
@@ -352,8 +418,12 @@ class MSGCoOp(TrainerX):
self.scaler.step(self.optim) self.scaler.step(self.optim)
self.scaler.update() self.scaler.update()
else: else:
output, score, diversity_loss = self.model(image) output, score, diversity_loss, hybrid_weights, entropy = self.model(image)
loss = F.cross_entropy(output, label)+self.w*score + diversity_loss * self.diversity_weight
# Add attention regularization to encourage balanced prompt usage
attn_reg = -(hybrid_weights * torch.log(hybrid_weights + 1e-8)).mean()
loss = F.cross_entropy(output, label) + self.w * score + diversity_loss * self.diversity_weight + self.attn_reg_weight * attn_reg
self.model_backward_and_update(loss) self.model_backward_and_update(loss)
loss_summary = { loss_summary = {