remove gpa

This commit is contained in:
2026-02-05 12:20:17 +08:00
parent 1925ddfc86
commit dd7615f1eb
5 changed files with 2 additions and 59 deletions

View File

@@ -39,5 +39,3 @@ TRAINER:
PROMPT_DEPTH_TEXT: 9 PROMPT_DEPTH_TEXT: 9
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 15
GPA_STD: 1

View File

@@ -39,5 +39,3 @@ TRAINER:
PROMPT_DEPTH_TEXT: 3 PROMPT_DEPTH_TEXT: 3
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
GPA_MEAN: 6
GPA_STD: 10

View File

@@ -40,8 +40,3 @@ TRAINER:
TEXT_LOSS_WEIGHT: 25 TEXT_LOSS_WEIGHT: 25
IMAGE_LOSS_WEIGHT: 10 IMAGE_LOSS_WEIGHT: 10
# Use the below configuration for: ImageNet, Caltech101, OxfordPets, Food101, UCF101 and SUN397 # Use the below configuration for: ImageNet, Caltech101, OxfordPets, Food101, UCF101 and SUN397
GPA_MEAN: 30
GPA_STD: 30
# Use the below configuration for: StanfordCars, Flowers102, FGVCAircraft, DTD and EuroSAT
# GPA_MEAN: 45
# GPA_STD: 5

View File

@@ -122,8 +122,6 @@ def extend_cfg(cfg):
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_STRONG = 25 # lambda2: strong text constraint weight
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT_WEAK = 2.5 # lambda3: weak text constraint weight
cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10 cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15
cfg.TRAINER.PROMPTSRC.GPA_STD = 1
cfg.TRAINER.PROMPTSRC.CONFIDENCE_TYPE = "max_margin" # entropy, max_prob, margin, max_margin cfg.TRAINER.PROMPTSRC.CONFIDENCE_TYPE = "max_margin" # entropy, max_prob, margin, max_margin
cfg.TRAINER.PROMPTSRC.CONFIDENCE_TEMPERATURE = 2.0 # temperature for confidence calculation cfg.TRAINER.PROMPTSRC.CONFIDENCE_TEMPERATURE = 2.0 # temperature for confidence calculation
cfg.TRAINER.PROMPTSRC.CONFIDENCE_MOMENTUM = 0.95 # momentum for running confidence cfg.TRAINER.PROMPTSRC.CONFIDENCE_MOMENTUM = 0.95 # momentum for running confidence

View File

@@ -1,11 +1,9 @@
import copy
import os.path as osp import os.path as osp
import numpy as np import json
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
import json
from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.utils import load_pretrained_weights, load_checkpoint from dassl.utils import load_pretrained_weights, load_checkpoint
@@ -353,12 +351,6 @@ class PromptSRC(TrainerX):
# Cosine scheduler # Cosine scheduler
self.total_epochs = cfg.OPTIM.MAX_EPOCH self.total_epochs = cfg.OPTIM.MAX_EPOCH
self.step_counter = 1 self.step_counter = 1
N = cfg.OPTIM.MAX_EPOCH
mean = cfg.TRAINER.PROMPTSRC.GPA_MEAN
stdev = cfg.TRAINER.PROMPTSRC.GPA_STD
gauss = self.get_gauss(mean, stdev)
self.gauss = np.array([gauss(a) for a in range(1, N + 1)])
self.gauss = self.gauss / sum(self.gauss)
self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None self.scaler = GradScaler() if cfg.TRAINER.PROMPTSRC.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is # Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel # big, which slows down the copy operation in DataParallel
@@ -366,8 +358,6 @@ class PromptSRC(TrainerX):
if device_count > 1: if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model) self.model = nn.DataParallel(self.model)
# Keep model with GPA
self.previous_model_gpa = None
def forward_backward(self, batch): def forward_backward(self, batch):
image, label = self.parse_batch_train(batch) image, label = self.parse_batch_train(batch)
@@ -413,46 +403,10 @@ class PromptSRC(TrainerX):
if (self.batch_idx + 1) == self.num_batches: if (self.batch_idx + 1) == self.num_batches:
self.update_lr() self.update_lr()
# Means one epoch is completed, perform GPA
self.step_counter = self.step_counter + 1 self.step_counter = self.step_counter + 1
current_epoch_weight = self.gauss[self.step_counter - 2]
current_model_weights = copy.deepcopy(model.state_dict())
weighted_state_dict = self.state_dict_weighting(current_model_weights, current_epoch_weight)
if self.previous_model_gpa is None:
self.previous_model_gpa = weighted_state_dict
else:
self.previous_model_gpa = self.state_dict_add(weighted_state_dict, self.previous_model_gpa)
if self.step_counter == self.model.total_epochs + 1:
print("Using GPA model for final inference...")
model.load_state_dict(self.previous_model_gpa)
self.model.load_state_dict(self.previous_model_gpa)
return loss_summary return loss_summary
def state_dict_weighting(self, main_dict, weightage, prompt_only=False):
# Average all parameters
updated_dict = copy.deepcopy(main_dict)
if not prompt_only:
for key in main_dict:
updated_dict[key] = main_dict[key] * weightage
return updated_dict
else:
return main_dict * weightage
def state_dict_add(self, dict1, dict2, prompt_only=False):
# Average all parameters
if not prompt_only:
modified_dict = dict2
for key in dict1:
modified_dict[key] = (modified_dict[key] + dict1[key])
return modified_dict
else:
return dict1 + dict2
def get_gauss(self, mu, sigma):
gauss = lambda x: (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
return gauss
def parse_batch_train(self, batch): def parse_batch_train(self, batch):
input = batch["img"] input = batch["img"]
label = batch["label"] label = batch["label"]