Compare commits

3 Commits

Author SHA1 Message Date
1d7d93ede5 Last-k Average 2026-02-07 15:58:51 +08:00
f3a7993665 xda xdo script 2026-02-06 17:38:54 +08:00
91e873c365 dual and softmax conf 2026-02-05 18:46:37 +08:00
15 changed files with 234 additions and 318 deletions

View File

@@ -39,5 +39,4 @@ TRAINER:
PROMPT_DEPTH_TEXT: 9 PROMPT_DEPTH_TEXT: 9
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 15 LAST_K: 5
GPA_STD: 1

View File

@@ -23,6 +23,7 @@ OPTIM:
WARMUP_CONS_LR: 1e-5 WARMUP_CONS_LR: 1e-5
TRAIN: TRAIN:
CHECKPOINT_FREQ: 5
PRINT_FREQ: 20 PRINT_FREQ: 20
MODEL: MODEL:
@@ -39,5 +40,4 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6 LAST_K: 5
GPA_STD: 10

View File

@@ -16,7 +16,7 @@ INPUT:
OPTIM: OPTIM:
NAME: "sgd" NAME: "sgd"
LR: 0.0025 LR: 0.0025
MAX_EPOCH: 50 MAX_EPOCH: 5
LR_SCHEDULER: "cosine" LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1 WARMUP_EPOCH: 1
WARMUP_TYPE: "constant" WARMUP_TYPE: "constant"
@@ -35,13 +35,8 @@ TRAINER:
N_CTX_TEXT: 4 N_CTX_TEXT: 4
CTX_INIT: "a photo of a" CTX_INIT: "a photo of a"
PREC: "fp16" PREC: "fp16"
PROMPT_DEPTH_VISION: 9 PROMPT_DEPTH_VISION: 3
PROMPT_DEPTH_TEXT: 9 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
# Use the below configuration for: ImageNet, Caltech101, OxfordPets, Food101, UCF101 and SUN397 LAST_K: 5
GPA_MEAN: 30
GPA_STD: 30
# Use the below configuration for: StanfordCars, Flowers102, FGVCAircraft, DTD and EuroSAT
# GPA_MEAN: 45
# GPA_STD: 5

View File

@@ -11,7 +11,7 @@ Training PromptSRC on ImageNet for 20 epochs takes around 6 hours for a single s
## PromptSRC ## PromptSRC
#### (1) Base-to-Novel class generalization setting #### (1) Base-to-Novel class generalization setting
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as GPA STD, GPA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file. The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as LAST_K, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
Run the commands below to train PromptSRC on ImageNet. Run the commands below to train PromptSRC on ImageNet.

View File

@@ -1,15 +1,15 @@
seeds=(1 2 3) seeds=(1 2 3)
datasets=( datasets=(
# "ucf101" "ucf101"
# "eurosat" "eurosat"
# "oxford_pets" "oxford_pets"
# "food101" "food101"
# "oxford_flowers" "oxford_flowers"
# "dtd" "dtd"
# "caltech101" "caltech101"
# "fgvc_aircraft" "fgvc_aircraft"
# "stanford_cars" "stanford_cars"
# "sun397" "sun397"
"imagenet" "imagenet"
) )

View File

@@ -1,27 +0,0 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
CFG=vit_b16_c2_ep50_batch4_4+4ctx_few_shot
SHOTS=$2
for SEED in 1 2 3
do
DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
if [ -d "$DIR" ]; then
echo " The results exist at ${DIR}"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS}
fi
done

View File

@@ -1,54 +0,0 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SEED=$2
WEIGHTSPATH=$3
CFG=vit_b16_c2_ep20_batch4_4+4ctx
SHOTS=16
LOADEP=20
SUB_base=base
SUB_novel=new
COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED}
DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR}
DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR}
if [ -d "$DIR" ]; then
echo "Results are already available in ${DIR}. Skipping..."
else
echo "Evaluating model"
echo "Runing the first phase job and save the output to ${DIR}"
# Evaluate on base classes
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR_base} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB_base}
# Evaluate on novel classes
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR_novel} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB_novel}
fi

View File

@@ -1,34 +0,0 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SHOTS=$2
WEIGHTSPATH=$3
CFG=vit_b16_c2_ep50_batch4_4+4ctx_few_shot
LOADEP=50
for SEED in 1 2 3
do
MODEL_DIR=${WEIGHTSPATH}/${SHOTS}shot/seed${SEED}
DIR=output/few_shot/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
if [ -d "$DIR" ]; then
echo " The results exist at ${DIR}"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS}
fi
done

View File

