This commit is contained in:
2026-02-03 10:21:07 +08:00
parent e556f17ebc
commit 0c2ae25cf8
81 changed files with 572 additions and 76 deletions

212
.gitignore vendored Normal file
View File

@@ -0,0 +1,212 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Custom
**/output*
**/__pycache__

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "Dassl.pytorch"]
path = Dassl.pytorch
url = https://github.com/KaiyangZhou/Dassl.pytorch.git

1
Dassl.pytorch Submodule

Submodule Dassl.pytorch added at c61a1b570a

Binary file not shown.

View File

@@ -15,7 +15,7 @@ INPUT:
OPTIM:
NAME: "sgd"
LR: 0.0035
MAX_EPOCH: 5
MAX_EPOCH: 10
LR_SCHEDULER: "cosine"
WARMUP_EPOCH: 1
WARMUP_TYPE: "constant"
@@ -38,4 +38,4 @@ TRAINER:
N_CTX: 2
CTX_INIT: "a photo of a"
PREC: "fp16"
PROMPT_DEPTH: 9
PROMPT_DEPTH: 9

43
run.sh Normal file
View File

@@ -0,0 +1,43 @@
#!/bin/bash
# 定义种子列表
seeds=(1 2 3)
# 定义数据集列表
datasets=(
"ucf101"
"eurosat"
"oxford_pets"
"food101"
"oxford_flowers"
"dtd"
"caltech101"
"fgvc_aircraft"
"stanford_cars"
# "sun397"
# "imagenet"
)
# 对于每个种子,遍历所有数据集
for seed in "${seeds[@]}"; do
for dataset in "${datasets[@]}"; do
echo "正在运行训练: 数据集=${dataset}, 种子=${seed}"
# 运行训练命令
CUDA_VISIBLE_DEVICES=0 python train.py \
--root ~/Datasets/CoOp \
--seed "$seed" \
--trainer MaPLe \
--dataset-config-file "configs/datasets/${dataset}.yaml" \
--config-file configs/trainers/MaPLe/vit_b16_t.yaml \
--output-dir "output/DAPT_${dataset}_seed${seed}" \
--mode dapt-g \
DATASET.NUM_SHOTS ${SHOTS}
echo "完成: 数据集=${dataset}, 种子=${seed}"
echo "----------------------------------------"
done
done
echo "所有训练任务完成!"

View File

@@ -0,0 +1,54 @@
#!/bin/bash
#cd ../..
# custom config
DATA="~/Datasets/CoOp"
TRAINER=MaPLe
DATASET=$1
SEED=$2
CFG=vit_b16_c2_ep5_batch4_2ctx
SHOTS=16
# LOADEP=10
SUB=new
COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
DIR=output/base2new/test_${SUB}/${COMMON_DIR}
if [ -d "$DIR" ]; then
echo "Evaluating model"
echo "Results are available in ${DIR}. Resuming..."
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
# --load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB}
else
echo "Evaluating model"
echo "Runing the first phase job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
# --load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB}
fi

View File

@@ -0,0 +1,40 @@
#!/bin/bash
#cd ../..
# custom config
DATA="~/Datasets/CoOp"
TRAINER=MaPLe
DATASET=$1
SEED=$2
CFG=vit_b16_c2_ep5_batch4_2ctx
# CFG=vit_b16_base
SHOTS=16
DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Resuming..."
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES base
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES base
fi

View File

@@ -0,0 +1,58 @@
#!/bin/bash
#cd ../..
# custom config
DATA="/path/to/dataset/folder"
TRAINER=MaPLe
DATASET=$1
SEED=$2
WEIGHTSPATH=$3
CFG=vit_b16_c2_ep5_batch4_2ctx
SHOTS=16
LOADEP=5
SUB_base=base
SUB_novel=new
COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED}
DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR}
DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR}
if [ -d "$DIR" ]; then
echo "Results are already available in ${DIR}. Skipping..."
else
echo "Evaluating model"
echo "Runing the first phase job and save the output to ${DIR}"
# Evaluate on base classes
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR_base} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB_base}
# Evaluate on novel classes
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR_novel} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
DATASET.SUBSAMPLE_CLASSES ${SUB_novel}
fi

