1 Commits
multi ... uma

Author SHA1 Message Date
1d7d93ede5 Last-k Average 2026-02-07 15:58:51 +08:00
13 changed files with 86 additions and 111 deletions

View File

@@ -1,4 +1,4 @@
# DZGCoOp: Dual-branch Zero-shot Guidance CoOp # PromptSRC: Prompting with Self-regularizing constraints
DATALOADER: DATALOADER:
TRAIN_X: TRAIN_X:
BATCH_SIZE: 4 BATCH_SIZE: 4
@@ -30,15 +30,13 @@ MODEL:
NAME: "ViT-B/16" NAME: "ViT-B/16"
TRAINER: TRAINER:
DZGCOOP: PROMPTSRC:
N_CTX_VISION: 4 N_CTX_VISION: 4
N_CTX_TEXT: 4 N_CTX_TEXT: 4
CTX_INIT: "a photo of a" CTX_INIT: "a photo of a"
PREC: "fp16" PREC: "fp16"
PROMPT_DEPTH_VISION: 9 PROMPT_DEPTH_VISION: 9
PROMPT_DEPTH_TEXT: 9 PROMPT_DEPTH_TEXT: 9
IMAGE_LOSS_WEIGHT: 8 TEXT_LOSS_WEIGHT: 25
TEXT_LOSS_WEIGHT_STRONG: 24 IMAGE_LOSS_WEIGHT: 10
TEXT_LOSS_WEIGHT_WEAK: 8 LAST_K: 5
EWA_MEAN: 15
EWA_STD: 1

View File

@@ -1,4 +1,4 @@
# DZGCoOp: Dual-branch Zero-shot Guidance CoOp # PromptSRC: Prompting with Self-regularizing constraints
DATALOADER: DATALOADER:
TRAIN_X: TRAIN_X:
BATCH_SIZE: 4 BATCH_SIZE: 4
@@ -31,7 +31,7 @@ MODEL:
NAME: "ViT-B/16" NAME: "ViT-B/16"
TRAINER: TRAINER:
DZGCOOP: PROMPTSRC:
N_CTX_VISION: 4 N_CTX_VISION: 4
N_CTX_TEXT: 4 N_CTX_TEXT: 4
CTX_INIT: "a photo of a" CTX_INIT: "a photo of a"
@@ -40,5 +40,4 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
EWA_MEAN: 6 LAST_K: 5
EWA_STD: 10

View File

@@ -1,4 +1,4 @@
# DZGCoOp: Dual-branch Zero-shot Guidance CoOp # PromptSRC: Prompting with Self-regularizing constraints
DATALOADER: DATALOADER:
TRAIN_X: TRAIN_X:
BATCH_SIZE: 4 BATCH_SIZE: 4
@@ -30,7 +30,7 @@ MODEL:
NAME: "ViT-B/16" NAME: "ViT-B/16"
TRAINER: TRAINER:
DZGCOOP: PROMPTSRC:
N_CTX_VISION: 4 N_CTX_VISION: 4
N_CTX_TEXT: 4 N_CTX_TEXT: 4
CTX_INIT: "a photo of a" CTX_INIT: "a photo of a"
@@ -39,5 +39,4 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
EWA_MEAN: 6 LAST_K: 5
EWA_STD: 10

View File

@@ -11,7 +11,7 @@ Training PromptSRC on ImageNet for 20 epochs takes around 6 hours for a single s
## PromptSRC ## PromptSRC
#### (1) Base-to-Novel class generalization setting #### (1) Base-to-Novel class generalization setting
The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as EWA STD, EWA Mean, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file. The base-to-novel PromptSRC configuration is provided in config file at `configs/trainers/PromptSRC/vit_b16_c2_ep20_batch4_4+4ctx.yaml`. All hyper-parameters such as LAST_K, SCL loss weights coefficients, prompt length and prompt depth etc., can be modified using this config file.
Run the commands below to train PromptSRC on ImageNet. Run the commands below to train PromptSRC on ImageNet.

View File

@@ -109,7 +109,7 @@ def print_model_results(results, model_name):
def main(): def main():
root_dir = 'output' # 修改为你的output目录路径 root_dir = 'output' # 修改为你的output目录路径
target_model = 'DZGCoOp' # 指定要分析的模型 target_model = 'PromptSRC' # 指定要分析的模型
results = collect_model_results(root_dir, target_model) results = collect_model_results(root_dir, target_model)
print_model_results(results, target_model) print_model_results(results, target_model)