@@ -1,36 +0,0 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SEED=$2
WEIGHTSPATH=$3
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets
SHOTS=16
LOADEP=20
MODEL_DIR=${WEIGHTSPATH}/seed${SEED}
DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are already available in ${DIR}. Skipping..."
else
echo "Evaluating model"
echo "Runing the first phase job and save the output to ${DIR}"
# Evaluate on evaluation datasets
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
fi

View File

@@ -1,31 +0,0 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SEED=$2
CFG=vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets
SHOTS=16
DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
--load-epoch 20 \
--eval-only
fi

View File

@@ -1,29 +1,30 @@
#!/bin/bash #!/bin/bash
# custom config DATA=" ~/Datasets/CoOp"
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC TRAINER=PromptSRC
SRC_DATASETS=imagenet
DATASET=$1
SEED=$2
CFG=vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets
SHOTS=16 SHOTS=16
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets
DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} for SEED in 1 2 3
if [ -d "$DIR" ]; then do
echo "Results are available in ${DIR}." DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
else
echo "Run this job and save the output to ${DIR}" if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
CUDA_VISIBLE_DEVICES=0 python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${SRC_DATASETS}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS}
fi
done
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS}
fi

View File

@@ -0,0 +1,46 @@
#!/bin/bash
# custom config
DATA=" ~/Datasets/CoOp"
TRAINER=PromptSRC
SRC_DATASETS=imagenet
SHOTS=16
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets
LOADEP=20
DATASETS=(dtd eurosat fgvc_aircraft food101 oxford_flowers oxford_pets stanford_cars ucf101 caltech101 sun397)
SEEDS=(1 2 3)
for DATASET in "${DATASETS[@]}"
do
for SEED in "${SEEDS[@]}"
do
MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
DIR=output_xd/base2new/test_new/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
echo "Loading model from ${MODEL_DIR}"
CUDA_VISIBLE_DEVICES=0 python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only
fi
done
done

View File

@@ -0,0 +1,46 @@
#!/bin/bash
# custom config
DATA=" ~/Datasets/CoOp"
TRAINER=PromptSRC
SRC_DATASETS=imagenet
SHOTS=16
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets
LOADEP=20
DATASETS=(imagenetv2 imagenet_sketch imagenet_a imagenet_r)
SEEDS=(1 2 3)
for DATASET in "${DATASETS[@]}"
do
for SEED in "${SEEDS[@]}"
do
MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
DIR=output_xd/base2new/test_new/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
echo "Loading model from ${MODEL_DIR}"
CUDA_VISIBLE_DEVICES=0 python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only
fi
done
done

View File

@@ -119,9 +119,10 @@ def extend_cfg(cfg):
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1) cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1) cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25 cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight
cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10 cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15 cfg.TRAINER.PROMPTSRC.LAST_K = 5
cfg.TRAINER.PROMPTSRC.GPA_STD = 1
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for independent Vision Language prompting (independent-vlp) # Config for independent Vision Language prompting (independent-vlp)

View File

