Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from glob import glob
from os.path import join, basename
import numpy as np
from . import CDDataset
from .common import default_loader
class LebedevDataset(CDDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1,
subsets=('real', 'with_shift', 'without_shift')
):
self.subsets = subsets
super().__init__(root, phase, transforms, repeats)
def _read_file_paths(self):
t1_list, t2_list, label_list = [], [], []
for subset in self.subsets:
# Get subset directory
if subset == 'real':
subset_dir = join(self.root, 'Real', 'subset')
elif subset == 'with_shift':
subset_dir = join(self.root, 'Model', 'with_shift')
elif subset == 'without_shift':
subset_dir = join(self.root, 'Model', 'without_shift')
else:
raise RuntimeError('unrecognized key encountered')
pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
label_list.extend(refs)
t1_list.extend(t1s)
t2_list.extend(t2s)
return t1_list, t2_list, label_list
def fetch_label(self, label_path):
# To {0,1}
return (super().fetch_label(label_path) / 255.0).astype(np.uint8)