View File

@@ -10,13 +10,13 @@ datasets=(
"fgvc_aircraft" "fgvc_aircraft"
"stanford_cars" "stanford_cars"
"sun397" "sun397"
# "imagenet" "imagenet"
) )
for dataset in "${datasets[@]}"; do for dataset in "${datasets[@]}"; do
for seed in "${seeds[@]}"; do for seed in "${seeds[@]}"; do
bash scripts/dzgcoop/base2new_train.sh "$dataset" "$seed" bash scripts/promptsrc/base2new_train.sh "$dataset" "$seed"
bash scripts/dzgcoop/base2new_test.sh "$dataset" "$seed" bash scripts/promptsrc/base2new_test.sh "$dataset" "$seed"
done done
done done

View File

@@ -3,7 +3,7 @@
# custom config # custom config
DATA="~/Datasets/CoOp" DATA="~/Datasets/CoOp"
TRAINER=DZGCoOp TRAINER=PromptSRC
DATASET=$1 DATASET=$1
SEED=$2 SEED=$2

View File

@@ -2,7 +2,7 @@
# custom config # custom config
DATA="~/Datasets/CoOp" DATA="~/Datasets/CoOp"
TRAINER=DZGCoOp TRAINER=PromptSRC
DATASET=$1 DATASET=$1
SEED=$2 SEED=$2

View File

@@ -2,7 +2,7 @@
DATA=" ~/Datasets/CoOp" DATA=" ~/Datasets/CoOp"
TRAINER=DZGCoOp TRAINER=PromptSRC
SRC_DATASETS=imagenet SRC_DATASETS=imagenet
SHOTS=16 SHOTS=16
CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets CFG=vit_b16_c2_ep20_batch4_4+4ctx_cross_datasets

View File

@@ -3,7 +3,7 @@
# custom config # custom config
DATA=" ~/Datasets/CoOp" DATA=" ~/Datasets/CoOp"
TRAINER=DZGCoOp TRAINER=PromptSRC
SRC_DATASETS=imagenet SRC_DATASETS=imagenet

View File

@@ -3,7 +3,7 @@
# custom config # custom config
DATA=" ~/Datasets/CoOp" DATA=" ~/Datasets/CoOp"
TRAINER=DZGCoOp TRAINER=PromptSRC
SRC_DATASETS=imagenet SRC_DATASETS=imagenet

View File

@@ -28,7 +28,7 @@ import trainers.cocoop
import trainers.zsclip import trainers.zsclip
import trainers.maple import trainers.maple
import trainers.independentVL import trainers.independentVL
import trainers.dzgcoop import trainers.promptsrc
def print_args(args, cfg): def print_args(args, cfg):
@@ -110,19 +110,19 @@ def extend_cfg(cfg):
cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1) cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1)
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for DZGCoOp # Config for PromptSRC
cfg.TRAINER.DZGCOOP = CN() cfg.TRAINER.PROMPTSRC = CN()
cfg.TRAINER.DZGCOOP.N_CTX_VISION = 4 # number of context vectors at the vision branch cfg.TRAINER.PROMPTSRC.N_CTX_VISION = 4 # number of context vectors at the vision branch
cfg.TRAINER.DZGCOOP.N_CTX_TEXT = 4 # number of context vectors at the language branch cfg.TRAINER.PROMPTSRC.N_CTX_TEXT = 4 # number of context vectors at the language branch
cfg.TRAINER.DZGCOOP.CTX_INIT = "a photo of a" # initialization words cfg.TRAINER.PROMPTSRC.CTX_INIT = "a photo of a" # initialization words
cfg.TRAINER.DZGCOOP.PREC = "fp16" # fp16, fp32, amp cfg.TRAINER.PROMPTSRC.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1) cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1) cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25
cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK = 10 # lambda3: weak text constraint weight cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT = 10 cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight
cfg.TRAINER.DZGCOOP.EWA_MEAN = 15 cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.DZGCOOP.EWA_STD = 1 cfg.TRAINER.PROMPTSRC.LAST_K = 5
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for independent Vision Language prompting (independent-vlp) # Config for independent Vision Language prompting (independent-vlp)

View File

