Skip to content
Snippets Groups Projects
Commit d3d59c95 authored by Bobholamovic's avatar Bobholamovic
Browse files

Add multi-level cache for OSCD

parent 7d951e6d
No related branches found
No related tags found
1 merge request!2Update outdated code
import os
from glob import glob from glob import glob
from os.path import join, basename from os.path import join, basename
from multiprocessing import Manager from multiprocessing import Manager
...@@ -17,13 +18,14 @@ class OSCDDataset(CDDataset): ...@@ -17,13 +18,14 @@ class OSCDDataset(CDDataset):
root, phase='train', root, phase='train',
transforms=(None, None, None), transforms=(None, None, None),
repeats=1, repeats=1,
cache_labels=True cache_level=1
): ):
super().__init__(root, phase, transforms, repeats) super().__init__(root, phase, transforms, repeats)
self.cache_on = cache_labels # 0 for no cache, 1 for caching labels only, 2 and higher for caching all
if self.cache_on: self.cache_level = int(cache_level)
if self.cache_level > 0:
self._manager = Manager() self._manager = Manager()
self.label_pool = self._manager.dict() self._pool = self._manager.dict()
def _read_file_paths(self): def _read_file_paths(self):
image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images') image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
...@@ -38,62 +40,34 @@ class OSCDDataset(CDDataset): ...@@ -38,62 +40,34 @@ class OSCDDataset(CDDataset):
else: else:
# For validation, use the remaining 3 pairs # For validation, use the remaining 3 pairs
cities = cities[-3:] cities = cities[-3:]
# t1_list, t2_list = [], []
# for city in cities:
# t1s = glob(join(image_dir, city, 'imgs_1', '*_B??.tif'))
# t1_list.append(t1s) # Populate t1_list
# # Recognize t2 from t1
# prefix = glob(join(image_dir, city, 'imgs_2/*_B01.tif'))[0][:-5]
# t2_list.append([prefix+t1[-5:] for t1 in t1s])
#
# Use resampled images # Use resampled images
t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities] t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities] t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities] label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
#准备数据
print('preparing %s data ... \n'%self.phase)
pb = tqdm(list(range(len(t1_list))))
self.t1_imgs = []
self.t2_imgs = []
for i in pb:
self.t1_imgs.append(self.fetch_image(t1_list[i]))
self.t2_imgs.append(self.fetch_image(t2_list[i]))
return t1_list, t2_list, label_list return t1_list, t2_list, label_list
#重写该方法
def __getitem__(self, index):
if index >= len(self):
raise IndexError
index = index % self.len
t1 = self.t1_imgs[index]
t2 = self.t2_imgs[index]
label = self.fetch_label(self.label_list[index])
t1, t2, label = self.preprocess(t1, t2, label)
if self.phase == 'train':
return t1, t2, label
else:
return self.get_name(index), t1, t2, label
def fetch_image(self, image_paths): def fetch_image(self, image_paths):
return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32) key = '-'.join(image_paths[0].split(os.sep)[-3:-1])
if self.cache_level >= 2:
image = self._pool.get(key, None)
if image is not None:
return image
image = np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
if self.cache_level >= 2:
self._pool[key] = image
return image
def fetch_label(self, label_path): def fetch_label(self, label_path):
if self.cache_on: key = basename(label_path)
label = self.label_pool.get(label_path, None) if self.cache_level >= 1:
label = self._pool.get(key, None)
if label is not None: if label is not None:
return label return label
# In the tif labels, 1 for NC and 2 for C # In the tif labels, 1 for NC and 2 for C
# Thus a -1 offset is needed # Thus a -1 offset is needed
label = default_loader(label_path) - 1 label = default_loader(label_path) - 1
if self.cache_on: if self.cache_level >= 1:
self.label_pool[label_path] = label self._pool[key] = label
return label return label
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment