Skip to content
Snippets Groups Projects
_AirChange.py 2.15 KiB
Newer Older
  • Learn to ignore specific revisions
  • Bobholamovic's avatar
    Bobholamovic committed
    import abc
    from os.path import join, basename
    
    Bobholamovic's avatar
    Bobholamovic committed
    from functools import lru_cache
    
    Bobholamovic's avatar
    Bobholamovic committed
    
    import numpy as np
    
    from . import CDDataset
    from .common import default_loader
    from .augmentation import Crop
    
    
    class _AirChangeDataset(CDDataset):
        def __init__(
            self, 
            root, phase='train', 
            transforms=(None, None, None), 
            repeats=1
        ):
            super().__init__(root, phase, transforms, repeats)
            self.cropper = Crop(bounds=(0, 0, 748, 448))
    
        @property
        @abc.abstractmethod
        def LOCATION(self):
            return ''
    
        @property
        @abc.abstractmethod
        def TEST_SAMPLE_IDS(self):
            return ()
    
        @property
        @abc.abstractmethod
        def N_PAIRS(self):
            return 0
    
        def _read_file_paths(self):
            if self.phase == 'train':
    
    Bobholamovic's avatar
    Bobholamovic committed
                sample_ids = [i for i in range(self.N_PAIRS) if i not in self.TEST_SAMPLE_IDS]
                t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in sample_ids]
                t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in sample_ids]
                label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in sample_ids]
    
    Bobholamovic's avatar
    Bobholamovic committed
            else:
    
    Bobholamovic's avatar
    Bobholamovic committed
                t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in self.TEST_SAMPLE_IDS]
                t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in self.TEST_SAMPLE_IDS]
                label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in self.TEST_SAMPLE_IDS]
    
    Bobholamovic's avatar
    Bobholamovic committed
    
            return t1_list, t2_list, label_list
    
    
    Bobholamovic's avatar
    Bobholamovic committed
    
        @lru_cache(maxsize=8)
    
    Bobholamovic's avatar
    Bobholamovic committed
        def fetch_image(self, image_name):
    
    Bobholamovic's avatar
    Bobholamovic committed
            image = self._bmp_loader(image_name)
            return image if self.phase == 'train' else self.cropper(image)
    
    Bobholamovic's avatar
    Bobholamovic committed
        @lru_cache(maxsize=8)
    
    Bobholamovic's avatar
    Bobholamovic committed
        def fetch_label(self, label_name):
    
    Bobholamovic's avatar
    Bobholamovic committed
            label = self._bmp_loader(label_name)
            label = (label / 255.0).astype(np.uint8)    # To 0,1
            return label if self.phase == 'train' else self.cropper(label)
    
    Bobholamovic's avatar
    Bobholamovic committed
    
        @staticmethod
        def _bmp_loader(bmp_path_wo_ext):
            # Case insensitive .bmp loader
            try:
                return default_loader(bmp_path_wo_ext+'.bmp')
            except FileNotFoundError:
    
    Bobholamovic's avatar
    Bobholamovic committed
                return default_loader(bmp_path_wo_ext+'.BMP')