import itertools from PIL import Image, ImageTk, TiffImagePlugin import os from collections import Counter import numpy as np class DatasetManager: _accepted_formats = ['.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'] def __init__(self, dataset_directory, dataset_format): self._dataset_directory = os.fsencode(dataset_directory) self.dataset_format = dataset_format self._volume, self.filenames = self._read_dataset() self._dtype = None self._axis = 0 def __len__(self): return self._volume.shape[0] def __getitem__(self, key): return self._volume[key, :, :] @property def dtype(self): return self._dtype @property def shape(self): return self._volume.shape @property def axis(self): return self._axis @property def axis_shape(self): return tuple([dim for axis, dim in enumerate(self.shape) if axis != self.axis]) def set_axis(self, axis): self._axis = axis def _read_dataset(self): if self.dataset_format == 'multi image': vfilenames = self._get_multi_image_volume_and_filenames() elif self.dataset_format == 'multipage file': vfilenames = self._get_multipage_image_volume_and_filename() return vfilenames def _get_multi_image_volume_and_filenames(self): filtered_string_list = [] file_formats_counter = Counter() for file in os.listdir(self._dataset_directory): filename = os.fsdecode(file) formatted_string_list = [''.join(x) for _, x in itertools.groupby(filename, key=str.isdigit)] file_format = formatted_string_list[-1] if file_format in self._accepted_formats: file_formats_counter[file_format] += 1 filtered_string_list.append(formatted_string_list) if len(file_formats_counter) == 0: raise IOError("No valid image files found") chosen_ext = file_formats_counter.most_common()[0][0] number_filename_list = [] for string_list in filtered_string_list: ext = string_list[-1] if chosen_ext == ext: idx = int(string_list[-2]) filename = ''.join(string_list) full_image_path = os.path.join(os.fsdecode(self._dataset_directory), filename) number_filename_list.append([idx, full_image_path, filename]) number_filename_list = sorted(number_filename_list) filenames = [filename for _, _, filename in number_filename_list] fpairs = zip(number_filename_list[:-1], number_filename_list[1:]) # Check that sections are consecutive for (idx_a, path_a, filename_a), (idx_b, path_b, filename_b) in fpairs: if idx_b - idx_a != 1: err_msg_info = filename_a + ' with number ' + str(idx_a) \ + ' and ' + filename_b + ' with number ' + str(idx_b) raise IOError('Files must be numbered consecutively, got: ' + err_msg_info) self.min_idx = number_filename_list[0][0] volume = np.stack([np.array(Image.open(filepath)) for (N, filepath, _) in number_filename_list]) return volume, filenames def _get_multipage_image_volume_and_filename(self): img = Image.open(self._dataset_directory) image_list = [] full_filename = os.fsdecode(self._dataset_directory).split('/')[-1] filename_list = full_filename.split(".") filename = ''.join(filename_list[:-1]) ext = filename_list[-1] filenames = [] i = 0 while True: try: img.seek(i) image_list.append(np.array(img)) filenames.append(filename + '_' + str(i) + '.' + ext) i += 1 except EOFError: # Not enough frames in img break volume = np.stack(image_list) return volume, filenames def get_image_patch(self, center, shape): if self.axis == 0: image_slice = self._volume[center[0], :, :] i, j = center[1], center[2] h, w = self._volume.shape[1], self._volume.shape[2] if self.axis == 1: image_slice = self._volume[:, center[1], :] i, j = center[0], center[2] h, w = self._volume.shape[0], self._volume.shape[2] if self.axis == 2: image_slice = self._volume[:, :, center[2]] i, j = center[0], center[1] h, w = self._volume.shape[0], self._volume.shape[1] ph, pw = shape if ph % 2 == 0 or pw % 2 == 0: raise ValueError('Shape widths must be odd') patch = np.zeros(shape, dtype=self.dtype) for si in range(ph): for sj in range(pw): ii = i + si - ph//2 jj = j + sj - pw//2 if ii >= 0 and ii < h and jj >= 0 and jj < w: patch[si, sj] = image_slice[ii, jj] return patch, (j - pw//2, i - ph//2) def initializeDatasetOutput(self, output_dir, output_shape, dataset_type): self.output_dir = output_dir #if not os.path.exists(output_dir): # os.makedirs(output_dir) if dataset_type == 'multipage file': self.ouput_writer = TiffImagePlugin.AppendingTiffWriter(output_dir,True) def saveTransformedImage(self, k, image, dataset_type): if dataset_type == 'multi image': filename = self.filenames[k] output_filepath = os.path.join(self.output_dir, filename) output_image = Image.fromarray(image).convert('L') output_image.save(output_filepath) elif dataset_type == 'multipage file': im = Image.fromarray(image).convert('L') im.save(self.ouput_writer) self.ouput_writer.newFrame()