Skip to content
Snippets Groups Projects
Commit a4d70bc2 authored by Bobholamovic's avatar Bobholamovic
Browse files

Fix mp error on Windows

parent 909b6859
No related branches found
No related tags found
No related merge requests found
...@@ -175,7 +175,7 @@ def _get_basic_configs(ds_name, C): ...@@ -175,7 +175,7 @@ def _get_basic_configs(ds_name, C):
return dict( return dict(
root = constants.IMDB_AIRCHANGE root = constants.IMDB_AIRCHANGE
) )
elif ds_name.startswith('Lebedev'): elif ds_name == 'Lebedev':
return dict( return dict(
root = constants.IMDB_LEBEDEV root = constants.IMDB_LEBEDEV
) )
......
import abc import abc
from os.path import join, basename from os.path import join, basename
from multiprocessing import Manager from functools import lru_cache
import numpy as np import numpy as np
...@@ -19,11 +19,6 @@ class _AirChangeDataset(CDDataset): ...@@ -19,11 +19,6 @@ class _AirChangeDataset(CDDataset):
super().__init__(root, phase, transforms, repeats) super().__init__(root, phase, transforms, repeats)
self.cropper = Crop(bounds=(0, 0, 748, 448)) self.cropper = Crop(bounds=(0, 0, 748, 448))
self._manager = Manager()
sync_list = self._manager.list
self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)])
self.labels = sync_list([None]*self.N_PAIRS)
@property @property
@abc.abstractmethod @abc.abstractmethod
def LOCATION(self): def LOCATION(self):
...@@ -41,32 +36,28 @@ class _AirChangeDataset(CDDataset): ...@@ -41,32 +36,28 @@ class _AirChangeDataset(CDDataset):
def _read_file_paths(self): def _read_file_paths(self):
if self.phase == 'train': if self.phase == 'train':
sample_ids = range(self.N_PAIRS) sample_ids = [i for i in range(self.N_PAIRS) if i not in self.TEST_SAMPLE_IDS]
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in sample_ids]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in sample_ids]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS] label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in sample_ids]
else: else:
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS] t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS] t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS] label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in self.TEST_SAMPLE_IDS]
return t1_list, t2_list, label_list return t1_list, t2_list, label_list
@lru_cache(maxsize=8)
def fetch_image(self, image_name): def fetch_image(self, image_name):
_, i, t = image_name.split('-') image = self._bmp_loader(image_name)
i, t = int(i), int(t[:-4]) return image if self.phase == 'train' else self.cropper(image)
if self.images[t][i] is None:
image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1)))
self.images[t][i] = image if self.phase == 'train' else self.cropper(image)
return self.images[t][i]
@lru_cache(maxsize=8)
def fetch_label(self, label_name): def fetch_label(self, label_name):
index = int(label_name.split('-')[1]) label = self._bmp_loader(label_name)
if self.labels[index] is None: label = (label / 255.0).astype(np.uint8) # To 0,1
label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt')) return label if self.phase == 'train' else self.cropper(label)
label = (label / 255.0).astype(np.uint8) # To 0,1
self.labels[index] = label if self.phase == 'train' else self.cropper(label)
return self.labels[index]
@staticmethod @staticmethod
def _bmp_loader(bmp_path_wo_ext): def _bmp_loader(bmp_path_wo_ext):
...@@ -74,6 +65,4 @@ class _AirChangeDataset(CDDataset): ...@@ -74,6 +65,4 @@ class _AirChangeDataset(CDDataset):
try: try:
return default_loader(bmp_path_wo_ext+'.bmp') return default_loader(bmp_path_wo_ext+'.bmp')
except FileNotFoundError: except FileNotFoundError:
return default_loader(bmp_path_wo_ext+'.BMP') return default_loader(bmp_path_wo_ext+'.BMP')
\ No newline at end of file
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment