diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 3e2b24faf3bf8774c7b8d3391055b615642b7709..3a13775ebecab0fdc1e44d705ff32971eddfb98c 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -23,6 +23,11 @@ class AverageMeter: self.sum = 0 self.count = 0 + for attr in filter(lambda a: not a.startswith('__'), dir(self)): + obj = getattr(self, attr) + if isinstance(obj, AverageMeter): + AverageMeter.reset(obj) + def update(self, *args, n=1): self.val = self.compute(*args) self.sum += self.val * n