@@ -51,10 +51,10 @@ def load_clip_to_cpu(cfg, zero_shot_model=False):
state_dict = torch.load(model_path, map_location="cpu") state_dict = torch.load(model_path, map_location="cpu")
if not zero_shot_model: if not zero_shot_model:
design_details = {"trainer": 'IVLP', design_details = {"trainer": 'IVLP',
"vision_depth": cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_VISION, "vision_depth": cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION,
"language_depth": cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT, "language_depth": cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT,
"vision_ctx": cfg.TRAINER.DZGCOOP.N_CTX_VISION, "vision_ctx": cfg.TRAINER.PROMPTSRC.N_CTX_VISION,
"language_ctx": cfg.TRAINER.DZGCOOP.N_CTX_TEXT} "language_ctx": cfg.TRAINER.PROMPTSRC.N_CTX_TEXT}
model = clip.build_model(state_dict or model.state_dict(), design_details) model = clip.build_model(state_dict or model.state_dict(), design_details)
else: else:
# Return original CLIP model for generating frozen VL features # Return original CLIP model for generating frozen VL features
@@ -95,11 +95,11 @@ class VLPromptLearner(nn.Module):
super().__init__() super().__init__()
n_cls = len(classnames) n_cls = len(classnames)
# Make sure Language depth >= 1 # Make sure Language depth >= 1
assert cfg.TRAINER.DZGCOOP.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \ assert cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \
"\nPlease use VPT trainer if you want to learn only vision " \ "\nPlease use VPT trainer if you want to learn only vision " \
"branch" "branch"
n_ctx = cfg.TRAINER.DZGCOOP.N_CTX_TEXT n_ctx = cfg.TRAINER.PROMPTSRC.N_CTX_TEXT
ctx_init = cfg.TRAINER.DZGCOOP.CTX_INIT ctx_init = cfg.TRAINER.PROMPTSRC.CTX_INIT
dtype = clip_model.dtype dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0] ctx_dim = clip_model.ln_final.weight.shape[0]
clip_imsize = clip_model.visual.input_resolution clip_imsize = clip_model.visual.input_resolution
@@ -126,7 +126,7 @@ class VLPromptLearner(nn.Module):
print(f'Strong branch initial text context: "{prompt_prefix_strong}"') print(f'Strong branch initial text context: "{prompt_prefix_strong}"')
print(f'Weak branch initial text context: "{prompt_prefix_weak}"') print(f'Weak branch initial text context: "{prompt_prefix_weak}"')
print(f"Number of context words (tokens) for Language prompting: {n_ctx}") print(f"Number of context words (tokens) for Language prompting: {n_ctx}")
print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.DZGCOOP.N_CTX_VISION}") print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.PROMPTSRC.N_CTX_VISION}")
self.ctx_strong = nn.Parameter(ctx_vectors_strong) self.ctx_strong = nn.Parameter(ctx_vectors_strong)
self.ctx_weak = nn.Parameter(ctx_vectors_weak) self.ctx_weak = nn.Parameter(ctx_vectors_weak)
@@ -142,7 +142,7 @@ class VLPromptLearner(nn.Module):
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.ZS_image_encoder = clip_model_temp_image.visual self.ZS_image_encoder = clip_model_temp_image.visual
# Now pre-compute the frozen VL embeddings from LLM descriptions # Now pre-compute the frozen VL embeddings from LLM descriptions
semantic_guidance_features = [] all_teacher_features = []
desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json" desc_file = f"./desc/{DESC_LLM}/descriptions_top{DESC_TOPK}/{cfg.DATASET.NAME}.json"
with open(desc_file, "r") as f: with open(desc_file, "r") as f:
all_desc = json.load(f) all_desc = json.load(f)
@@ -155,9 +155,9 @@ class VLPromptLearner(nn.Module):
cls_feature = clip_model_temp.encode_text(cls_token) cls_feature = clip_model_temp.encode_text(cls_token)
cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True) cls_feature = cls_feature / cls_feature.norm(dim=-1, keepdim=True)
cls_feature = torch.mean(cls_feature, dim=0) cls_feature = torch.mean(cls_feature, dim=0)
semantic_guidance_features.append(cls_feature) all_teacher_features.append(cls_feature)
self.semantic_embeddings = torch.stack(semantic_guidance_features) self.fixed_embeddings = torch.stack(all_teacher_features)
print(f"Using LLM descriptions from: {desc_file}") print(f"Using LLM descriptions from: {desc_file}")
# These token vectors will be saved when in save_model(), # These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use # but they should be ignored in load_model() as we want to use
@@ -238,10 +238,10 @@ class CustomCLIP(nn.Module):
text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts) text_features_weak = self.text_encoder(prompts_weak, tokenized_prompts)
text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True) text_features_weak = text_features_weak / text_features_weak.norm(dim=-1, keepdim=True)
semantic_embeddings = self.prompt_learner.semantic_embeddings fixed_embeddings = self.prompt_learner.fixed_embeddings
semantic_embeddings = semantic_embeddings / semantic_embeddings.norm(dim=-1, keepdim=True) fixed_embeddings = fixed_embeddings / fixed_embeddings.norm(dim=-1, keepdim=True)
zero_shot_logits = logit_scale * zero_shot_features.cuda() @ semantic_embeddings.half().cuda().t() zero_shot_logits = logit_scale * zero_shot_features.cuda() @ fixed_embeddings.half().cuda().t()
logits_strong = logit_scale * image_features @ text_features_strong.t() logits_strong = logit_scale * image_features @ text_features_strong.t()
logits_weak = logit_scale * image_features @ text_features_weak.t() logits_weak = logit_scale * image_features @ text_features_weak.t()
@@ -255,15 +255,15 @@ class CustomCLIP(nn.Module):
if self.prompt_learner.training: if self.prompt_learner.training:
loss_ce = F.cross_entropy(logits_final, label) loss_ce = F.cross_entropy(logits_final, label)
return loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final return loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zero_shot_features, image_features, zero_shot_logits, logits_strong, logits_weak, logits_final
else: else:
return logits_final return logits_final
@TRAINER_REGISTRY.register() @TRAINER_REGISTRY.register()
class DZGCoOp(TrainerX): class PromptSRC(TrainerX):
def check_cfg(self, cfg): def check_cfg(self, cfg):
assert cfg.TRAINER.DZGCOOP.PREC in ["fp16", "fp32", "amp"] assert cfg.TRAINER.PROMPTSRC.PREC in ["fp16", "fp32", "amp"]
def build_model(self): def build_model(self):
cfg = self.cfg cfg = self.cfg
@@ -272,7 +272,7 @@ class DZGCoOp(TrainerX):
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg) clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.DZGCOOP.PREC == "fp32" or cfg.TRAINER.DZGCOOP.PREC == "amp": if cfg.TRAINER.PROMPTSRC.PREC == "fp32" or cfg.TRAINER.PROMPTSRC.PREC == "amp":
# CLIP's default precision is fp16 # CLIP's default precision is fp16
clip_model.float() clip_model.float()
@@ -311,21 +311,15 @@ class DZGCoOp(TrainerX):
# Cosine scheduler # Cosine scheduler
self.total_epochs = cfg.OPTIM.MAX_EPOCH self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.step_counter = 1 self.step_counter = 1
N = cfg.OPTIM.MAX_EPOCH self.max_k = cfg.TRAINER.PROMPTSRC.LAST_K
mean = cfg.TRAINER.DZGCOOP.EWA_MEAN self.last_k_models = []
stdev = cfg.TRAINER.DZGCOOP.EWA_STD self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None
normal = self.get_normal(mean, stdev)
self.normal_weights = np.array([normal(a) for a in range(1, N + 1)])
self.normal_weights = self.normal_weights / sum(self.normal_weights)
self.scaler = GradScaler() if cfg.TRAINER.DZGCOOP.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is # Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel # big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
if device_count > 1: if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model) self.model = nn.DataParallel(self.model)
# Keep model with EWA
self.previous_model_ewa = None
def forward_backward(self, batch): def forward_backward(self, batch):
image, label = self.parse_batch_train(batch) image, label = self.parse_batch_train(batch)
@@ -334,7 +328,7 @@ class DZGCoOp(TrainerX):
optim = self.optim optim = self.optim
scaler = self.scaler scaler = self.scaler
prec = self.cfg.TRAINER.DZGCOOP.PREC prec = self.cfg.TRAINER.PROMPTSRC.PREC
if prec == "amp": if prec == "amp":
with autocast(): with autocast():
loss = model(image, label) loss = model(image, label)
@@ -343,26 +337,26 @@ class DZGCoOp(TrainerX):
scaler.step(optim) scaler.step(optim)
scaler.update() scaler.update()
else: else:
loss_ce, text_features_strong, text_features_weak, semantic_embeddings, zs_image_embedd, image_ft, \ loss_ce, text_features_strong, text_features_weak, fixed_embeddings, zs_image_embedd, image_ft, \
zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label) zero_shot_logits, logits_strong, logits_weak, logits_final = model(image, label)
lambda1 = self.cfg.TRAINER.DZGCOOP.IMAGE_LOSS_WEIGHT lambda1 = self.cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT
lambda2 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_STRONG lambda2 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG
lambda3 = self.cfg.TRAINER.DZGCOOP.TEXT_LOSS_WEIGHT_WEAK lambda3 = self.cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK
L_zvg = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1 loss_scl_image = F.l1_loss(image_ft, zs_image_embedd.cuda(), reduction='mean') * lambda1
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2 loss_scl_text_strong = F.l1_loss(text_features_strong, fixed_embeddings.cuda(), reduction='mean') * lambda2
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3 loss_scl_text_weak = F.l1_loss(text_features_weak, fixed_embeddings.cuda(), reduction='mean') * lambda3
L_zpg = F.kl_div( L_SCL_logits = F.kl_div(
F.log_softmax(logits_final / 1, dim=1), F.log_softmax(logits_final / 1, dim=1),
F.log_softmax(zero_shot_logits / 1, dim=1), F.log_softmax(zero_shot_logits / 1, dim=1),
reduction='sum', reduction='sum',
log_target=True log_target=True
) * (1 * 1) / logits_final.numel() ) * (1 * 1) / logits_final.numel()
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg) L_SCL = (L_SCL_logits + loss_scl_text_strong + loss_scl_text_weak + loss_scl_image)
loss = (loss_ce + L_zg) loss = (loss_ce + L_SCL)
optim.zero_grad() optim.zero_grad()
loss.backward() loss.backward()
optim.step() optim.step()
@@ -371,47 +365,32 @@ class DZGCoOp(TrainerX):
if (self.batch_idx + 1) == self.num_batches: if (self.batch_idx + 1) == self.num_batches:
self.update_lr() self.update_lr()
# Means one epoch is completed, perform EWA
self.step_counter = self.step_counter + 1 self.step_counter = self.step_counter + 1
current_epoch_weight = self.normal_weights[self.step_counter - 2]
current_model_weights = copy.deepcopy(model.state_dict()) current_model_weights = copy.deepcopy(model.state_dict())
for key in current_model_weights: for key in current_model_weights:
current_model_weights[key] = current_model_weights[key].cpu() current_model_weights[key] = current_model_weights[key].cpu()
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight) self.last_k_models.append(current_model_weights)
if self.previous_model_ewa is None: if len(self.last_k_models) > self.max_k:
self.previous_model_ewa = weighted_state_dict self.last_k_models.pop(0)
else: torch.cuda.empty_cache()
self.previous_model_ewa = self.state_dict_add(weighted_state_dict, self.previous_model_ewa)
if self.step_counter == self.model.total_epochs + 1: if self.step_counter == self.model.total_epochs + 1:
print("Using EWA model for final inference...") print(f"Using Last-K Averaging (K={len(self.last_k_models)}) model for final inference...")
model.load_state_dict(self.previous_model_ewa) averaged_state_dict = self._average_last_k_models()
self.model.load_state_dict(self.previous_model_ewa) for key in averaged_state_dict:
averaged_state_dict[key] = averaged_state_dict[key].cuda()
model.load_state_dict(averaged_state_dict)
self.model.load_state_dict(averaged_state_dict)
return loss_summary return loss_summary
def state_dict_weighting(self, main_dict, weightage, prompt_only=False): def _average_last_k_models(self):
# Average all parameters if not self.last_k_models:
updated_dict = copy.deepcopy(main_dict) return {}
if not prompt_only: averaged_dict = {}
for key in main_dict: for key in self.last_k_models[0]:
updated_dict[key] = main_dict[key].cpu() * weightage stacked = torch.stack([model_state[key] for model_state in self.last_k_models])
return updated_dict averaged_dict[key] = torch.mean(stacked, dim=0)
else: return averaged_dict
return main_dict.cpu() * weightage
def state_dict_add(self, dict1, dict2, prompt_only=False):
# Average all parameters
if not prompt_only:
modified_dict = dict2
for key in dict1:
modified_dict[key] = modified_dict[key].cpu() + dict1[key].cpu()
return modified_dict
else:
return dict1.cpu() + dict2.cpu()
def get_normal(self, mu, sigma):
normal = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return normal
def parse_batch_train(self, batch): def parse_batch_train(self, batch):
input = batch["img"] input = batch["img"]
@@ -458,4 +437,4 @@ class DZGCoOp(TrainerX):
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False # set strict=False
self._models[name].load_state_dict(state_dict, strict=False) self._models[name].load_state_dict(state_dict, strict=False)