from glob import glob
from os.path import join, basename

import numpy as np

from . import CDDataset
from .common import default_loader

class LebedevDataset(CDDataset):
    def __init__(
        self, 
        root, phase='train', 
        transforms=(None, None, None), 
        repeats=1,
        subsets=('real', 'with_shift', 'without_shift')
    ):
        self.subsets = subsets
        super().__init__(root, phase, transforms, repeats)

    def _read_file_paths(self):
        t1_list, t2_list, label_list = [], [], []

        for subset in self.subsets:
            # Get subset directory
            if subset == 'real':
                subset_dir = join(self.root, 'Real', 'subset')
            elif subset == 'with_shift':
                subset_dir = join(self.root, 'Model', 'with_shift')
            elif subset == 'without_shift':
                subset_dir = join(self.root, 'Model', 'without_shift')
            else:
                raise RuntimeError('unrecognized key encountered')

            pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
            refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
            t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
            t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)

            label_list.extend(refs)
            t1_list.extend(t1s)
            t2_list.extend(t2s)

        return t1_list, t2_list, label_list

    def fetch_label(self, label_path):
        # To {0,1}
        return (super().fetch_label(label_path) / 255.0).astype(np.uint8)