Newer
Older
return args[0]
else:
raise NotImplementedError
def reset(self):
if self.calc_avg:
return "val: {} avg: {} cnt: {}".format(self.val, self.avg, self.count)
else:
return "val: {} cnt: {}".format(self.val, self.count)
def __init__(self, n_classes=2, mode='separ', reduction='binary'):
self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)), False)
self.mode = mode
raise ValueError("Binary reduction only works in 2-class cases.")
def calculate(self, pred, true, n=1):
self._cm.update(true.ravel(), pred.ravel())
if self.mode == 'accum':
cm = self._cm.sum
elif self.mode == 'separ':
cm = self._cm.val
return self._calculate_metric(cm).mean()
elif self.reduction == 'binary':
def reset(self):
super().reset()
# Reset the confusion matrix
self._cm.reset()
def __init__(self, n_classes=2, mode='separ'):
super().__init__(n_classes=n_classes, mode=mode, reduction='none')
prec = np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
recall = np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
return np.nan_to_num(2*(prec*recall) / (prec+recall))