@@ -107,28 +107,32 @@ class VLPromptLearner(nn.Module):
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init and n_ctx <= 4: if ctx_init and n_ctx <= 4:
# use given words to initialize context vectors
ctx_init = ctx_init.replace("_", " ") ctx_init = ctx_init.replace("_", " ")
n_ctx = n_ctx
prompt = clip.tokenize(ctx_init) prompt = clip.tokenize(ctx_init)
with torch.no_grad(): with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype) embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1: 1 + n_ctx, :] ctx_vectors_strong = embedding[0, 1: 1 + n_ctx, :]
prompt_prefix = ctx_init prompt_prefix_strong = ctx_init
else: else:
# random initialization ctx_vectors_strong = torch.empty(n_ctx, ctx_dim, dtype=dtype)
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors_strong, std=0.02)
nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix_strong = " ".join(["X"] * n_ctx)
prompt_prefix = " ".join(["X"] * n_ctx)
print(f"Independent V-L design") ctx_vectors_weak = torch.empty(n_ctx, ctx_dim, dtype=dtype)
print(f'Initial text context: "{prompt_prefix}"') nn.init.normal_(ctx_vectors_weak, std=0.02)
prompt_prefix_weak = " ".join(["X"] * n_ctx)
print(f"Independent V-L design with Dual Prompt Branches")
print(f'Strong branch initial text context: "{prompt_prefix_strong}"')
print(f'Weak branch initial text context: "{prompt_prefix_weak}"')
print(f"Number of context words (tokens) for Language prompting: {n_ctx}") print(f"Number of context words (tokens) for Language prompting: {n_ctx}")
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.PROMPTSRC.N_CTX_VISION}") print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.PROMPTSRC.N_CTX_VISION}")
self.ctx = nn.Parameter(ctx_vectors) self.ctx_strong = nn.Parameter(ctx_vectors_strong)
self.ctx_weak = nn.Parameter(ctx_vectors_weak)
classnames = [name.replace("_", " ") for name in classnames] classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames] prompts = [prompt_prefix_strong + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn) tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
# Also create frozen CLIP # Also create frozen CLIP
@@ -188,15 +192,19 @@ class VLPromptLearner(nn.Module):
return prompts return prompts
def forward(self): def forward(self):
ctx = self.ctx ctx_strong = self.ctx_strong
if ctx.dim() == 2: ctx_weak = self.ctx_weak
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
if ctx_strong.dim() == 2:
ctx_strong = ctx_strong.unsqueeze(0).expand(self.n_cls, -1, -1)
ctx_weak = ctx_weak.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix prefix = self.token_prefix
suffix = self.token_suffix suffix = self.token_suffix
prompts = self.construct_prompts(ctx, prefix, suffix) prompts_strong = self.construct_prompts(ctx_strong, prefix, suffix)
prompts_weak = self.construct_prompts(ctx_weak, prefix, suffix)
return prompts return prompts_strong, prompts_weak
class CustomCLIP(nn.Module): class CustomCLIP(nn.Module):
@@ -215,29 +223,41 @@ class CustomCLIP(nn.Module):
tokenized_prompts = self.tokenized_prompts tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp() logit_scale = self.logit_scale.exp()
prompts = self.prompt_learner() prompts_strong, prompts_weak = self.prompt_learner()
# Compute the prompted image and text features
text_features = self.text_encoder(prompts, tokenized_prompts) with torch.no_grad():
zero_shot_features = self.prompt_learner.ZS_image_encoder(image.type(self.dtype))
zero_shot_features = zero_shot_features / zero_shot_features.norm(dim=-1, keepdim=True)
image_features = self.image_encoder(image.type(self.dtype)) image_features = self.image_encoder(image.type(self.dtype))
image_features = image_features / image_features.norm(dim=-1, keepdim=True) image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Compute the prompted logits
logits = logit_scale * image_features @ text_features.t()
if self.prompt_learner.training:
# Now calculate the frozen pre-trained features
fixed_embeddings = self.prompt_learner.fixed_embeddings # precomputed pre-trained frozen textual features
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
with torch.no_grad():
zero_shot_features = self.prompt_learner.ZS_image_encoder(image.type(self.dtype))
zero_shot_features = zero_shot_features / zero_shot_features.norm(dim=-1, keepdim=True)
# Compute pre-trained frozen visual features
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
return F.cross_entropy(logits, text_features_strong = self.text_encoder(prompts_strong, tokenized_prompts)
label), text_features, fixed_embeddings, zero_shot_features, \ text_features_strong = text_features_strong / text_features_strong.norm(dim=-1, keepdim=True)
image_features, zero_shot_logits, logits
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
fixed_embeddings = self.prompt_learner.fixed_embeddings
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
logits_strong = logit_scale * image_features @ text_features_strong.t()
logits_weak = logit_scale * image_features @ text_features_weak.t()
zs_probs = F.softmax(zero_shot_logits, dim=1)
confidence = zs_probs.max(dim=1).values
alpha = confidence.unsqueeze(1)
logits_final = alpha * logits_strong + (1 - alpha) * logits_weak
if self.prompt_learner.training:
loss_ce = F.cross_entropy(logits_final, label)
return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
else: else:
return logits return logits_final
@TRAINER_REGISTRY.register() @TRAINER_REGISTRY.register()
@@ -291,12 +311,8 @@ class PromptSRC(TrainerX):
# Cosine scheduler # Cosine scheduler
self.total_epochs = cfg.OPTIM.MAX_EPOCH self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.step_counter = 1 self.step_counter = 1
N = cfg.OPTIM.MAX_EPOCH self.max_k = cfg.TRAINER.PROMPTSRC.LAST_K
mean = cfg.TRAINER.PROMPTSRC.GPA_MEAN self.last_k_models = []
stdev = cfg.TRAINER.PROMPTSRC.GPA_STD
gauss = self.get_gauss(mean, stdev)
self.gauss = np.array([gauss(a) for a in range(1, N + 1)])
self.gauss = self.gauss / sum(self.gauss)
self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is # Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel # big, which slows down the copy operation in DataParallel
@@ -304,8 +320,6 @@ class PromptSRC(TrainerX):
if device_count > 1: if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model) self.model = nn.DataParallel(self.model)
# Keep model with GPA
self.previous_model_gpa = None
def forward_backward(self, batch): def forward_backward(self, batch):
image, label = self.parse_batch_train(batch) image, label = self.parse_batch_train(batch)
@@ -323,22 +337,25 @@ class PromptSRC(TrainerX):
scaler.step(optim) scaler.step(optim)
scaler.update() scaler.update()
else: else:
loss_ce, normalized_text_features, zs_clip_text_embeddings, zs_image_embedd, image_ft, \ loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
zero_shot_logits, logits = model(image, label) zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
# Calculate the L_SCL_text loss
loss_scl_text = F.l1_loss(normalized_text_features, zs_clip_text_embeddings.cuda(), lambda1 = self.cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT
reduction='mean') * self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT lambda2 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG
# Calculate the L_SCL_image loss lambda3 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK
loss_scl_image = F.l1_loss(image_ft, zs_image_embedd.cuda(),
reduction='mean') * self.cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT loss_scl_image = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
# Now calculate L_SCL_logits loss_scl_text_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
loss_scl_text_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
L_SCL_logits = F.kl_div( L_SCL_logits = F.kl_div(
F.log_softmax(logits / 1, dim=1), F.log_softmax(logits_final / 1, dim=1),
F.log_softmax(zero_shot_logits / 1, dim=1), F.log_softmax(zero_shot_logits / 1, dim=1),
reduction='sum', reduction='sum',
log_target=True log_target=True
) * (1 * 1) / logits.numel() ) * (1 * 1) / logits_final.numel()
L_SCL = (L_SCL_logits + loss_scl_text + loss_scl_image)
L_SCL = (L_SCL_logits + loss_scl_text_strong + loss_scl_text_weak + loss_scl_image)
loss = (loss_ce + L_SCL) loss = (loss_ce + L_SCL)
optim.zero_grad() optim.zero_grad()
loss.backward() loss.backward()
@@ -348,45 +365,32 @@ class PromptSRC(TrainerX):
if (self.batch_idx + 1) == self.num_batches: if (self.batch_idx + 1) == self.num_batches:
self.update_lr() self.update_lr()
# Means one epoch is completed, perform GPA
self.step_counter = self.step_counter + 1 self.step_counter = self.step_counter + 1
current_epoch_weight = self.gauss[self.step_counter - 2]
current_model_weights = copy.deepcopy(model.state_dict()) current_model_weights = copy.deepcopy(model.state_dict())
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight) for key in current_model_weights:
if self.previous_model_gpa is None: current_model_weights[key] = current_model_weights[key].cpu()
self.previous_model_gpa = weighted_state_dict self.last_k_models.append(current_model_weights)
else: if len(self.last_k_models) > self.max_k:
self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa) self.last_k_models.pop(0)
torch.cuda.empty_cache()
if self.step_counter == self.model.total_epochs + 1: if self.step_counter == self.model.total_epochs + 1:
print("Using GPA model for final inference...") print(f"Using Last-K Averaging (K={len(self.last_k_models)}) model for final inference...")
model.load_state_dict(self.previous_model_gpa) averaged_state_dict = self._average_last_k_models()
self.model.load_state_dict(self.previous_model_gpa) for key in averaged_state_dict:
averaged_state_dict[key] = averaged_state_dict[key].cuda()
model.load_state_dict(averaged_state_dict)
self.model.load_state_dict(averaged_state_dict)
return loss_summary return loss_summary
def state_dict_weighting(self, main_dict, weightage, prompt_only=False): def _average_last_k_models(self):
# Average all parameters if not self.last_k_models:
updated_dict = copy.deepcopy(main_dict) return {}
if not prompt_only: averaged_dict = {}
for key in main_dict: for key in self.last_k_models[0]:
updated_dict[key] = main_dict[key] * weightage stacked = torch.stack([model_state[key] for model_state in self.last_k_models])
return updated_dict averaged_dict[key] = torch.mean(stacked, dim=0)
else: return averaged_dict
return main_dict * weightage
def state_dict_add(self, dict1, dict2, prompt_only=False):
# Average all parameters
if not prompt_only:
modified_dict = dict2
for key in dict1:
modified_dict[key] = (modified_dict[key] + dict1[key])
return modified_dict
else:
return dict1 + dict2
def get_gauss(self, mu, sigma):
gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return gauss
def parse_batch_train(self, batch): def parse_batch_train(self, batch):
input = batch["img"] input = batch["img"]
@@ -425,6 +429,12 @@ class PromptSRC(TrainerX):
if "prompt_learner.token_suffix" in state_dict: if "prompt_learner.token_suffix" in state_dict:
del state_dict["prompt_learner.token_suffix"] del state_dict["prompt_learner.token_suffix"]
# Handle backward compatibility: if old checkpoint has ctx, initialize both ctx_strong and ctx_weak
if "prompt_learner.ctx" in state_dict:
ctx = state_dict.pop("prompt_learner.ctx")
state_dict["prompt_learner.ctx_strong"] = ctx.clone()
state_dict["prompt_learner.ctx_weak"] = ctx.clone()
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False # set strict=False
self._models[name].load_state_dict(state_dict, strict=False) self._models[name].load_state_dict(state_dict, strict=False)