diff --git a/src/data/_AirChange.py b/src/data/_AirChange.py index 89d257e6a47c3cbf6e14cc80ce3428ed23f5e6da..00e17a1d13504271746f5591378388b2bfd8fb0e 100644 --- a/src/data/_AirChange.py +++ b/src/data/_AirChange.py @@ -59,6 +59,9 @@ class _AirChangeDataset(CDDataset): label = (label / 255.0).astype(np.uint8) # To 0,1 return label if self.phase == 'train' else self.cropper(label) + def get_name(self, index): + return '{loc}-{id}-cm.bmp'.format(loc=self.LOCATION, id=index) + @staticmethod def _bmp_loader(bmp_path_wo_ext): # Case insensitive .bmp loader diff --git a/src/data/__init__.py b/src/data/__init__.py index 14a3da5dbb783326e19999c8198160bdfd6fe4db..1f81254569e7cce132a383614235f34fae5aac6c 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,4 +1,4 @@ -from os.path import join, expanduser, basename, exists +from os.path import join, expanduser, basename, exists, splitext import torch import torch.utils.data as data @@ -41,7 +41,7 @@ class CDDataset(data.Dataset): if self.phase == 'train': return t1, t2, label else: - return basename(self.label_list[index]), t1, t2, label + return self.get_name(index), t1, t2, label def _read_file_paths(self): raise NotImplementedError @@ -52,6 +52,9 @@ class CDDataset(data.Dataset): def fetch_image(self, image_path): return default_loader(image_path) + def get_name(self, index): + return splitext(basename(self.label_list[index]))[0]+'.bmp' + def preprocess(self, t1, t2, label): if self.transforms[0] is not None: # Applied on all