From e80c87ff7b3b85ee46a3e099c0f1f4ae85551757 Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Tue, 14 Apr 2020 13:13:24 +0800 Subject: [PATCH] Fix filename without suffix --- src/data/_AirChange.py | 3 +++ src/data/__init__.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/data/_AirChange.py b/src/data/_AirChange.py index 89d257e..00e17a1 100644 --- a/src/data/_AirChange.py +++ b/src/data/_AirChange.py @@ -59,6 +59,9 @@ class _AirChangeDataset(CDDataset): label = (label / 255.0).astype(np.uint8) # To 0,1 return label if self.phase == 'train' else self.cropper(label) + def get_name(self, index): + return '{loc}-{id}-cm.bmp'.format(loc=self.LOCATION, id=index) + @staticmethod def _bmp_loader(bmp_path_wo_ext): # Case insensitive .bmp loader diff --git a/src/data/__init__.py b/src/data/__init__.py index 14a3da5..1f81254 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,4 +1,4 @@ -from os.path import join, expanduser, basename, exists +from os.path import join, expanduser, basename, exists, splitext import torch import torch.utils.data as data @@ -41,7 +41,7 @@ class CDDataset(data.Dataset): if self.phase == 'train': return t1, t2, label else: - return basename(self.label_list[index]), t1, t2, label + return self.get_name(index), t1, t2, label def _read_file_paths(self): raise NotImplementedError @@ -52,6 +52,9 @@ class CDDataset(data.Dataset): def fetch_image(self, image_path): return default_loader(image_path) + def get_name(self, index): + return splitext(basename(self.label_list[index]))[0]+'.bmp' + def preprocess(self, t1, t2, label): if self.transforms[0] is not None: # Applied on all -- GitLab