From e80c87ff7b3b85ee46a3e099c0f1f4ae85551757 Mon Sep 17 00:00:00 2001
From: Bobholamovic <bob1998425@hotmail.com>
Date: Tue, 14 Apr 2020 13:13:24 +0800
Subject: [PATCH] Fix filename without suffix

---
 src/data/_AirChange.py | 3 +++
 src/data/__init__.py   | 7 +++++--
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/src/data/_AirChange.py b/src/data/_AirChange.py
index 89d257e..00e17a1 100644
--- a/src/data/_AirChange.py
+++ b/src/data/_AirChange.py
@@ -59,6 +59,9 @@ class _AirChangeDataset(CDDataset):
         label = (label / 255.0).astype(np.uint8)    # To 0,1
         return label if self.phase == 'train' else self.cropper(label)
 
+    def get_name(self, index):
+        return '{loc}-{id}-cm.bmp'.format(loc=self.LOCATION, id=index)
+
     @staticmethod
     def _bmp_loader(bmp_path_wo_ext):
         # Case insensitive .bmp loader
diff --git a/src/data/__init__.py b/src/data/__init__.py
index 14a3da5..1f81254 100644
--- a/src/data/__init__.py
+++ b/src/data/__init__.py
@@ -1,4 +1,4 @@
-from os.path import join, expanduser, basename, exists
+from os.path import join, expanduser, basename, exists, splitext
 
 import torch
 import torch.utils.data as data
@@ -41,7 +41,7 @@ class CDDataset(data.Dataset):
         if self.phase == 'train':
             return t1, t2, label
         else:
-            return basename(self.label_list[index]), t1, t2, label
+            return self.get_name(index), t1, t2, label
 
     def _read_file_paths(self):
         raise NotImplementedError
@@ -52,6 +52,9 @@ class CDDataset(data.Dataset):
     def fetch_image(self, image_path):
         return default_loader(image_path)
 
+    def get_name(self, index):
+        return splitext(basename(self.label_list[index]))[0]+'.bmp'
+
     def preprocess(self, t1, t2, label):
         if self.transforms[0] is not None:
             # Applied on all
-- 
GitLab