release code
This commit is contained in:
181
Dassl.ProGrad.pytorch/tools/parse_test_res.py
Normal file
181
Dassl.ProGrad.pytorch/tools/parse_test_res.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Goal
|
||||
---
|
||||
1. Read test results from log.txt files
|
||||
2. Compute mean and std across different folders (seeds)
|
||||
|
||||
Usage
|
||||
---
|
||||
Assume the output files are saved under output/my_experiment,
|
||||
which contains results of different seeds, e.g.,
|
||||
|
||||
my_experiment/
|
||||
seed1/
|
||||
log.txt
|
||||
seed2/
|
||||
log.txt
|
||||
seed3/
|
||||
log.txt
|
||||
|
||||
Run the following command from the root directory:
|
||||
|
||||
$ python tools/parse_test_res.py output/my_experiment
|
||||
|
||||
Add --ci95 to the argument if you wanna get 95% confidence
|
||||
interval instead of standard deviation:
|
||||
|
||||
$ python tools/parse_test_res.py output/my_experiment --ci95
|
||||
|
||||
If my_experiment/ has the following structure,
|
||||
|
||||
my_experiment/
|
||||
exp-1/
|
||||
seed1/
|
||||
log.txt
|
||||
...
|
||||
seed2/
|
||||
log.txt
|
||||
...
|
||||
seed3/
|
||||
log.txt
|
||||
...
|
||||
exp-2/
|
||||
...
|
||||
exp-3/
|
||||
...
|
||||
|
||||
Run
|
||||
|
||||
$ python tools/parse_test_res.py output/my_experiment --multi-exp
|
||||
"""
|
||||
import re
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import argparse
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
from dassl.utils import check_isfile, listdir_nohidden
|
||||
|
||||
|
||||
def compute_ci95(res):
|
||||
return 1.96 * np.std(res) / np.sqrt(len(res))
|
||||
|
||||
|
||||
def parse_function(*metrics, directory="", args=None, end_signal=None):
|
||||
print(f"Parsing files in {directory}")
|
||||
subdirs = listdir_nohidden(directory, sort=True)
|
||||
|
||||
outputs = []
|
||||
|
||||
for subdir in subdirs:
|
||||
fpath = osp.join(directory, subdir, "log.txt")
|
||||
assert check_isfile(fpath)
|
||||
good_to_go = False
|
||||
output = OrderedDict()
|
||||
|
||||
with open(fpath, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
if line == end_signal:
|
||||
good_to_go = True
|
||||
|
||||
for metric in metrics:
|
||||
match = metric["regex"].search(line)
|
||||
if match and good_to_go:
|
||||
if "file" not in output:
|
||||
output["file"] = fpath
|
||||
num = float(match.group(1))
|
||||
name = metric["name"]
|
||||
output[name] = num
|
||||
|
||||
if output:
|
||||
outputs.append(output)
|
||||
|
||||
assert len(outputs) > 0, f"Nothing found in {directory}"
|
||||
|
||||
metrics_results = defaultdict(list)
|
||||
|
||||
for output in outputs:
|
||||
msg = ""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, float):
|
||||
msg += f"{key}: {value:.2f}%. "
|
||||
else:
|
||||
msg += f"{key}: {value}. "
|
||||
if key != "file":
|
||||
metrics_results[key].append(value)
|
||||
print(msg)
|
||||
|
||||
output_results = OrderedDict()
|
||||
|
||||
print("===")
|
||||
print(f"Summary of directory: {directory}")
|
||||
for key, values in metrics_results.items():
|
||||
avg = np.mean(values)
|
||||
std = compute_ci95(values) if args.ci95 else np.std(values)
|
||||
print(f"* {key}: {avg:.2f}% +- {std:.2f}%")
|
||||
output_results[key] = avg
|
||||
print("===")
|
||||
|
||||
return output_results
|
||||
|
||||
|
||||
def main(args, end_signal):
|
||||
metric = {
|
||||
"name": args.keyword,
|
||||
"regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"),
|
||||
}
|
||||
|
||||
if args.multi_exp:
|
||||
final_results = defaultdict(list)
|
||||
|
||||
for directory in listdir_nohidden(args.directory, sort=True):
|
||||
directory = osp.join(args.directory, directory)
|
||||
results = parse_function(
|
||||
metric, directory=directory, args=args, end_signal=end_signal
|
||||
)
|
||||
|
||||
for key, value in results.items():
|
||||
final_results[key].append(value)
|
||||
|
||||
print("Average performance")
|
||||
for key, values in final_results.items():
|
||||
avg = np.mean(values)
|
||||
print(f"* {key}: {avg:.2f}%")
|
||||
|
||||
else:
|
||||
parse_function(
|
||||
metric, directory=args.directory, args=args, end_signal=end_signal
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("directory", type=str, help="path to directory")
|
||||
parser.add_argument(
|
||||
"--ci95",
|
||||
action="store_true",
|
||||
help=r"compute 95\% confidence interval"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-log", action="store_true", help="parse test-only logs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multi-exp", action="store_true", help="parse multiple experiments"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keyword",
|
||||
default="accuracy",
|
||||
type=str,
|
||||
help="which keyword to extract"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
end_signal = "Finished training"
|
||||
if args.test_log:
|
||||
end_signal = "=> result"
|
||||
|
||||
main(args, end_signal)
|
||||
69
Dassl.ProGrad.pytorch/tools/replace_text.py
Normal file
69
Dassl.ProGrad.pytorch/tools/replace_text.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Replace text in python files.
|
||||
"""
|
||||
import glob
|
||||
import os.path as osp
|
||||
import argparse
|
||||
import fileinput
|
||||
|
||||
EXTENSION = ".py"
|
||||
|
||||
|
||||
def is_python_file(filename):
|
||||
ext = osp.splitext(filename)[1]
|
||||
return ext == EXTENSION
|
||||
|
||||
|
||||
def update_file(filename, text_to_search, replacement_text):
|
||||
print("Processing {}".format(filename))
|
||||
with fileinput.FileInput(filename, inplace=True, backup="") as file:
|
||||
for line in file:
|
||||
print(line.replace(text_to_search, replacement_text), end="")
|
||||
|
||||
|
||||
def recursive_update(directory, text_to_search, replacement_text):
|
||||
filenames = glob.glob(osp.join(directory, "*"))
|
||||
|
||||
for filename in filenames:
|
||||
if osp.isfile(filename):
|
||||
if not is_python_file(filename):
|
||||
continue
|
||||
update_file(filename, text_to_search, replacement_text)
|
||||
elif osp.isdir(filename):
|
||||
recursive_update(filename, text_to_search, replacement_text)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"file_or_dir", type=str, help="path to file or directory"
|
||||
)
|
||||
parser.add_argument("text_to_search", type=str, help="name to be replaced")
|
||||
parser.add_argument("replacement_text", type=str, help="new name")
|
||||
parser.add_argument(
|
||||
"--ext", type=str, default=".py", help="file extension"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
file_or_dir = args.file_or_dir
|
||||
text_to_search = args.text_to_search
|
||||
replacement_text = args.replacement_text
|
||||
extension = args.ext
|
||||
|
||||
global EXTENSION
|
||||
EXTENSION = extension
|
||||
|
||||
if osp.isfile(file_or_dir):
|
||||
if not is_python_file(file_or_dir):
|
||||
return
|
||||
update_file(file_or_dir, text_to_search, replacement_text)
|
||||
elif osp.isdir(file_or_dir):
|
||||
recursive_update(file_or_dir, text_to_search, replacement_text)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
190
Dassl.ProGrad.pytorch/tools/train.py
Normal file
190
Dassl.ProGrad.pytorch/tools/train.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from dassl.utils import setup_logger, set_random_seed, collect_env_info
|
||||
from dassl.config import get_cfg_default
|
||||
from dassl.engine import build_trainer
|
||||
|
||||
|
||||
def print_args(args, cfg):
|
||||
print("***************")
|
||||
print("** Arguments **")
|
||||
print("***************")
|
||||
optkeys = list(args.__dict__.keys())
|
||||
optkeys.sort()
|
||||
for key in optkeys:
|
||||
print("{}: {}".format(key, args.__dict__[key]))
|
||||
print("************")
|
||||
print("** Config **")
|
||||
print("************")
|
||||
print(cfg)
|
||||
|
||||
|
||||
def reset_cfg(cfg, args):
|
||||
if args.root:
|
||||
cfg.DATASET.ROOT = args.root
|
||||
|
||||
if args.output_dir:
|
||||
cfg.OUTPUT_DIR = args.output_dir
|
||||
|
||||
if args.resume:
|
||||
cfg.RESUME = args.resume
|
||||
|
||||
if args.seed:
|
||||
cfg.SEED = args.seed
|
||||
|
||||
if args.source_domains:
|
||||
cfg.DATASET.SOURCE_DOMAINS = args.source_domains
|
||||
|
||||
if args.target_domains:
|
||||
cfg.DATASET.TARGET_DOMAINS = args.target_domains
|
||||
|
||||
if args.transforms:
|
||||
cfg.INPUT.TRANSFORMS = args.transforms
|
||||
|
||||
if args.trainer:
|
||||
cfg.TRAINER.NAME = args.trainer
|
||||
|
||||
if args.backbone:
|
||||
cfg.MODEL.BACKBONE.NAME = args.backbone
|
||||
|
||||
if args.head:
|
||||
cfg.MODEL.HEAD.NAME = args.head
|
||||
|
||||
|
||||
def extend_cfg(cfg):
|
||||
"""
|
||||
Add new config variables.
|
||||
|
||||
E.g.
|
||||
from yacs.config import CfgNode as CN
|
||||
cfg.TRAINER.MY_MODEL = CN()
|
||||
cfg.TRAINER.MY_MODEL.PARAM_A = 1.
|
||||
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
|
||||
cfg.TRAINER.MY_MODEL.PARAM_C = False
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def setup_cfg(args):
|
||||
cfg = get_cfg_default()
|
||||
extend_cfg(cfg)
|
||||
|
||||
# 1. From the dataset config file
|
||||
if args.dataset_config_file:
|
||||
cfg.merge_from_file(args.dataset_config_file)
|
||||
|
||||
# 2. From the method config file
|
||||
if args.config_file:
|
||||
cfg.merge_from_file(args.config_file)
|
||||
|
||||
# 3. From input arguments
|
||||
reset_cfg(cfg, args)
|
||||
|
||||
# 4. From optional input arguments
|
||||
cfg.merge_from_list(args.opts)
|
||||
|
||||
cfg.freeze()
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup_cfg(args)
|
||||
if cfg.SEED >= 0:
|
||||
print("Setting fixed seed: {}".format(cfg.SEED))
|
||||
set_random_seed(cfg.SEED)
|
||||
setup_logger(cfg.OUTPUT_DIR)
|
||||
|
||||
if torch.cuda.is_available() and cfg.USE_CUDA:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
print_args(args, cfg)
|
||||
print("Collecting env info ...")
|
||||
print("** System info **\n{}\n".format(collect_env_info()))
|
||||
|
||||
trainer = build_trainer(cfg)
|
||||
|
||||
if args.eval_only:
|
||||
trainer.load_model(args.model_dir, epoch=args.load_epoch)
|
||||
trainer.test()
|
||||
return
|
||||
|
||||
if not args.no_train:
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--root", type=str, default="", help="path to dataset")
|
||||
parser.add_argument(
|
||||
"--output-dir", type=str, default="", help="output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default="",
|
||||
help="checkpoint directory (from which the training resumes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="only positive value enables a fixed seed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-domains",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="source domains for DA/DG"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-domains",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="target domains for DA/DG"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transforms", type=str, nargs="+", help="data augmentation methods"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-file", type=str, default="", help="path to config file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-config-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="path to config file for dataset setup",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trainer", type=str, default="", help="name of trainer"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backbone", type=str, default="", help="name of CNN backbone"
|
||||
)
|
||||
parser.add_argument("--head", type=str, default="", help="name of head")
|
||||
parser.add_argument(
|
||||
"--eval-only", action="store_true", help="evaluation only"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="load model from this directory for eval-only mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-epoch",
|
||||
type=int,
|
||||
help="load model weights at this epoch for evaluation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-train", action="store_true", help="do not call trainer.train()"
|
||||
)
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
help="modify config options using the command-line",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user