diff --git a/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml b/configs/trainers/DZGCoOp/vit_b16_c2_ep20_batch4_4+4ctx.yaml similarity index 91% rename from configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml rename to configs/trainers/DZGCoOp/vit_b16_c2_ep20_batch4_4+4ctx.yaml index fbe381d..7e3826e 100644 --- a/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml +++ b/configs/trainers/DZGCoOp/vit_b16_c2_ep20_batch4_4+4ctx.yaml @@ -1,4 +1,4 @@ -# PromptSRC: Prompting with Self-regularizing constraints +# DZGCoOp: Dual-branch Zero-shot Guidance CoOp DATALOADER: TRAIN_X: BATCH_SIZE: 4 @@ -30,7 +30,7 @@ MODEL: NAME: "ViT-B/16" TRAINER: - PROMPTSRC: + DZGCOOP: N_CTX_VISION: 4 N_CTX_TEXT: 4 CTX_INIT: "a photo of a" diff --git a/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml b/configs/trainers/DZGCoOp/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml similarity index 91% rename from configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml rename to configs/trainers/DZGCoOp/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml index 895912a..57d3e64 100644 --- a/configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml +++ b/configs/trainers/DZGCoOp/vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets.yaml @@ -1,4 +1,4 @@ -# PromptSRC: Prompting with Self-regularizing constraints +# DZGCoOp: Dual-branch Zero-shot Guidance CoOp DATALOADER: TRAIN_X: BATCH_SIZE: 4 @@ -31,7 +31,7 @@ MODEL: NAME: "ViT-B/16" TRAINER: - PROMPTSRC: + DZGCOOP: N_CTX_VISION: 4 N_CTX_TEXT: 4 CTX_INIT: "a photo of a" diff --git a/configs/trainers/PromptSRC/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml b/configs/trainers/DZGCoOp/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml similarity index 91% rename from configs/trainers/PromptSRC/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml rename to configs/trainers/DZGCoOp/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml index 2dbbacd..336a810 100644 --- a/configs/trainers/PromptSRC/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml +++ b/configs/trainers/DZGCoOp/vit_b16_c2_ep5_batch4_4+4ctx_cross_datasets.yaml @@ -1,4 +1,4 @@ -# PromptSRC: Prompting with Self-regularizing constraints +# DZGCoOp: Dual-branch Zero-shot Guidance CoOp DATALOADER: TRAIN_X: BATCH_SIZE: 4 @@ -30,7 +30,7 @@ MODEL: NAME: "ViT-B/16" TRAINER: - PROMPTSRC: + DZGCOOP: N_CTX_VISION: 4 N_CTX_TEXT: 4 CTX_INIT: "a photo of a" diff --git a/extract_acc.py b/extract_acc.py index dcd7e55..15723b4 100644 --- a/extract_acc.py +++ b/extract_acc.py @@ -109,7 +109,7 @@ def print_model_results(results, model_name): def main(): root_dir = 'output' # 修改为你的output目录路径 - target_model = 'PromptSRC' # 指定要分析的模型 + target_model = 'DZGCoOp' # 指定要分析的模型 results = collect_model_results(root_dir, target_model) print_model_results(results, target_model) diff --git a/scripts/dzgcoop/base2new_all.sh b/scripts/dzgcoop/base2new_all.sh new file mode 100644 index 0000000..c9be780 --- /dev/null +++ b/scripts/dzgcoop/base2new_all.sh @@ -0,0 +1,22 @@ +seeds=(1 2 3) +datasets=( + "ucf101" + "eurosat" + "oxford_pets" + "food101" + "oxford_flowers" + "dtd" + "caltech101" + "fgvc_aircraft" + "stanford_cars" + # "sun397" + # "imagenet" +) + +for dataset in "${datasets[@]}"; do + for seed in "${seeds[@]}"; do + bash scripts/dzgcoop/base2new_train.sh "$dataset" "$seed" + bash scripts/dzgcoop/base2new_test.sh "$dataset" "$seed" + done +done + diff --git a/scripts/promptsrc/base2new_test.sh b/scripts/dzgcoop/base2new_test.sh similarity index 98% rename from scripts/promptsrc/base2new_test.sh rename to scripts/dzgcoop/base2new_test.sh index c1a68dd..67b566c 100644 --- a/scripts/promptsrc/base2new_test.sh +++ b/scripts/dzgcoop/base2new_test.sh @@ -3,7 +3,7 @@ # custom config DATA="~/Datasets/CoOp" -TRAINER=PromptSRC +TRAINER=DZGCoOp DATASET=$1 SEED=$2 diff --git a/scripts/promptsrc/base2new_train.sh b/scripts/dzgcoop/base2new_train.sh similarity index 98% rename from scripts/promptsrc/base2new_train.sh rename to scripts/dzgcoop/base2new_train.sh index 95f0eda..846dbb9 100644 --- a/scripts/promptsrc/base2new_train.sh +++ b/scripts/dzgcoop/base2new_train.sh @@ -2,7 +2,7 @@ # custom config DATA="~/Datasets/CoOp" -TRAINER=PromptSRC +TRAINER=DZGCoOp DATASET=$1 SEED=$2 diff --git a/scripts/promptsrc/xd_train.sh b/scripts/dzgcoop/xd_train.sh similarity index 97% rename from scripts/promptsrc/xd_train.sh rename to scripts/dzgcoop/xd_train.sh index 2268129..7e54f22 100644 --- a/scripts/promptsrc/xd_train.sh +++ b/scripts/dzgcoop/xd_train.sh @@ -2,7 +2,7 @@ DATA=" ~/Datasets/CoOp" -TRAINER=PromptSRC +TRAINER=DZGCoOp SRC_DATASETS=imagenet SHOTS=16 CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets diff --git a/scripts/promptsrc/xda_test.sh b/scripts/dzgcoop/xda_test.sh similarity index 98% rename from scripts/promptsrc/xda_test.sh rename to scripts/dzgcoop/xda_test.sh index 7a54645..cba7deb 100644 --- a/scripts/promptsrc/xda_test.sh +++ b/scripts/dzgcoop/xda_test.sh @@ -3,7 +3,7 @@ # custom config DATA=" ~/Datasets/CoOp" -TRAINER=PromptSRC +TRAINER=DZGCoOp SRC_DATASETS=imagenet diff --git a/scripts/promptsrc/xdo_test.sh b/scripts/dzgcoop/xdo_test.sh similarity index 98% rename from scripts/promptsrc/xdo_test.sh rename to scripts/dzgcoop/xdo_test.sh index 00221fb..77493c6 100644 --- a/scripts/promptsrc/xdo_test.sh +++ b/scripts/dzgcoop/xdo_test.sh @@ -3,7 +3,7 @@ # custom config DATA=" ~/Datasets/CoOp" -TRAINER=PromptSRC +TRAINER=DZGCoOp SRC_DATASETS=imagenet diff --git a/scripts/promptsrc/base2new_all.sh b/scripts/promptsrc/base2new_all.sh deleted file mode 100644 index 65fdd42..0000000 --- a/scripts/promptsrc/base2new_all.sh +++ /dev/null @@ -1,22 +0,0 @@ -seeds=(1 2 3) -datasets=( - # "ucf101" - # "eurosat" - # "oxford_pets" - # "food101" - # "oxford_flowers" - # "dtd" - # "caltech101" - # "fgvc_aircraft" - # "stanford_cars" - # "sun397" - "imagenet" -) - -for dataset in "${datasets[@]}"; do - for seed in "${seeds[@]}"; do - bash scripts/promptsrc/base2new_train.sh "$dataset" "$seed" - bash scripts/promptsrc/base2new_test.sh "$dataset" "$seed" - done -done - diff --git a/train.py b/train.py index 09fab30..fab7f4b 100644 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ import trainers.cocoop import trainers.zsclip import trainers.maple import trainers.independentVL -import trainers.promptsrc +import trainers.dzgcoop def print_args(args, cfg): @@ -110,20 +110,20 @@ def extend_cfg(cfg): cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1) cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new - # Config for PromptSRC - cfg.TRAINER.PROMPTSRC = CN() - cfg.TRAINER.PROMPTSRC.N_CTX_VISION = 4 # number of context vectors at the vision branch - cfg.TRAINER.PROMPTSRC.N_CTX_TEXT = 4 # number of context vectors at the language branch - cfg.TRAINER.PROMPTSRC.CTX_INIT = "a photo of a" # initialization words - cfg.TRAINER.PROMPTSRC.PREC = "fp16" # fp16, fp32, amp - cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1) - cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1) - cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25 - cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight - cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight - cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10 - cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15 - cfg.TRAINER.PROMPTSRC.GPA_STD = 1 + # Config for DZGCoOp + cfg.TRAINER.DZGCOOP = CN() + cfg.TRAINER.DZGCOOP.N_CTX_VISION = 4 # number of context vectors at the vision branch + cfg.TRAINER.DZGCOOP.N_CTX_TEXT = 4 # number of context vectors at the language branch + cfg.TRAINER.DZGCOOP.CTX_INIT = "a photo of a" # initialization words + cfg.TRAINER.DZGCOOP.PREC = "fp16" # fp16, fp32, amp + cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1) + cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1) + cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT = 25 + cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight + cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight + cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT = 10 + cfg.TRAINER.DZGCOOP.GPA_MEAN = 15 + cfg.TRAINER.DZGCOOP.GPA_STD = 1 cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new # Config for independent Vision Language prompting (independent-vlp) diff --git a/trainers/promptsrc.py b/trainers/dzgcoop.py similarity index 91% rename from trainers/promptsrc.py rename to trainers/dzgcoop.py index f12e9cd..3451776 100644 --- a/trainers/promptsrc.py +++ b/trainers/dzgcoop.py @@ -51,10 +51,10 @@ def load_clip_to_cpu(cfg, zero_shot_model=False): state_dict = torch.load(model_path, map_location="cpu") if not zero_shot_model: design_details = {"trainer": 'IVLP', - "vision_depth": cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION, - "language_depth": cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT, - "vision_ctx": cfg.TRAINER.PROMPTSRC.N_CTX_VISION, - "language_ctx": cfg.TRAINER.PROMPTSRC.N_CTX_TEXT} + "vision_depth": cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_VISION, + "language_depth": cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT, + "vision_ctx": cfg.TRAINER.DZGCOOP.N_CTX_VISION, + "language_ctx": cfg.TRAINER.DZGCOOP.N_CTX_TEXT} model = clip.build_model(state_dict or model.state_dict(), design_details) else: # Return original CLIP model for generating frozen VL features @@ -95,11 +95,11 @@ class VLPromptLearner(nn.Module): super().__init__() n_cls = len(classnames) # Make sure Language depth >= 1 - assert cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \ + assert cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \ "\nPlease use VPT trainer if you want to learn only vision " \ "branch" - n_ctx = cfg.TRAINER.PROMPTSRC.N_CTX_TEXT - ctx_init = cfg.TRAINER.PROMPTSRC.CTX_INIT + n_ctx = cfg.TRAINER.DZGCOOP.N_CTX_TEXT + ctx_init = cfg.TRAINER.DZGCOOP.CTX_INIT dtype = clip_model.dtype ctx_dim = clip_model.ln_final.weight.shape[0] clip_imsize = clip_model.visual.input_resolution @@ -126,7 +126,7 @@ class VLPromptLearner(nn.Module): print(f'Strong branch initial text context: "{prompt_prefix_strong}"') print(f'Weak branch initial text context: "{prompt_prefix_weak}"') print(f"Number of context words (tokens) for Language prompting: {n_ctx}") - print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.PROMPTSRC.N_CTX_VISION}") + print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.DZGCOOP.N_CTX_VISION}") self.ctx_strong = nn.Parameter(ctx_vectors_strong) self.ctx_weak = nn.Parameter(ctx_vectors_weak) @@ -261,9 +261,9 @@ class CustomCLIP(nn.Module): @TRAINER_REGISTRY.register() -class PromptSRC(TrainerX): +class DZGCoOp(TrainerX): def check_cfg(self, cfg): - assert cfg.TRAINER.PROMPTSRC.PREC in ["fp16", "fp32", "amp"] + assert cfg.TRAINER.DZGCOOP.PREC in ["fp16", "fp32", "amp"] def build_model(self): cfg = self.cfg @@ -272,7 +272,7 @@ class PromptSRC(TrainerX): print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) - if cfg.TRAINER.PROMPTSRC.PREC == "fp32" or cfg.TRAINER.PROMPTSRC.PREC == "amp": + if cfg.TRAINER.DZGCOOP.PREC == "fp32" or cfg.TRAINER.DZGCOOP.PREC == "amp": # CLIP's default precision is fp16 clip_model.float() @@ -312,12 +312,12 @@ class PromptSRC(TrainerX): self.total_epochs = cfg.OPTIM.MAX_EPOCH self.step_counter = 1 N = cfg.OPTIM.MAX_EPOCH - mean = cfg.TRAINER.PROMPTSRC.GPA_MEAN - stdev = cfg.TRAINER.PROMPTSRC.GPA_STD + mean = cfg.TRAINER.DZGCOOP.GPA_MEAN + stdev = cfg.TRAINER.DZGCOOP.GPA_STD gauss = self.get_gauss(mean, stdev) self.gauss = np.array([gauss(a) for a in range(1, N + 1)]) self.gauss = self.gauss / sum(self.gauss) - self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None + self.scaler = GradScaler() if cfg.TRAINER.DZGCOOP.PREC == "amp" else None # Note that multi-gpu training could be slow because CLIP's size is # big, which slows down the copy operation in DataParallel device_count = torch.cuda.device_count() @@ -334,7 +334,7 @@ class PromptSRC(TrainerX): optim = self.optim scaler = self.scaler - prec = self.cfg.TRAINER.PROMPTSRC.PREC + prec = self.cfg.TRAINER.DZGCOOP.PREC if prec == "amp": with autocast(): loss = model(image, label) @@ -346,23 +346,23 @@ class PromptSRC(TrainerX): loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \ zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label) - lambda1 = self.cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT - lambda2 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG - lambda3 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK + lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT + lambda2 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG + lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK - loss_scl_image = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1 - loss_scl_text_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2 - loss_scl_text_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3 + L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1 + L_sg_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2 + L_sg_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3 - L_SCL_logits = F.kl_div( + L_zpg = F.kl_div( F.log_softmax(logits_final / 1, dim=1), F.log_softmax(zero_shot_logits / 1, dim=1), reduction='sum', log_target=True ) * (1 * 1) / logits_final.numel() - L_SCL = (L_SCL_logits + loss_scl_text_strong + loss_scl_text_weak + loss_scl_image) - loss = (loss_ce + L_SCL) + L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg) + loss = (loss_ce + L_zg) optim.zero_grad() loss.backward() optim.step()