Skip to content
Snippets Groups Projects
_AirChange.py 2.73 KiB
Newer Older
  • Learn to ignore specific revisions
  • Bobholamovic's avatar
    Bobholamovic committed
    import abc
    from os.path import join, basename
    from multiprocessing import Manager
    
    import numpy as np
    
    from . import CDDataset
    from .common import default_loader
    from .augmentation import Crop
    
    
    class _AirChangeDataset(CDDataset):
        def __init__(
            self, 
            root, phase='train', 
            transforms=(None, None, None), 
            repeats=1
        ):
            super().__init__(root, phase, transforms, repeats)
            self.cropper = Crop(bounds=(0, 0, 748, 448))
    
            self._manager = Manager()
            sync_list = self._manager.list
            self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)])
            self.labels = sync_list([None]*self.N_PAIRS)
    
        @property
        @abc.abstractmethod
        def LOCATION(self):
            return ''
    
        @property
        @abc.abstractmethod
        def TEST_SAMPLE_IDS(self):
            return ()
    
        @property
        @abc.abstractmethod
        def N_PAIRS(self):
            return 0
    
        def _read_file_paths(self):
            if self.phase == 'train':
                sample_ids = range(self.N_PAIRS)
                t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
                t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
                label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
            else:
                t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS]
                t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS]
                label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS]
    
            return t1_list, t2_list, label_list
    
        def fetch_image(self, image_name):
            _, i, t = image_name.split('-')
            i, t = int(i), int(t[:-4])
            if self.images[t][i] is None:
                image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1)))
                self.images[t][i] = image if self.phase == 'train' else self.cropper(image)
            return self.images[t][i]
    
        def fetch_label(self, label_name):
            index = int(label_name.split('-')[1])
            if self.labels[index] is None:
                label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt'))
                label = (label / 255.0).astype(np.uint8)    # To 0,1
                self.labels[index] = label if self.phase == 'train' else self.cropper(label)
            return self.labels[index]
    
        @staticmethod
        def _bmp_loader(bmp_path_wo_ext):
            # Case insensitive .bmp loader
            try:
                return default_loader(bmp_path_wo_ext+'.bmp')
            except FileNotFoundError:
                return default_loader(bmp_path_wo_ext+'.BMP')