diff --git a/src/core/factories.py b/src/core/factories.py index 23198c7883589b4b5213b74392cb41875771bb93..98963cefdb0df754e418ceda650d1f7305e89b9b 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -175,7 +175,7 @@ def _get_basic_configs(ds_name, C): return dict( root = constants.IMDB_AIRCHANGE ) - elif ds_name.startswith('Lebedev'): + elif ds_name == 'Lebedev': return dict( root = constants.IMDB_LEBEDEV ) diff --git a/src/data/_AirChange.py b/src/data/_AirChange.py index 4d8555d843bc8b5933e166a4c9aa1612b252cf49..89d257e6a47c3cbf6e14cc80ce3428ed23f5e6da 100644 --- a/src/data/_AirChange.py +++ b/src/data/_AirChange.py @@ -1,6 +1,6 @@ import abc from os.path import join, basename -from multiprocessing import Manager +from functools import lru_cache import numpy as np @@ -19,11 +19,6 @@ class _AirChangeDataset(CDDataset): super().__init__(root, phase, transforms, repeats) self.cropper = Crop(bounds=(0, 0, 748, 448)) - self._manager = Manager() - sync_list = self._manager.list - self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)]) - self.labels = sync_list([None]*self.N_PAIRS) - @property @abc.abstractmethod def LOCATION(self): @@ -41,32 +36,28 @@ class _AirChangeDataset(CDDataset): def _read_file_paths(self): if self.phase == 'train': - sample_ids = range(self.N_PAIRS) - t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] - t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] - label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] + sample_ids = [i for i in range(self.N_PAIRS) if i not in self.TEST_SAMPLE_IDS] + t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in sample_ids] + t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in sample_ids] + label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in sample_ids] else: - t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS] - t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS] - label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS] + t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in self.TEST_SAMPLE_IDS] + t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in self.TEST_SAMPLE_IDS] + label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in self.TEST_SAMPLE_IDS] return t1_list, t2_list, label_list + + @lru_cache(maxsize=8) def fetch_image(self, image_name): - _, i, t = image_name.split('-') - i, t = int(i), int(t[:-4]) - if self.images[t][i] is None: - image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1))) - self.images[t][i] = image if self.phase == 'train' else self.cropper(image) - return self.images[t][i] + image = self._bmp_loader(image_name) + return image if self.phase == 'train' else self.cropper(image) + @lru_cache(maxsize=8) def fetch_label(self, label_name): - index = int(label_name.split('-')[1]) - if self.labels[index] is None: - label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt')) - label = (label / 255.0).astype(np.uint8) # To 0,1 - self.labels[index] = label if self.phase == 'train' else self.cropper(label) - return self.labels[index] + label = self._bmp_loader(label_name) + label = (label / 255.0).astype(np.uint8) # To 0,1 + return label if self.phase == 'train' else self.cropper(label) @staticmethod def _bmp_loader(bmp_path_wo_ext): @@ -74,6 +65,4 @@ class _AirChangeDataset(CDDataset): try: return default_loader(bmp_path_wo_ext+'.bmp') except FileNotFoundError: - return default_loader(bmp_path_wo_ext+'.BMP') - - \ No newline at end of file + return default_loader(bmp_path_wo_ext+'.BMP') \ No newline at end of file