from glob import glob
from os.path import join, basename
from multiprocessing import Manager

import numpy as np

from . import CDDataset
from .common import default_loader

class OSCDDataset(CDDataset):
    __BAND_NAMES = (
        'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 
        'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
    )
    def __init__(
        self, 
        root, phase='train', 
        transforms=(None, None, None), 
        repeats=1,
        cache_labels=True
    ):
        super().__init__(root, phase, transforms, repeats)
        self.cache_on = cache_labels
        if self.cache_on:
            self._manager = Manager()
            self.label_pool = self._manager.dict()

    def _read_file_paths(self):
        image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
        label_dir = join(self.root, 'Onera Satellite Change Detection dataset - Train Labels')
        txt_file = join(image_dir, 'train.txt')
        # Read cities
        with open(txt_file, 'r') as f:
            cities = [city.strip() for city in f.read().strip().split(',')]
        if self.phase == 'train':
            # For training, use the first 11 pairs
            cities = cities[:-3]
        else:
            # For validation, use the remaining 3 pairs
            cities = cities[-3:]
        # t1_list, t2_list = [], []
        # for city in cities:
        #     t1s = glob(join(image_dir, city, 'imgs_1', '*_B??.tif'))
        #     t1_list.append(t1s) # Populate t1_list
        #     # Recognize t2 from t1
        #     prefix = glob(join(image_dir, city, 'imgs_2/*_B01.tif'))[0][:-5]
        #     t2_list.append([prefix+t1[-5:] for t1 in t1s])
        #
        # Use resampled images
        t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
        t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
        label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
        return t1_list, t2_list, label_list

    def fetch_image(self, image_paths):
        return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)

    def fetch_label(self, label_path):
        if self.cache_on:
            label = self.label_pool.get(label_path, None)
            if label is not None:
                return label
        # In the tif labels, 1 for NC and 2 for C
        # Thus a -1 offset is needed
        label = default_loader(label_path) - 1
        if self.cache_on:
            self.label_pool[label_path] = label
        return label