Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from sklearn import metrics
class AverageMeter:
def __init__(self, callback=None):
super().__init__()
self.callback = callback
self.reset()
def compute(self, *args):
if self.callback is not None:
return self.callback(*args)
elif len(args) == 1:
return args[0]
else:
raise NotImplementedError
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0
def update(self, *args, n=1):
self.val = self.compute(*args)
self.sum += self.val * n
self.count += n
self.avg = self.sum / self.count
class Metric(AverageMeter):
__name__ = 'Metric'
def __init__(self, callback, **configs):
super().__init__(callback)
self.configs = configs
def compute(self, pred, true):
return self.callback(true.ravel(), pred.ravel(), **self.configs)
class Precision(Metric):
__name__ = 'Prec.'
def __init__(self, **configs):
super().__init__(metrics.precision_score, **configs)
class Recall(Metric):
__name__ = 'Recall'
def __init__(self, **configs):
super().__init__(metrics.recall_score, **configs)
class Accuracy(Metric):
__name__ = 'OA'
def __init__(self, **configs):
super().__init__(metrics.accuracy_score, **configs)
class F1Score(Metric):
__name__ = 'F1'
def __init__(self, **configs):
super().__init__(metrics.f1_score, **configs)