Skip to content
Snippets Groups Projects

Update outdated code

Open manli requested to merge github/fork/Bobholamovic/master into master
1 file
+ 21
47
Compare changes
  • Side-by-side
  • Inline
+ 21
47
import os
from glob import glob
from os.path import join, basename
from multiprocessing import Manager
@@ -17,13 +18,14 @@ class OSCDDataset(CDDataset):
root, phase='train',
transforms=(None, None, None),
repeats=1,
cache_labels=True
cache_level=1
):
super().__init__(root, phase, transforms, repeats)
self.cache_on = cache_labels
if self.cache_on:
# 0 for no cache, 1 for caching labels only, 2 and higher for caching all
self.cache_level = int(cache_level)
if self.cache_level > 0:
self._manager = Manager()
self.label_pool = self._manager.dict()
self._pool = self._manager.dict()
def _read_file_paths(self):
image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
@@ -38,62 +40,34 @@ class OSCDDataset(CDDataset):
else:
# For validation, use the remaining 3 pairs
cities = cities[-3:]
# t1_list, t2_list = [], []
# for city in cities:
# t1s = glob(join(image_dir, city, 'imgs_1', '*_B??.tif'))
# t1_list.append(t1s) # Populate t1_list
# # Recognize t2 from t1
# prefix = glob(join(image_dir, city, 'imgs_2/*_B01.tif'))[0][:-5]
# t2_list.append([prefix+t1[-5:] for t1 in t1s])
#
# Use resampled images
t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
#准备数据
print('preparing %s data ... \n'%self.phase)
pb = tqdm(list(range(len(t1_list))))
self.t1_imgs = []
self.t2_imgs = []
for i in pb:
self.t1_imgs.append(self.fetch_image(t1_list[i]))
self.t2_imgs.append(self.fetch_image(t2_list[i]))
return t1_list, t2_list, label_list
#重写该方法
def __getitem__(self, index):
if index >= len(self):
raise IndexError
index = index % self.len
t1 = self.t1_imgs[index]
t2 = self.t2_imgs[index]
label = self.fetch_label(self.label_list[index])
t1, t2, label = self.preprocess(t1, t2, label)
if self.phase == 'train':
return t1, t2, label
else:
return self.get_name(index), t1, t2, label
def fetch_image(self, image_paths):
return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
key = '-'.join(image_paths[0].split(os.sep)[-3:-1])
if self.cache_level >= 2:
image = self._pool.get(key, None)
if image is not None:
return image
image = np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
if self.cache_level >= 2:
self._pool[key] = image
return image
def fetch_label(self, label_path):
if self.cache_on:
label = self.label_pool.get(label_path, None)
key = basename(label_path)
if self.cache_level >= 1:
label = self._pool.get(key, None)
if label is not None:
return label
# In the tif labels, 1 for NC and 2 for C
# Thus a -1 offset is needed
label = default_loader(label_path) - 1
if self.cache_on:
self.label_pool[label_path] = label
if self.cache_level >= 1:
self._pool[key] = label
return label
Loading