Upload to Main
This commit is contained in:
367
train.py
Normal file
367
train.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from dassl.utils import setup_logger, set_random_seed, collect_env_info
|
||||
from dassl.config import get_cfg_default
|
||||
from dassl.engine import build_trainer
|
||||
# os['CUDA_LAUNCH_BLOCKING'] = 1
|
||||
# custom
|
||||
import datasets.oxford_pets
|
||||
import datasets.oxford_flowers
|
||||
import datasets.fgvc_aircraft
|
||||
import datasets.dtd
|
||||
import datasets.eurosat
|
||||
import datasets.stanford_cars
|
||||
import datasets.food101
|
||||
import datasets.sun397
|
||||
import datasets.caltech101
|
||||
import datasets.ucf101
|
||||
import datasets.imagenet
|
||||
import datasets.pascal_voc
|
||||
import datasets.imagenet_sketch
|
||||
import datasets.imagenetv2
|
||||
import datasets.imagenet_a
|
||||
import datasets.imagenet_r
|
||||
|
||||
import trainers.coop
|
||||
import trainers.cocoop
|
||||
import trainers.zsclip
|
||||
import trainers.maple
|
||||
import trainers.independentVL
|
||||
import trainers.vpt
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
||||
import copy
|
||||
|
||||
def print_args(args, cfg):
|
||||
print("***************")
|
||||
print("** Arguments **")
|
||||
print("***************")
|
||||
optkeys = list(args.__dict__.keys())
|
||||
optkeys.sort()
|
||||
for key in optkeys:
|
||||
print("{}: {}".format(key, args.__dict__[key]))
|
||||
print("************")
|
||||
print("** Config **")
|
||||
print("************")
|
||||
print(cfg)
|
||||
|
||||
|
||||
def reset_cfg(cfg, args):
|
||||
if args.root:
|
||||
cfg.DATASET.ROOT = args.root
|
||||
|
||||
if args.output_dir:
|
||||
cfg.OUTPUT_DIR = args.output_dir
|
||||
|
||||
if args.resume:
|
||||
cfg.RESUME = args.resume
|
||||
|
||||
if args.seed:
|
||||
cfg.SEED = args.seed
|
||||
|
||||
if args.source_domains:
|
||||
cfg.DATASET.SOURCE_DOMAINS = args.source_domains
|
||||
|
||||
if args.target_domains:
|
||||
cfg.DATASET.TARGET_DOMAINS = args.target_domains
|
||||
|
||||
if args.transforms:
|
||||
cfg.INPUT.TRANSFORMS = args.transforms
|
||||
|
||||
if args.trainer:
|
||||
cfg.TRAINER.NAME = args.trainer
|
||||
|
||||
if args.backbone:
|
||||
cfg.MODEL.BACKBONE.NAME = args.backbone
|
||||
|
||||
if args.head:
|
||||
cfg.MODEL.HEAD.NAME = args.head
|
||||
|
||||
if args.dapt_mode:
|
||||
cfg.TRAINER.DAPT_MODE = args.dapt_mode
|
||||
|
||||
def extend_cfg(cfg):
|
||||
"""
|
||||
Add new config variables.
|
||||
|
||||
E.g.
|
||||
from yacs.config import CfgNode as CN
|
||||
cfg.TRAINER.MY_MODEL = CN()
|
||||
cfg.TRAINER.MY_MODEL.PARAM_A = 1.
|
||||
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
|
||||
cfg.TRAINER.MY_MODEL.PARAM_C = False
|
||||
"""
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
cfg.TRAINER.COOP = CN()
|
||||
cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors
|
||||
cfg.TRAINER.COOP.CSC = False # class-specific context
|
||||
cfg.TRAINER.COOP.CTX_INIT = "" # initialization words
|
||||
cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp
|
||||
cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
|
||||
|
||||
cfg.TRAINER.COCOOP = CN()
|
||||
cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors
|
||||
cfg.TRAINER.COCOOP.CTX_INIT = "" # initialization words
|
||||
cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
|
||||
|
||||
# Config for MaPLe
|
||||
cfg.TRAINER.MAPLE = CN()
|
||||
cfg.TRAINER.MAPLE.N_CTX = 2 # number of context vectors
|
||||
cfg.TRAINER.MAPLE.CTX_INIT = "a photo of a" # initialization words
|
||||
cfg.TRAINER.MAPLE.PREC = "fp16" # fp16, fp32, amp
|
||||
cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1)
|
||||
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
||||
|
||||
# Config for independent Vision Language prompting (independent-vlp)
|
||||
cfg.TRAINER.IVLP = CN()
|
||||
cfg.TRAINER.IVLP.N_CTX_VISION = 2 # number of context vectors at the vision branch
|
||||
cfg.TRAINER.IVLP.N_CTX_TEXT = 2 # number of context vectors at the language branch
|
||||
cfg.TRAINER.IVLP.CTX_INIT = "a photo of a" # initialization words (only for language prompts)
|
||||
cfg.TRAINER.IVLP.PREC = "fp16" # fp16, fp32, amp
|
||||
# If both variables below are set to 0, 0, will the config will degenerate to COOP model
|
||||
cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will act as shallow MaPLe (J=1)
|
||||
cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will act as shallow MaPLe (J=1)
|
||||
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
||||
|
||||
# Config for only vision side prompting
|
||||
cfg.TRAINER.VPT = CN()
|
||||
cfg.TRAINER.VPT.N_CTX_VISION = 2 # number of context vectors at the vision branch
|
||||
cfg.TRAINER.VPT.CTX_INIT = "a photo of a" # initialization words
|
||||
cfg.TRAINER.VPT.PREC = "fp16" # fp16, fp32, amp
|
||||
cfg.TRAINER.VPT.PROMPT_DEPTH_VISION = 1 # if set to 1, will represent shallow vision prompting only
|
||||
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
||||
|
||||
#Configs for Data Selection
|
||||
#Method ['Uniform', 'Uncertainty' (Entropy, least confidence, margin),
|
||||
# 'Forgetting', 'Herding', 'Submodular' (GraphCut/Facility Location), 'Glister',
|
||||
# 'GraNd', 'Craig', 'Cal']
|
||||
cfg.DATASET.SELECTION_METHOD = 'Uniform'
|
||||
cfg.DATASET.SELECTION_RATIO = 1.0
|
||||
cfg.DATASET.SELECTION_BATCH_SIZE = 50
|
||||
|
||||
cfg.OPTIM_SELECTION = copy.deepcopy(cfg.OPTIM)
|
||||
cfg.OPTIM_SELECTION.NAME = 'sgd'
|
||||
cfg.OPTIM_SELECTION.LR = 0.0035
|
||||
cfg.OPTIM_SELECTION.MAX_EPOCH = 0 #Forgetting is needed
|
||||
cfg.OPTIM_SELECTION.LR_SCHEDULER = 'cosine'
|
||||
cfg.OPTIM_SELECTION.WARMUP_EPOCH = 1
|
||||
cfg.OPTIM_SELECTION.WARMUP_TYPE = 'constant'
|
||||
cfg.OPTIM_SELECTION.WARMUP_CONS_LR = 1e-5
|
||||
|
||||
|
||||
def extend_cfg2(cfg):
|
||||
"""
|
||||
Add new config variables.
|
||||
|
||||
E.g.
|
||||
from yacs.config import CfgNode as CN
|
||||
cfg.TRAINER.MY_MODEL = CN()
|
||||
cfg.TRAINER.MY_MODEL.PARAM_A = 1.
|
||||
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
|
||||
cfg.TRAINER.MY_MODEL.PARAM_C = False
|
||||
"""
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
cfg.TRAINER.COOP = CN()
|
||||
cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors
|
||||
cfg.TRAINER.COOP.CSC = False # class-specific context
|
||||
cfg.TRAINER.COOP.CTX_INIT = "" # initialization words
|
||||
cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp
|
||||
cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
|
||||
|
||||
cfg.TRAINER.COCOOP = CN()
|
||||
cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors
|
||||
cfg.TRAINER.COCOOP.CTX_INIT = "" # initialization words
|
||||
cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
|
||||
|
||||
# Config for MaPLe
|
||||
cfg.TRAINER.MAPLE = CN()
|
||||
cfg.TRAINER.MAPLE.N_CTX = 2 # number of context vectors
|
||||
cfg.TRAINER.MAPLE.CTX_INIT = "a photo of a" # initialization words
|
||||
cfg.TRAINER.MAPLE.PREC = "fp16" # fp16, fp32, amp
|
||||
cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1)
|
||||
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
||||
|
||||
# Config for independent Vision Language prompting (independent-vlp)
|
||||
cfg.TRAINER.IVLP = CN()
|
||||
cfg.TRAINER.IVLP.N_CTX_VISION = 2 # number of context vectors at the vision branch
|
||||
cfg.TRAINER.IVLP.N_CTX_TEXT = 2 # number of context vectors at the language branch
|
||||
cfg.TRAINER.IVLP.CTX_INIT = "a photo of a" # initialization words (only for language prompts)
|
||||
cfg.TRAINER.IVLP.PREC = "fp16" # fp16, fp32, amp
|
||||
# If both variables below are set to 0, 0, will the config will degenerate to COOP model
|
||||
cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will act as shallow MaPLe (J=1)
|
||||
cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will act as shallow MaPLe (J=1)
|
||||
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
||||
|
||||
# Config for only vision side prompting
|
||||
cfg.TRAINER.VPT = CN()
|
||||
cfg.TRAINER.VPT.N_CTX_VISION = 2 # number of context vectors at the vision branch
|
||||
cfg.TRAINER.VPT.CTX_INIT = "a photo of a" # initialization words
|
||||
cfg.TRAINER.VPT.PREC = "fp16" # fp16, fp32, amp
|
||||
cfg.TRAINER.VPT.PROMPT_DEPTH_VISION = 1 # if set to 1, will represent shallow vision prompting only
|
||||
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
|
||||
|
||||
#selection
|
||||
#Method ['Uniform', 'Uncertainty' (Entropy, least confidence, margin),
|
||||
# 'Forgetting', 'Herding', 'Submodular' (GraphCut/Facility Location), 'Glister',
|
||||
# 'GraNd', 'Craig', 'Cal']
|
||||
cfg.DATASET.SELECTION_METHOD = 'GraNd'
|
||||
cfg.DATASET.SELECTION_RATIO = 0.5
|
||||
cfg.DATASET.SELECTION_BATCH_SIZE = 50
|
||||
|
||||
cfg.OPTIM_SELECTION = copy.deepcopy(cfg.OPTIM)
|
||||
cfg.OPTIM_SELECTION.NAME = 'sgd'
|
||||
cfg.OPTIM_SELECTION.LR = 0.0035
|
||||
cfg.OPTIM_SELECTION.MAX_EPOCH = 0 #Forgetting is needed
|
||||
cfg.OPTIM_SELECTION.LR_SCHEDULER = 'cosine'
|
||||
cfg.OPTIM_SELECTION.WARMUP_EPOCH = 1
|
||||
cfg.OPTIM_SELECTION.WARMUP_TYPE = 'constant'
|
||||
cfg.OPTIM_SELECTION.WARMUP_CONS_LR = 1e-5
|
||||
|
||||
def setup_cfg(args):
|
||||
cfg = get_cfg_default()
|
||||
extend_cfg(cfg)
|
||||
|
||||
# 1. From the dataset config file
|
||||
if args.dataset_config_file:
|
||||
cfg.merge_from_file(args.dataset_config_file)
|
||||
|
||||
# 2. From the method config file
|
||||
if args.config_file:
|
||||
cfg.merge_from_file(args.config_file)
|
||||
|
||||
# 3. From input arguments
|
||||
reset_cfg(cfg, args)
|
||||
|
||||
# 4. From optional input arguments
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR,cfg.DATASET.NAME,cfg.DATASET.SELECTION_METHOD+'_'+str(cfg.DATASET.SELECTION_RATIO),'seed'+str(cfg.SEED))
|
||||
cfg.freeze()
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def setup_cfg2(args):
|
||||
cfg = get_cfg_default()
|
||||
extend_cfg2(cfg)
|
||||
|
||||
# 1. From the dataset config file
|
||||
if args.dataset_config_file:
|
||||
cfg.merge_from_file(args.dataset_config_file)
|
||||
|
||||
# 2. From the method config file
|
||||
if args.config_file:
|
||||
cfg.merge_from_file(args.config_file)
|
||||
|
||||
# 3. From input arguments
|
||||
reset_cfg(cfg, args)
|
||||
|
||||
# 4. From optional input arguments
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR,cfg.DATASET.NAME,cfg.DATASET.SELECTION_METHOD+'_'+str(cfg.DATASET.SELECTION_RATIO),'seed'+str(cfg.SEED))
|
||||
cfg.freeze()
|
||||
|
||||
return cfg
|
||||
|
||||
def main(args):
|
||||
cfg = setup_cfg(args)
|
||||
cfg2 = setup_cfg2(args)
|
||||
if cfg.SEED >= 0:
|
||||
print("Setting fixed seed: {}".format(cfg.SEED))
|
||||
set_random_seed(cfg.SEED)
|
||||
setup_logger(cfg.OUTPUT_DIR)
|
||||
|
||||
if torch.cuda.is_available() and cfg.USE_CUDA:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
print_args(args, cfg)
|
||||
print("Collecting env info ...")
|
||||
print("** System info **\n{}\n".format(collect_env_info()))
|
||||
|
||||
trainer = build_trainer(cfg)
|
||||
# trainer2 = build_trainer(cfg2)
|
||||
|
||||
|
||||
if args.eval_only:
|
||||
if cfg.DATASET.SUBSAMPLE_CLASSES == 'new':
|
||||
original_weight_output = 'output/'+'/'.join(trainer.output_dir.split('/')[1:])
|
||||
trainer.load_model(original_weight_output)
|
||||
trainer.test_withlabel()
|
||||
# trainer.test_multi_label() Evaluation for VOC12
|
||||
import shutil
|
||||
shutil.rmtree(original_weight_output+'/MultiModalPromptLearner')
|
||||
else:
|
||||
original_weight_output = 'output/' + '/'.join(trainer.output_dir.split('/')[1:])
|
||||
trainer.load_model(original_weight_output)
|
||||
# trainer.load_model(args.model_dir,epoch=5)
|
||||
trainer.test()
|
||||
return
|
||||
|
||||
|
||||
if not args.no_train:
|
||||
trainer.train()
|
||||
if cfg.DATASET.SUBSAMPLE_CLASSES != 'base':
|
||||
import shutil
|
||||
shutil.rmtree(trainer.output_dir+'/MultiModalPromptLearner')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--root", type=str, default="", help="path to dataset")
|
||||
parser.add_argument("--output-dir", type=str, default="", help="output directory")
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default="",
|
||||
help="checkpoint directory (from which the training resumes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=-1, help="only positive value enables a fixed seed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-domains", type=str, nargs="+", help="source domains for DA/DG"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-domains", type=str, nargs="+", help="target domains for DA/DG"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transforms", type=str, nargs="+", help="data augmentation methods"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-file", type=str, default="", help="path to config file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-config-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="path to config file for dataset setup",
|
||||
)
|
||||
parser.add_argument("--trainer", type=str, default="", help="name of trainer")
|
||||
parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
|
||||
parser.add_argument("--head", type=str, default="", help="name of head")
|
||||
parser.add_argument("--eval-only", action="store_true", help="evaluation only")
|
||||
parser.add_argument(
|
||||
"--model-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="load model from this directory for eval-only mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-epoch", type=int, help="load model weights at this epoch for evaluation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-train", action="store_true", help="do not call trainer.train()"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dapt-mode", default="dapt-s", help="whether using dapt-s or dapt-g"
|
||||
)
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
help="modify config options using the command-line",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user