release code
This commit is contained in:
127
Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py
Normal file
127
Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from collections import OrderedDict, defaultdict
|
||||
import torch
|
||||
from sklearn.metrics import f1_score, confusion_matrix
|
||||
|
||||
from .build import EVALUATOR_REGISTRY
|
||||
|
||||
|
||||
class EvaluatorBase:
|
||||
"""Base evaluator."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def process(self, mo, gt):
|
||||
raise NotImplementedError
|
||||
|
||||
def evaluate(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@EVALUATOR_REGISTRY.register()
|
||||
class Classification(EvaluatorBase):
|
||||
"""Evaluator for classification."""
|
||||
|
||||
def __init__(self, cfg, lab2cname=None, **kwargs):
|
||||
super().__init__(cfg)
|
||||
self._lab2cname = lab2cname
|
||||
self._correct = 0
|
||||
self._total = 0
|
||||
self._per_class_res = None
|
||||
self._y_true = []
|
||||
self._y_pred = []
|
||||
if cfg.TEST.PER_CLASS_RESULT:
|
||||
assert lab2cname is not None
|
||||
self._per_class_res = defaultdict(list)
|
||||
|
||||
def reset(self):
|
||||
self._correct = 0
|
||||
self._total = 0
|
||||
self._y_true = []
|
||||
self._y_pred = []
|
||||
if self._per_class_res is not None:
|
||||
self._per_class_res = defaultdict(list)
|
||||
|
||||
def process(self, mo, gt):
|
||||
# mo (torch.Tensor): model output [batch, num_classes]
|
||||
# gt (torch.LongTensor): ground truth [batch]
|
||||
pred = mo.max(1)[1]
|
||||
matches = pred.eq(gt).float()
|
||||
self._correct += int(matches.sum().item())
|
||||
self._total += gt.shape[0]
|
||||
|
||||
self._y_true.extend(gt.data.cpu().numpy().tolist())
|
||||
self._y_pred.extend(pred.data.cpu().numpy().tolist())
|
||||
|
||||
if self._per_class_res is not None:
|
||||
for i, label in enumerate(gt):
|
||||
label = label.item()
|
||||
matches_i = int(matches[i].item())
|
||||
self._per_class_res[label].append(matches_i)
|
||||
|
||||
def evaluate(self):
|
||||
results = OrderedDict()
|
||||
acc = 100.0 * self._correct / self._total
|
||||
err = 100.0 - acc
|
||||
macro_f1 = 100.0 * f1_score(
|
||||
self._y_true,
|
||||
self._y_pred,
|
||||
average="macro",
|
||||
labels=np.unique(self._y_true)
|
||||
)
|
||||
|
||||
# The first value will be returned by trainer.test()
|
||||
results["accuracy"] = acc
|
||||
results["error_rate"] = err
|
||||
results["macro_f1"] = macro_f1
|
||||
|
||||
print(
|
||||
"=> result\n"
|
||||
f"* total: {self._total:,}\n"
|
||||
f"* correct: {self._correct:,}\n"
|
||||
f"* accuracy: {acc:.2f}%\n"
|
||||
f"* error: {err:.2f}%\n"
|
||||
f"* macro_f1: {macro_f1:.2f}%"
|
||||
)
|
||||
|
||||
if self._per_class_res is not None:
|
||||
labels = list(self._per_class_res.keys())
|
||||
labels.sort()
|
||||
|
||||
print("=> per-class result")
|
||||
accs = []
|
||||
|
||||
for label in labels:
|
||||
classname = self._lab2cname[label]
|
||||
res = self._per_class_res[label]
|
||||
correct = sum(res)
|
||||
total = len(res)
|
||||
acc = 100.0 * correct / total
|
||||
accs.append(acc)
|
||||
print(
|
||||
"* class: {} ({})\t"
|
||||
"total: {:,}\t"
|
||||
"correct: {:,}\t"
|
||||
"acc: {:.2f}%".format(
|
||||
label, classname, total, correct, acc
|
||||
)
|
||||
)
|
||||
mean_acc = np.mean(accs)
|
||||
print("* average: {:.2f}%".format(mean_acc))
|
||||
|
||||
results["perclass_accuracy"] = mean_acc
|
||||
|
||||
if self.cfg.TEST.COMPUTE_CMAT:
|
||||
cmat = confusion_matrix(
|
||||
self._y_true, self._y_pred, normalize="true"
|
||||
)
|
||||
save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
|
||||
torch.save(cmat, save_path)
|
||||
print('Confusion matrix is saved to "{}"'.format(save_path))
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user