Skip to content
Snippets Groups Projects
Lebedev.py 1.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • Bobholamovic's avatar
    Bobholamovic committed
    from glob import glob
    from os.path import join, basename
    
    import numpy as np
    
    from . import CDDataset
    from .common import default_loader
    
    class LebedevDataset(CDDataset):
        def __init__(
            self, 
            root, phase='train', 
            transforms=(None, None, None), 
            repeats=1,
            subsets=('real', 'with_shift', 'without_shift')
        ):
            self.subsets = subsets
            super().__init__(root, phase, transforms, repeats)
    
        def _read_file_paths(self):
            t1_list, t2_list, label_list = [], [], []
    
            for subset in self.subsets:
                # Get subset directory
                if subset == 'real':
                    subset_dir = join(self.root, 'Real', 'subset')
                elif subset == 'with_shift':
                    subset_dir = join(self.root, 'Model', 'with_shift')
                elif subset == 'without_shift':
                    subset_dir = join(self.root, 'Model', 'without_shift')
                else:
                    raise RuntimeError('unrecognized key encountered')
    
                pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
                refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
                t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
                t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
    
                label_list.extend(refs)
                t1_list.extend(t1s)
                t2_list.extend(t2s)
    
            return t1_list, t2_list, label_list
    
        def fetch_label(self, label_path):
            # To {0,1}
            return (super().fetch_label(label_path) / 255.0).astype(np.uint8)