Compare commits

2 Commits

Author SHA1 Message Date
7fcf319dcf fix conf 2026-02-05 12:12:11 +08:00
1925ddfc86 dual and softmax conf 2026-02-05 12:11:29 +08:00
19 changed files with 311 additions and 219 deletions

View File

@@ -1,4 +1,4 @@
# DZGCoOp: Dual-branch Zero-shot Guidance CoOp
# PromptSRC: Prompting with Self-regularizing constraints
DATALOADER:
TRAIN_X:
BATCH_SIZE: 4
@@ -30,15 +30,14 @@ MODEL:
NAME: "ViT-B/16"
TRAINER:
DZGCOOP:
PROMPTSRC:
N_CTX_VISION: 4
N_CTX_TEXT: 4
CTX_INIT: "a photo of a"
PREC: "fp16"
PROMPT_DEPTH_VISION: 9
PROMPT_DEPTH_TEXT: 9
IMAGE_LOSS_WEIGHT: 8
TEXT_LOSS_WEIGHT_STRONG: 24
TEXT_LOSS_WEIGHT_WEAK: 8
EWA_MEAN: 15
EWA_STD: 1
TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 15
GPA_STD: 1

View File

@@ -1,4 +1,4 @@
# DZGCoOp: Dual-branch Zero-shot Guidance CoOp
# PromptSRC: Prompting with Self-regularizing constraints
DATALOADER:
TRAIN_X:
BATCH_SIZE: 4
@@ -23,7 +23,6 @@ OPTIM:
WARMUP_CONS_LR: 1e-5
TRAIN:
CHECKPOINT_FREQ: 5
PRINT_FREQ: 20
MODEL:
@@ -31,7 +30,7 @@ MODEL:
NAME: "ViT-B/16"
TRAINER:
DZGCOOP:
PROMPTSRC:
N_CTX_VISION: 4
N_CTX_TEXT: 4
CTX_INIT: "a photo of a"
@@ -40,5 +39,5 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10
EWA_MEAN: 6
EWA_STD: 10
GPA_MEAN: 6
GPA_STD: 10

View File

@@ -1,4 +1,4 @@
# DZGCoOp: Dual-branch Zero-shot Guidance CoOp
# PromptSRC: Prompting with Self-regularizing constraints
DATALOADER:
TRAIN_X:
BATCH_SIZE: 4
@@ -16,7 +16,7 @@ INPUT:
OPTIM:
NAME: "sgd"
LR: 0.0025
MAX_EPOCH: 5
MAX_EPOCH: 50
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
@@ -30,14 +30,18 @@ MODEL:
NAME: "ViT-B/16"
TRAINER:
DZGCOOP:
PROMPTSRC:
N_CTX_VISION: 4
N_CTX_TEXT: 4
CTX_INIT: "a photo of a"
PREC: "fp16"
PROMPT_DEPTH_VISION: 3
PROMPT_DEPTH_TEXT: 3
PROMPT_DEPTH_VISION: 9
PROMPT_DEPTH_TEXT: 9
TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10
EWA_MEAN: 6
EWA_STD: 10
# Use the below configuration for: ImageNet, Caltech101, OxfordPets, Food101, UCF101 and SUN397
GPA_MEAN: 30
GPA_STD: 30
# Use the below configuration for: StanfordCars, Flowers102, FGVCAircraft, DTD and EuroSAT
# GPA_MEAN: 45
# GPA_STD: 5

View File

@@ -11,7 +11,7 @@ Training PromptSRC on ImageNet for 20 epochs takes around 6 hours for a single s
## PromptSRC
#### (1) Base-to-Novel class generalization setting
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as EWA STD, EWA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as GPA STD, GPA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
Run the commands below to train PromptSRC on ImageNet.

View File

@@ -109,7 +109,7 @@ def print_model_results(results, model_name):
def main():
root_dir = 'output' # 修改为你的output目录路径
target_model = 'DZGCoOp' # 指定要分析的模型
target_model = 'PromptSRC' # 指定要分析的模型
results = collect_model_results(root_dir, target_model)
print_model_results(results, target_model)

View File

