diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 3a13775ebecab0fdc1e44d705ff32971eddfb98c..6610bfb5cf5f0110a1f347acfdafdf509ddbc2a0 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -41,7 +41,7 @@ class AverageMeter: # These metrics only for numpy arrays class Metric(AverageMeter): __name__ = 'Metric' - def __init__(self, n_classes=2, mode='accum', reduction='binary'): + def __init__(self, n_classes=2, mode='separ', reduction='binary'): super().__init__(None) self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes))) assert mode in ('accum', 'separ') @@ -94,7 +94,7 @@ class Recall(Metric): class Accuracy(Metric): __name__ = 'OA' - def __init__(self, n_classes=2, mode='accum'): + def __init__(self, n_classes=2, mode='separ'): super().__init__(n_classes=n_classes, mode=mode, reduction='none') def _compute(self, cm): return np.nan_to_num(np.diag(cm).sum()/cm.sum())