diff --git a/config_base.yaml b/config_base.yaml index 23154e126550632ada941725534be645e41a32c0..f581c7126650b139df6c0f1e46ccbc823867f67a 100644 --- a/config_base.yaml +++ b/config_base.yaml @@ -19,7 +19,7 @@ step: 2 # Training related batch_size: 32 -num_epochs: 20 +num_epochs: 40 resume: '' load_optim: True anew: False diff --git a/src/core/trainers.py b/src/core/trainers.py index ddefda42c5e2d3adeddc52af3ed77887a3343fd4..2c5b139b894a9262a67495d47b88ac6578e0f056 100644 --- a/src/core/trainers.py +++ b/src/core/trainers.py @@ -267,6 +267,6 @@ class CDTrainer(Trainer): self.logger.dump(desc) if store: - self.save_image(name[0], CM, epoch) + self.save_image(name[0], CM.squeeze(-1), epoch) return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc) \ No newline at end of file