128 lines
3.8 KiB
Python
128 lines
3.8 KiB
Python
import numpy as np
|
|
import os.path as osp
|
|
from collections import OrderedDict, defaultdict
|
|
import torch
|
|
from sklearn.metrics import f1_score, confusion_matrix
|
|
|
|
from .build import EVALUATOR_REGISTRY
|
|
|
|
|
|
class EvaluatorBase:
|
|
"""Base evaluator."""
|
|
|
|
def __init__(self, cfg):
|
|
self.cfg = cfg
|
|
|
|
def reset(self):
|
|
raise NotImplementedError
|
|
|
|
def process(self, mo, gt):
|
|
raise NotImplementedError
|
|
|
|
def evaluate(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
@EVALUATOR_REGISTRY.register()
|
|
class Classification(EvaluatorBase):
|
|
"""Evaluator for classification."""
|
|
|
|
def __init__(self, cfg, lab2cname=None, **kwargs):
|
|
super().__init__(cfg)
|
|
self._lab2cname = lab2cname
|
|
self._correct = 0
|
|
self._total = 0
|
|
self._per_class_res = None
|
|
self._y_true = []
|
|
self._y_pred = []
|
|
if cfg.TEST.PER_CLASS_RESULT:
|
|
assert lab2cname is not None
|
|
self._per_class_res = defaultdict(list)
|
|
|
|
def reset(self):
|
|
self._correct = 0
|
|
self._total = 0
|
|
self._y_true = []
|
|
self._y_pred = []
|
|
if self._per_class_res is not None:
|
|
self._per_class_res = defaultdict(list)
|
|
|
|
def process(self, mo, gt):
|
|
# mo (torch.Tensor): model output [batch, num_classes]
|
|
# gt (torch.LongTensor): ground truth [batch]
|
|
pred = mo.max(1)[1]
|
|
matches = pred.eq(gt).float()
|
|
self._correct += int(matches.sum().item())
|
|
self._total += gt.shape[0]
|
|
|
|
self._y_true.extend(gt.data.cpu().numpy().tolist())
|
|
self._y_pred.extend(pred.data.cpu().numpy().tolist())
|
|
|
|
if self._per_class_res is not None:
|
|
for i, label in enumerate(gt):
|
|
label = label.item()
|
|
matches_i = int(matches[i].item())
|
|
self._per_class_res[label].append(matches_i)
|
|
|
|
def evaluate(self):
|
|
results = OrderedDict()
|
|
acc = 100.0 * self._correct / self._total
|
|
err = 100.0 - acc
|
|
macro_f1 = 100.0 * f1_score(
|
|
self._y_true,
|
|
self._y_pred,
|
|
average="macro",
|
|
labels=np.unique(self._y_true)
|
|
)
|
|
|
|
# The first value will be returned by trainer.test()
|
|
results["accuracy"] = acc
|
|
results["error_rate"] = err
|
|
results["macro_f1"] = macro_f1
|
|
|
|
print(
|
|
"=> result\n"
|
|
f"* total: {self._total:,}\n"
|
|
f"* correct: {self._correct:,}\n"
|
|
f"* accuracy: {acc:.2f}%\n"
|
|
f"* error: {err:.2f}%\n"
|
|
f"* macro_f1: {macro_f1:.2f}%"
|
|
)
|
|
|
|
if self._per_class_res is not None:
|
|
labels = list(self._per_class_res.keys())
|
|
labels.sort()
|
|
|
|
print("=> per-class result")
|
|
accs = []
|
|
|
|
for label in labels:
|
|
classname = self._lab2cname[label]
|
|
res = self._per_class_res[label]
|
|
correct = sum(res)
|
|
total = len(res)
|
|
acc = 100.0 * correct / total
|
|
accs.append(acc)
|
|
print(
|
|
"* class: {} ({})\t"
|
|
"total: {:,}\t"
|
|
"correct: {:,}\t"
|
|
"acc: {:.2f}%".format(
|
|
label, classname, total, correct, acc
|
|
)
|
|
)
|
|
mean_acc = np.mean(accs)
|
|
print("* average: {:.2f}%".format(mean_acc))
|
|
|
|
results["perclass_accuracy"] = mean_acc
|
|
|
|
if self.cfg.TEST.COMPUTE_CMAT:
|
|
cmat = confusion_matrix(
|
|
self._y_true, self._y_pred, normalize="true"
|
|
)
|
|
save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
|
|
torch.save(cmat, save_path)
|
|
print('Confusion matrix is saved to "{}"'.format(save_path))
|
|
|
|
return results
|