@@ -1,30 +0,0 @@
#!/bin/bash
DATA=" ~/Datasets/CoOp"
TRAINER=DZGCoOp
SRC_DATASETS=imagenet
SHOTS=16
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets
for SEED in 1 2 3
do
DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
CUDA_VISIBLE_DEVICES=0 python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${SRC_DATASETS}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS}
fi
done

View File

@@ -1,46 +0,0 @@
#!/bin/bash
# custom config
DATA=" ~/Datasets/CoOp"
TRAINER=DZGCoOp
SRC_DATASETS=imagenet
SHOTS=16
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets
LOADEP=20
DATASETS=(dtd eurosat fgvc_aircraft food101 oxford_flowers oxford_pets stanford_cars ucf101 caltech101 sun397)
SEEDS=(1 2 3)
for DATASET in "${DATASETS[@]}"
do
for SEED in "${SEEDS[@]}"
do
MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
DIR=output_xd/base2new/test_new/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
echo "Loading model from ${MODEL_DIR}"
CUDA_VISIBLE_DEVICES=0 python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only
fi
done
done

View File

@@ -1,46 +0,0 @@
#!/bin/bash
# custom config
DATA=" ~/Datasets/CoOp"
TRAINER=DZGCoOp
SRC_DATASETS=imagenet
SHOTS=16
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets
LOADEP=20
DATASETS=(imagenetv2 imagenet_sketch imagenet_a imagenet_r)
SEEDS=(1 2 3)
for DATASET in "${DATASETS[@]}"
do
for SEED in "${SEEDS[@]}"
do
MODEL_DIR=output_xd/base2new/train_base/${SRC_DATASETS}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
DIR=output_xd/base2new/test_new/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
echo "Loading model from ${MODEL_DIR}"
CUDA_VISIBLE_DEVICES=0 python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only
fi
done
done

View File

@@ -15,8 +15,8 @@ datasets=(
for dataset in "${datasets[@]}"; do
for seed in "${seeds[@]}"; do
bash scripts/dzgcoop/base2new_train.sh "$dataset" "$seed"
bash scripts/dzgcoop/base2new_test.sh "$dataset" "$seed"
bash scripts/promptsrc/base2new_train.sh "$dataset" "$seed"
bash scripts/promptsrc/base2new_test.sh "$dataset" "$seed"
done
done

View File

@@ -3,7 +3,7 @@
# custom config
DATA="~/Datasets/CoOp"
TRAINER=DZGCoOp
TRAINER=PromptSRC
DATASET=$1
SEED=$2

View File

@@ -2,7 +2,7 @@
# custom config
DATA="~/Datasets/CoOp"
TRAINER=DZGCoOp
TRAINER=PromptSRC
DATASET=$1
SEED=$2

View File

@@ -0,0 +1,27 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
CFG=vit_b16_c2_ep50_batch4_4+4ctx_few_shot
SHOTS=$2
for SEED in 1 2 3
do
DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
if [ -d "$DIR" ]; then
echo " The results exist at ${DIR}"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS}
fi
done

View File

@@ -0,0 +1,54 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SEED=$2
WEIGHTSPATH=$3
CFG=vit_b16_c2_ep20_batch4_4+4ctx
SHOTS=16
LOADEP=20
SUB_base=base
SUB_novel=new
COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED}
DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR}
DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR}
if [ -d "$DIR" ]; then
echo "Results are already available in ${DIR}. Skipping..."
else
echo "Evaluating model"
echo "Runing the first phase job and save the output to ${DIR}"
# Evaluate on base classes
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR_base} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB_base}
# Evaluate on novel classes
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR_novel} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB_novel}
fi

View File

@@ -0,0 +1,34 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SHOTS=$2
WEIGHTSPATH=$3
CFG=vit_b16_c2_ep50_batch4_4+4ctx_few_shot
LOADEP=50
for SEED in 1 2 3
do
MODEL_DIR=${WEIGHTSPATH}/${SHOTS}shot/seed${SEED}
DIR=output/few_shot/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
if [ -d "$DIR" ]; then
echo " The results exist at ${DIR}"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS}
fi
done

View File

