Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import abc
from os.path import join, basename
from multiprocessing import Manager
import numpy as np
from . import CDDataset
from .common import default_loader
from .augmentation import Crop
class _AirChangeDataset(CDDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1
):
super().__init__(root, phase, transforms, repeats)
self.cropper = Crop(bounds=(0, 0, 748, 448))
self._manager = Manager()
sync_list = self._manager.list
self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)])
self.labels = sync_list([None]*self.N_PAIRS)
@property
@abc.abstractmethod
def LOCATION(self):
return ''
@property
@abc.abstractmethod
def TEST_SAMPLE_IDS(self):
return ()
@property
@abc.abstractmethod
def N_PAIRS(self):
return 0
def _read_file_paths(self):
if self.phase == 'train':
sample_ids = range(self.N_PAIRS)
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
else:
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS]
return t1_list, t2_list, label_list
def fetch_image(self, image_name):
_, i, t = image_name.split('-')
i, t = int(i), int(t[:-4])
if self.images[t][i] is None:
image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1)))
self.images[t][i] = image if self.phase == 'train' else self.cropper(image)
return self.images[t][i]
def fetch_label(self, label_name):
index = int(label_name.split('-')[1])
if self.labels[index] is None:
label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt'))
label = (label / 255.0).astype(np.uint8) # To 0,1
self.labels[index] = label if self.phase == 'train' else self.cropper(label)
return self.labels[index]
@staticmethod
def _bmp_loader(bmp_path_wo_ext):
# Case insensitive .bmp loader
try:
return default_loader(bmp_path_wo_ext+'.bmp')
except FileNotFoundError:
return default_loader(bmp_path_wo_ext+'.BMP')