Skip to content
Snippets Groups Projects
Select Git revision
  • d3d59c95ba92a40bee1bbd8a24d3214bebc7764b
  • master default protected
  • github/fork/Bobholamovic/master
3 results

OSCD.py

Blame
  • user avatar
    Bobholamovic authored
    d3d59c95
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    OSCD.py 2.65 KiB
    import os
    from glob import glob
    from os.path import join, basename
    from multiprocessing import Manager
    
    import numpy as np
    
    from . import CDDataset
    from .common import default_loader
    
    class OSCDDataset(CDDataset):
        __BAND_NAMES = (
            'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 
            'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
        )
        def __init__(
            self, 
            root, phase='train', 
            transforms=(None, None, None), 
            repeats=1,
            cache_level=1
        ):
            super().__init__(root, phase, transforms, repeats)
            # 0 for no cache, 1 for caching labels only, 2 and higher for caching all
            self.cache_level = int(cache_level)
            if self.cache_level > 0:
                self._manager = Manager()
                self._pool = self._manager.dict()
    
        def _read_file_paths(self):
            image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
            label_dir = join(self.root, 'Onera Satellite Change Detection dataset - Train Labels')
            txt_file = join(image_dir, 'train.txt')
            # Read cities
            with open(txt_file, 'r') as f:
                cities = [city.strip() for city in f.read().strip().split(',')]
            if self.phase == 'train':
                # For training, use the first 11 pairs
                cities = cities[:-3]
            else:
                # For validation, use the remaining 3 pairs
                cities = cities[-3:]
                
            # Use resampled images
            t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
            t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
            label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
    
            return t1_list, t2_list, label_list
    
        def fetch_image(self, image_paths):
            key = '-'.join(image_paths[0].split(os.sep)[-3:-1])
            if self.cache_level >= 2:
                image = self._pool.get(key, None)
                if image is not None:
                    return image
            image = np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
            if self.cache_level >= 2:
                self._pool[key] = image
            return image
    
        def fetch_label(self, label_path):
            key = basename(label_path)
            if self.cache_level >= 1:
                label = self._pool.get(key, None)
                if label is not None:
                    return label
            # In the tif labels, 1 for NC and 2 for C
            # Thus a -1 offset is needed
            label = default_loader(label_path) - 1
            if self.cache_level >= 1:
                self._pool[key] = label
            return label