@@ -0,0 +1,36 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SEED=$2
WEIGHTSPATH=$3
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets
SHOTS=16
LOADEP=20
MODEL_DIR=${WEIGHTSPATH}/seed${SEED}
DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are already available in ${DIR}. Skipping..."
else
echo "Evaluating model"
echo "Runing the first phase job and save the output to ${DIR}"
# Evaluate on evaluation datasets
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
fi

View File

@@ -0,0 +1,31 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SEED=$2
CFG=vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets
SHOTS=16
DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
--load-epoch 20 \
--eval-only
fi

View File

@@ -0,0 +1,29 @@
#!/bin/bash
# custom config
DATA="/path/to/dataset/folder"
TRAINER=PromptSRC
DATASET=$1
SEED=$2
CFG=vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets
SHOTS=16
DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}."
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS}
fi

View File

@@ -28,7 +28,7 @@ import trainers.cocoop
import trainers.zsclip
import trainers.maple
import trainers.independentVL
import trainers.dzgcoop
import trainers.promptsrc
def print_args(args, cfg):
@@ -110,19 +110,23 @@ def extend_cfg(cfg):
cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1)
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for DZGCoOp
cfg.TRAINER.DZGCOOP = CN()
cfg.TRAINER.DZGCOOP.N_CTX_VISION = 4 # number of context vectors at the vision branch
cfg.TRAINER.DZGCOOP.N_CTX_TEXT = 4 # number of context vectors at the language branch
cfg.TRAINER.DZGCOOP.CTX_INIT = "a photo of a" # initialization words
cfg.TRAINER.DZGCOOP.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK = 10 # lambda3: weak text constraint weight
cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.DZGCOOP.EWA_MEAN = 15
cfg.TRAINER.DZGCOOP.EWA_STD = 1
# Config for PromptSRC
cfg.TRAINER.PROMPTSRC = CN()
cfg.TRAINER.PROMPTSRC.N_CTX_VISION = 4 # number of context vectors at the vision branch
cfg.TRAINER.PROMPTSRC.N_CTX_TEXT = 4 # number of context vectors at the language branch
cfg.TRAINER.PROMPTSRC.CTX_INIT = "a photo of a" # initialization words
cfg.TRAINER.PROMPTSRC.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight
cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15
cfg.TRAINER.PROMPTSRC.GPA_STD = 1
cfg.TRAINER.PROMPTSRC.CONFIDENCE_TYPE = "max_margin" # entropy, max_prob, margin, max_margin
cfg.TRAINER.PROMPTSRC.CONFIDENCE_TEMPERATURE = 2.0 # temperature for confidence calculation
cfg.TRAINER.PROMPTSRC.CONFIDENCE_MOMENTUM = 0.95 # momentum for running confidence
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for independent Vision Language prompting (independent-vlp)

View File

@@ -51,10 +51,10 @@ def load_clip_to_cpu(cfg, zero_shot_model=False):
state_dict = torch.load(model_path, map_location="cpu")
if not zero_shot_model:
design_details = {"trainer": 'IVLP',
"vision_depth": cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_VISION,
"language_depth": cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT,
"vision_ctx": cfg.TRAINER.DZGCOOP.N_CTX_VISION,
"language_ctx": cfg.TRAINER.DZGCOOP.N_CTX_TEXT}
"vision_depth": cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION,
"language_depth": cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT,
"vision_ctx": cfg.TRAINER.PROMPTSRC.N_CTX_VISION,
"language_ctx": cfg.TRAINER.PROMPTSRC.N_CTX_TEXT}
model = clip.build_model(state_dict or model.state_dict(), design_details)
else:
# Return original CLIP model for generating frozen VL features
@@ -95,17 +95,18 @@ class VLPromptLearner(nn.Module):
super().__init__()
n_cls = len(classnames)
# Make sure Language depth >= 1
assert cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \
assert cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \
"\nPlease use VPT trainer if you want to learn only vision " \
"branch"
n_ctx = cfg.TRAINER.DZGCOOP.N_CTX_TEXT
ctx_init = cfg.TRAINER.DZGCOOP.CTX_INIT
n_ctx = cfg.TRAINER.PROMPTSRC.N_CTX_TEXT
ctx_init = cfg.TRAINER.PROMPTSRC.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
# Strong constraint branch initialization
if ctx_init and n_ctx <= 4:
ctx_init = ctx_init.replace("_", " ")
prompt = clip.tokenize(ctx_init)
@@ -118,6 +119,7 @@ class VLPromptLearner(nn.Module):
nn.init.normal_(ctx_vectors_strong, std=0.02)
prompt_prefix_strong = " ".join(["X"] * n_ctx)
# Weak constraint branch - random initialization
ctx_vectors_weak = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors_weak, std=0.02)
prompt_prefix_weak = " ".join(["X"] * n_ctx)
@@ -126,7 +128,7 @@ class VLPromptLearner(nn.Module):
print(f'Strong branch initial text context: "{prompt_prefix_strong}"')
print(f'Weak branch initial text context: "{prompt_prefix_weak}"')
print(f"Number of context words (tokens) for Language prompting: {n_ctx}")
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.DZGCOOP.N_CTX_VISION}")
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.PROMPTSRC.N_CTX_VISION}")
self.ctx_strong = nn.Parameter(ctx_vectors_strong)
self.ctx_weak = nn.Parameter(ctx_vectors_weak)
@@ -142,7 +144,7 @@ class VLPromptLearner(nn.Module):
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.ZS_image_encoder = clip_model_temp_image.visual
# Now pre-compute the frozen VL embeddings from LLM descriptions
semantic_guidance_features = []
all_teacher_features = []
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
with open(desc_file, "r") as f:
all_desc = json.load(f)
@@ -155,9 +157,9 @@ class VLPromptLearner(nn.Module):
cls_feature = clip_model_temp.encode_text(cls_token)
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
cls_feature = torch.mean(cls_feature, dim=0)
semantic_guidance_features.append(cls_feature)
all_teacher_features.append(cls_feature)
self.semantic_embeddings = torch.stack(semantic_guidance_features)
self.fixed_embeddings = torch.stack(all_teacher_features)
print(f"Using LLM descriptions from: {desc_file}")
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
@@ -238,32 +240,29 @@ class CustomCLIP(nn.Module):
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
semantic_embeddings = self.prompt_learner.semantic_embeddings
semantic_embeddings = semantic_embeddings / semantic_embeddings.norm(dim=-1, keepdim=True)
fixed_embeddings = self.prompt_learner.fixed_embeddings
fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ semantic_embeddings.half().cuda().t()
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
logits_strong = logit_scale * image_features @ text_features_strong.t()
logits_weak = logit_scale * image_features @ text_features_weak.t()
zs_probs = F.softmax(zero_shot_logits, dim=1)
confidence = zs_probs.max(dim=1).values
alpha = confidence.unsqueeze(1)
alpha = 0.5
logits_final = alpha * logits_strong + (1 - alpha) * logits_weak
if self.prompt_learner.training:
loss_ce = F.cross_entropy(logits_final, label)
return loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
else:
return logits_final
@TRAINER_REGISTRY.register()
class DZGCoOp(TrainerX):
class PromptSRC(TrainerX):
def check_cfg(self, cfg):
assert cfg.TRAINER.DZGCOOP.PREC in ["fp16", "fp32", "amp"]
assert cfg.TRAINER.PROMPTSRC.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
@@ -272,7 +271,7 @@ class DZGCoOp(TrainerX):
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.DZGCOOP.PREC == "fp32" or cfg.TRAINER.DZGCOOP.PREC == "amp":
if cfg.TRAINER.PROMPTSRC.PREC == "fp32" or cfg.TRAINER.PROMPTSRC.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
@@ -312,20 +311,20 @@ class DZGCoOp(TrainerX):
self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.step_counter = 1
N = cfg.OPTIM.MAX_EPOCH
mean = cfg.TRAINER.DZGCOOP.EWA_MEAN
stdev = cfg.TRAINER.DZGCOOP.EWA_STD
normal = self.get_normal(mean, stdev)
self.normal_weights = np.array([normal(a) for a in range(1, N + 1)])
self.normal_weights = self.normal_weights / sum(self.normal_weights)
self.scaler = GradScaler() if cfg.TRAINER.DZGCOOP.PREC == "amp" else None
mean = cfg.TRAINER.PROMPTSRC.GPA_MEAN
stdev = cfg.TRAINER.PROMPTSRC.GPA_STD
gauss = self.get_gauss(mean, stdev)
self.gauss = np.array([gauss(a) for a in range(1, N + 1)])
self.gauss = self.gauss / sum(self.gauss)
self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model)
# Keep model with EWA
self.previous_model_ewa = None
# Keep model with GPA
self.previous_model_gpa = None
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
@@ -334,7 +333,7 @@ class DZGCoOp(TrainerX):
optim = self.optim
scaler = self.scaler
prec = self.cfg.TRAINER.DZGCOOP.PREC
prec = self.cfg.TRAINER.PROMPTSRC.PREC
if prec == "amp":
with autocast():
loss = model(image, label)
@@ -343,26 +342,26 @@ class DZGCoOp(TrainerX):
scaler.step(optim)
scaler.update()
else:
loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zs_image_embedd, image_ft, \
loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT
lambda2 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG
lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK
lambda1 = self.cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT
lambda2 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG
lambda3 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK
L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
loss_scl_image = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
loss_scl_text_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
loss_scl_text_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
L_zpg = F.kl_div(
L_SCL_logits = F.kl_div(
F.log_softmax(logits_final / 1, dim=1),
F.log_softmax(zero_shot_logits / 1, dim=1),
reduction='sum',
log_target=True
) * (1 * 1) / logits_final.numel()
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
loss = (loss_ce + L_zg)
L_SCL = (L_SCL_logits + loss_scl_text_strong + loss_scl_text_weak + loss_scl_image)
loss = (loss_ce + L_SCL)
optim.zero_grad()
loss.backward()
optim.step()
@@ -371,22 +370,20 @@ class DZGCoOp(TrainerX):
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
# Means one epoch is completed, perform EWA
# Means one epoch is completed, perform GPA
self.step_counter = self.step_counter + 1
current_epoch_weight = self.normal_weights[self.step_counter - 2]
current_epoch_weight = self.gauss[self.step_counter - 2]
current_model_weights = copy.deepcopy(model.state_dict())
for key in current_model_weights:
current_model_weights[key] = current_model_weights[key].cpu()
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight)
if self.previous_model_ewa is None:
self.previous_model_ewa = weighted_state_dict
if self.previous_model_gpa is None:
self.previous_model_gpa = weighted_state_dict
else:
self.previous_model_ewa = self.state_dict_add(weighted_state_dict, self.previous_model_ewa)
self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa)
if self.step_counter == self.model.total_epochs + 1:
print("Using EWA model for final inference...")
model.load_state_dict(self.previous_model_ewa)
self.model.load_state_dict(self.previous_model_ewa)
print("Using GPA model for final inference...")
model.load_state_dict(self.previous_model_gpa)
self.model.load_state_dict(self.previous_model_gpa)
return loss_summary
def state_dict_weighting(self, main_dict, weightage, prompt_only=False):
@@ -394,24 +391,24 @@ class DZGCoOp(TrainerX):
updated_dict = copy.deepcopy(main_dict)
if not prompt_only:
for key in main_dict:
updated_dict[key] = main_dict[key].cpu() * weightage
updated_dict[key] = main_dict[key] * weightage
return updated_dict
else:
return main_dict.cpu() * weightage
return main_dict * weightage
def state_dict_add(self, dict1, dict2, prompt_only=False):
# Average all parameters
if not prompt_only:
modified_dict = dict2
for key in dict1:
modified_dict[key] = modified_dict[key].cpu() + dict1[key].cpu()
modified_dict[key] = (modified_dict[key] + dict1[key])
return modified_dict
else:
return dict1.cpu() + dict2.cpu()
return dict1 + dict2
def get_normal(self, mu, sigma):
normal = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return normal
def get_gauss(self, mu, sigma):
gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return gauss
def parse_batch_train(self, batch):
input = batch["img"]