View File

@@ -0,0 +1,38 @@
#!/bin/bash
#cd ../..
# custom config
DATA="/path/to/dataset/folder"
TRAINER=MaPLe
DATASET=$1
SEED=$2
WEIGHTSPATH=$3
CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets
SHOTS=16
LOADEP=2
MODEL_DIR=${WEIGHTSPATH}/seed${SEED}
DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are already available in ${DIR}. Skipping..."
else
echo "Evaluating model"
echo "Runing the first phase job and save the output to ${DIR}"
# Evaluate on evaluation datasets
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir ${MODEL_DIR} \
--load-epoch ${LOADEP} \
--eval-only \
DATASET.NUM_SHOTS ${SHOTS} \
fi

View File

@@ -0,0 +1,32 @@
#!/bin/bash
#cd ../..
# custom config
DATA="/path/to/dataset/folder"
TRAINER=MaPLe
DATASET=$1
SEED=$2
CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets
SHOTS=16
DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}. Skip this job"
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
--model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
--load-epoch 2 \
--eval-only
fi

View File

@@ -0,0 +1,30 @@
#!/bin/bash
#cd ../..
# custom config
DATA="/path/to/dataset/folder"
TRAINER=MaPLe
DATASET=$1
SEED=$2
CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets
SHOTS=16
DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
if [ -d "$DIR" ]; then
echo "Results are available in ${DIR}."
else
echo "Run this job and save the output to ${DIR}"
python train.py \
--root ${DATA} \
--seed ${SEED} \
--trainer ${TRAINER} \
--dataset-config-file configs/datasets/${DATASET}.yaml \
--config-file configs/trainers/${TRAINER}/${CFG}.yaml \
--output-dir ${DIR} \
DATASET.NUM_SHOTS ${SHOTS}
fi

View File

@@ -291,7 +291,7 @@ def main(args):
trainer.test_withlabel()
# trainer.test_multi_label() Evaluation for VOC12
import shutil
shutil.rmtree(original_weight_output+'/MultiModalPromptLearner')
# shutil.rmtree(original_weight_output+'/MultiModalPromptLearner')
else:
original_weight_output = 'output/' + '/'.join(trainer.output_dir.split('/')[1:])
trainer.load_model(original_weight_output)
@@ -304,7 +304,7 @@ def main(args):
trainer.train()
if cfg.DATASET.SUBSAMPLE_CLASSES != 'base':
import shutil
shutil.rmtree(trainer.output_dir+'/MultiModalPromptLearner')
# shutil.rmtree(trainer.output_dir+'/MultiModalPromptLearner')
if __name__ == "__main__":

View File

@@ -19,7 +19,6 @@ from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint, mkdir_if_missing
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.evaluation import Classification,EvaluatorBase
from pygrad.pcgrad import PCGrad
from datasets.data_manager import DataManager
from dassl.data.datasets import build_dataset
@@ -35,16 +34,6 @@ from .util import GradCAM,denorm
import cv2
_tokenizer = _Tokenizer()
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',]
#['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
#'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
#]
BACKGROUND_CATEGORY_FOOD = ['table','forks','tablecloth','hands','spoon','glasses','dishes']
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
@@ -159,28 +148,18 @@ class MultiModalPromptLearner(nn.Module):
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
###Introduce Background
bg_template = 'a clean origami {}.'
bg_classesnames = [bg_template.format(name) for name in BACKGROUND_CATEGORY +BACKGROUND_CATEGORY_FOOD ]
tokenized_bg_prompts = torch.cat([clip.tokenize(bg) for bg in bg_classesnames])
bg_num = len(BACKGROUND_CATEGORY) + len(BACKGROUND_CATEGORY_FOOD)
tokenized_prompts = torch.cat((tokenized_prompts,tokenized_bg_prompts),dim=0)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
self.bg_embeding = embedding[-bg_num:]
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:-bg_num, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:-bg_num, 1 + n_ctx:, :]) # CLS, EOS
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num 77] [:-bg_num]
self.tokenized_prompts = tokenized_prompts # torch.Tensor [class_num, 77]
self.name_lens = name_lens
def construct_prompts(self, ctx, prefix, suffix, label=None):
@@ -204,8 +183,7 @@ class MultiModalPromptLearner(nn.Module):
dim=1,
)
final_prompts = torch.cat((prompts,self.bg_embeding.cuda()),dim=0)
return final_prompts
return prompts
def forward(self):
ctx = self.ctx
@@ -264,17 +242,44 @@ class CustomCLIP(nn.Module):
def cos_sim(self,a,b):
return F.cosine_similarity(a,b)
def contrastive_loss(self, anchor, positive, negative, temperature=0.07):
"""
InfoNCE contrastive loss for foreground-background discrimination
Args:
anchor: Complete image features [B, D]
positive: Foreground features [B, D]
negative: Background features [B, D]
temperature: Temperature parameter for softmax
Returns:
loss: Contrastive learning loss value
"""
# Calculate similarity
sim_pos = F.cosine_similarity(anchor, positive, dim=-1) # [B]
sim_neg = F.cosine_similarity(anchor, negative, dim=-1) # [B]
# Apply temperature scaling
sim_pos = sim_pos / temperature
sim_neg = sim_neg / temperature
# InfoNCE loss: -log(exp(sim_pos) / (exp(sim_pos) + exp(sim_neg)))
logits = torch.stack([sim_pos, sim_neg], dim=1) # [B, 2]
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
loss = F.cross_entropy(logits, labels)
return loss
def forward(self, image, label=None,record=False,cal_gradient=False,weight=None,epoch=None,index=None,cfg=None,mask=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
text_features_fg = text_features[:-len(BACKGROUND_CATEGORY)]
ori_image_input = image.type(self.dtype)
# text_features = text_features + self.get_learnable_noise(text_features.shape)
text_features_fg = text_features_fg / text_features_fg.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_features, visual_ctx, mask_similarity = self.image_encoder(ori_image_input, shared_ctx,
deep_compound_prompts_vision)
@@ -285,18 +290,15 @@ class CustomCLIP(nn.Module):
# if label is not None:
# image_features = image_features + self.get_uniform_ball_noise(image_features.shape)
logits = logit_scale * image_features @ text_features_fg.t()
logits = logit_scale * image_features @ text_features.t()
loss_re = torch.tensor(0.0, dtype=self.dtype, device=image.device)
loss_fg = torch.tensor(0.0, dtype=self.dtype, device=image.device)
loss_contrastive = torch.tensor(0.0, dtype=self.dtype, device=image.device)
if mask != None:
text_features_bg = text_features[-len(BACKGROUND_CATEGORY):]
text_features_bg = text_features_bg / text_features_bg.norm(dim=-1, keepdim=True)
image_features_fg,_,_ = self.image_encoder(ori_image_input*mask, shared_ctx, deep_compound_prompts_vision) #, shared_ctx, deep_compound_prompts_vision
@@ -305,37 +307,27 @@ class CustomCLIP(nn.Module):
image_features_bg = image_features_bg / image_features_bg.norm(dim=-1, keepdim=True)
loss_re1 = F.triplet_margin_loss(image_features,image_features_fg.detach(),image_features_bg.detach(),margin=1.5)
# image_features_fg_ori = self.image_encoder_ori(ori_image_input*mask_random)
# image_features_bg_ori = self.image_encoder_ori(ori_image_input*(1-mask_random))
# image_features_fg_ori = image_features_fg_ori / image_features_fg_ori.norm(dim=-1, keepdim=True)
# image_features_bg_ori = image_features_bg_ori / image_features_bg_ori.norm(dim=-1,keepdim=True)
# image_features_all_ori = image_features_fg_ori + image_features_bg_ori
# image_features_all_ori = image_features_all_ori / image_features_all_ori.norm(dim=-1,keepdim=True)
# loss_reo = torch.abs(image_features_all_ori.detach() - image_features).mean()
foreground_score = logit_scale*image_features_fg.detach()@text_features_fg.t()
pseudo_label = torch.argmax(image_features_bg @ text_features_bg.t(), dim=-1)
logits_bg = logit_scale*(image_features_bg) @ text_features_bg.t()
para_bg = 0.5
para_fg = 0.1
para_vd = 0.8
loss_contrastive = self.contrastive_loss(image_features, image_features_fg.detach(), image_features_bg.detach(), temperature=0.07)
loss_bg = F.cross_entropy(logits_bg,pseudo_label)
loss_fg = F.cross_entropy(foreground_score,label)
para_fg = 0.2
para_vd = 0.6
if epoch > 6: #Tunable parameters
loss_re = para_fg*loss_fg + para_bg*loss_bg
if label is not None:
loss_fg = F.cross_entropy(logit_scale*image_features_fg.detach()@text_features.t(), label)
else:
loss_re = para_vd*loss_re1 #loss_reo would be effective in base2novel setting
loss_fg = torch.tensor(0.0, dtype=self.dtype, device=image.device)
if epoch is not None and epoch > 6: #Tunable parameters
loss_re = para_fg*loss_fg
else:
loss_re = para_vd*loss_contrastive
if self.prompt_learner.training:
if weight is None:
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_vd':loss_re1.item(),'loss_bg':loss_bg.item(),'loss_fg':loss_fg.item()}
return F.cross_entropy(logits,label)+loss_re,logits,{'loss_contrastive':loss_contrastive.item(),'loss_fg':loss_fg.item()}
else:
return F.cross_entropy(weight.unsqueeze(-1)*logits,label), logits
@@ -674,8 +666,8 @@ class MaPLe(TrainerX):
model_name="model-best.pth.tar"
)
# if meet_checkpoint_freq or last_epoch:
# self.save_model(self.epoch, self.output_dir)
if meet_checkpoint_freq or last_epoch:
self.save_model(self.epoch, self.output_dir)
print(f"Now generate the attentive masking in {self.cfg.TRAINER.DAPT_MODE} \n")

View File

@@ -5,10 +5,6 @@ import cv2
from PIL import Image
import os
BACKGROUND_CATEGORY = ['ground','land','grass','tree','building','wall','sky','lake','water','river','sea','railway','railroad','keyboard','helmet',
'cloud','house','mountain','ocean','road','rock','street','valley','bridge','sign',
]
class GradCAM(object):
def __init__(self,model_dict):
layer_name = model_dict['layer_name']
@@ -80,7 +76,7 @@ class GradCAM(object):
else:
logit = self.model_arch.forward_test(input,labels,cfg=cfg)
pred_label = torch.argmax(logit[:,:-len(BACKGROUND_CATEGORY)])
pred_label = torch.argmax(logit)
sign = pred_label == labels
# if (split == 'true' and sign == False) or (split == 'wrong' and sign == True):
# print(f'Ignore the not {split} sample')
@@ -88,11 +84,10 @@ class GradCAM(object):
# if attn_mask:
# return final_cls_mask
pred = logit[:,:-len(BACKGROUND_CATEGORY)].argmax(dim=-1)
background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]-len(BACKGROUND_CATEGORY)).to(torch.float16)
pred = logit.argmax(dim=-1)
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
loss = (F.softmax(logit[:,:-len(BACKGROUND_CATEGORY)])*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
loss = (F.softmax(logit)*one_hot_labels).mean()
# score = logit[:,labels]
self.model_arch.zero_grad()
@@ -186,10 +181,8 @@ class GradCAM(object):
# if attn_mask:
# return final_cls_mask
# pred = logit[:,-len(BACKGROUND_CATEGORY):].argmax(dim=-1)
# background_logit = logit[:,-len(BACKGROUND_CATEGORY):]
one_hot_labels = F.one_hot(labels, num_classes=logit.shape[1]).to(torch.float16)
loss = (logit*one_hot_labels).mean() #+ background_logit.mean() #(logit[:,:-len(BACKGROUND_CATEGORY)]*one_hot_labels).mean() #F.cross_entropy(logit.requires_grad_(True), labels)
loss = (logit*one_hot_labels).mean()
# score = logit[:,labels]
self.model_arch.zero_grad()
loss.backward(retain_graph=retain_graph)