Skip to content
Snippets Groups Projects
OSCD.py 2.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • Bobholamovic's avatar
    Bobholamovic committed
    from glob import glob
    from os.path import join, basename
    from multiprocessing import Manager
    
    import numpy as np
    
    from . import CDDataset
    from .common import default_loader
    
    class OSCDDataset(CDDataset):
        __BAND_NAMES = (
            'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 
            'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
        )
        def __init__(
            self, 
            root, phase='train', 
            transforms=(None, None, None), 
            repeats=1,
            cache_labels=True
        ):
            super().__init__(root, phase, transforms, repeats)
            self.cache_on = cache_labels
            if self.cache_on:
                self._manager = Manager()
                self.label_pool = self._manager.dict()
    
        def _read_file_paths(self):
            image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
            label_dir = join(self.root, 'Onera Satellite Change Detection dataset - Train Labels')
            txt_file = join(image_dir, 'train.txt')
            # Read cities
            with open(txt_file, 'r') as f:
                cities = [city.strip() for city in f.read().strip().split(',')]
            if self.phase == 'train':
    
    Bobholamovic's avatar
    Bobholamovic committed
                # For training, use the first 11 pairs
    
    Bobholamovic's avatar
    Bobholamovic committed
                cities = cities[:-3]
            else:
                # For validation, use the remaining 3 pairs
                cities = cities[-3:]
            # t1_list, t2_list = [], []
            # for city in cities:
            #     t1s = glob(join(image_dir, city, 'imgs_1', '*_B??.tif'))
            #     t1_list.append(t1s) # Populate t1_list
            #     # Recognize t2 from t1
            #     prefix = glob(join(image_dir, city, 'imgs_2/*_B01.tif'))[0][:-5]
            #     t2_list.append([prefix+t1[-5:] for t1 in t1s])
            #
            # Use resampled images
            t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
            t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
            label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
            return t1_list, t2_list, label_list
    
        def fetch_image(self, image_paths):
            return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
    
        def fetch_label(self, label_path):
            if self.cache_on:
                label = self.label_pool.get(label_path, None)
                if label is not None:
                    return label
            # In the tif labels, 1 for NC and 2 for C
            # Thus a -1 offset is needed
            label = default_loader(label_path) - 1
            if self.cache_on:
                self.label_pool[label_path] = label
            return label