Skip to content
Snippets Groups Projects
DatasetManager.py 5.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • import itertools
    from PIL import Image, ImageTk, TiffImagePlugin
    import os
    from collections import Counter
    import numpy as np
    
    class DatasetManager:
    
        _accepted_formats = ['.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp']
    
        def __init__(self, dataset_directory, dataset_format):
            self._dataset_directory = os.fsencode(dataset_directory)
            self.dataset_format = dataset_format
            self._volume, self.filenames = self._read_dataset()
            self._dtype = None
            self._axis = 0
    
        def __len__(self):
            return self._volume.shape[0]
    
        def __getitem__(self, key):
            return self._volume[key, :, :]
    
        @property
        def dtype(self):
            return self._dtype
    
        @property
        def shape(self):
            return self._volume.shape
    
        @property
        def axis(self):
            return self._axis
    
        @property
        def axis_shape(self):
            return tuple([dim for axis, dim in enumerate(self.shape) if axis != self.axis])
    
        def set_axis(self, axis):
            self._axis = axis
    
        def _read_dataset(self):
    
            if self.dataset_format == 'multi image':
                vfilenames = self._get_multi_image_volume_and_filenames()
            elif self.dataset_format == 'multipage file':
                vfilenames = self._get_multipage_image_volume_and_filename()
    
            return vfilenames
    
        def _get_multi_image_volume_and_filenames(self):
    
            filtered_string_list = []
            file_formats_counter = Counter()
    
            for file in os.listdir(self._dataset_directory):
                filename = os.fsdecode(file)
    
                formatted_string_list = [''.join(x) for _, x in itertools.groupby(filename, key=str.isdigit)]
    
                file_format = formatted_string_list[-1]
    
                if file_format in self._accepted_formats:
                    file_formats_counter[file_format] += 1
                    filtered_string_list.append(formatted_string_list)
    
            if len(file_formats_counter) == 0:
                raise IOError("No valid image files found")
    
            chosen_ext = file_formats_counter.most_common()[0][0]
    
            number_filename_list = []
    
            for string_list in filtered_string_list:
                ext = string_list[-1]
                if chosen_ext == ext:
                    idx = int(string_list[-2])
                    filename = ''.join(string_list)
                    full_image_path = os.path.join(os.fsdecode(self._dataset_directory), filename)
                    number_filename_list.append([idx, full_image_path, filename])
    
            number_filename_list = sorted(number_filename_list)
            filenames = [filename for _, _, filename in number_filename_list]
    
            fpairs = zip(number_filename_list[:-1], number_filename_list[1:])
    
            # Check that sections are consecutive
            for (idx_a, path_a, filename_a), (idx_b, path_b, filename_b) in fpairs:
                if idx_b - idx_a != 1:
                    err_msg_info = filename_a + ' with number ' + str(idx_a) \
                         + ' and ' + filename_b + ' with number ' + str(idx_b)
    
                    raise IOError('Files must be numbered consecutively, got: ' + err_msg_info)
    
            self.min_idx = number_filename_list[0][0]
    
            volume = np.stack([np.array(Image.open(filepath)) for (N, filepath, _) in number_filename_list])
            return volume, filenames
    
        def _get_multipage_image_volume_and_filename(self):
            img = Image.open(self._dataset_directory)
            image_list = []
    
            full_filename = os.fsdecode(self._dataset_directory).split('/')[-1]
            filename_list = full_filename.split(".")
            filename = ''.join(filename_list[:-1])
            ext = filename_list[-1]
    
            filenames = []
    
            i = 0
    
            while True:
                try:
                    img.seek(i)
                    image_list.append(np.array(img))
                    filenames.append(filename + '_' + str(i) + '.' + ext)
                    i += 1
                except EOFError:
                    # Not enough frames in img
                    break
    
            volume = np.stack(image_list)
    
            return volume, filenames
    
        def get_image_patch(self, center, shape):
    
            if self.axis == 0:
                image_slice = self._volume[center[0], :, :]
                i, j = center[1], center[2]
                h, w = self._volume.shape[1], self._volume.shape[2]
            if self.axis == 1:
                image_slice = self._volume[:, center[1], :]
                i, j = center[0], center[2]
                h, w = self._volume.shape[0], self._volume.shape[2]
            if self.axis == 2:
                image_slice = self._volume[:, :, center[2]]
                i, j = center[0], center[1]
                h, w = self._volume.shape[0], self._volume.shape[1]
    
            ph, pw = shape
    
            if ph % 2 == 0 or pw % 2 == 0:
                raise ValueError('Shape widths must be odd')
    
            patch = np.zeros(shape, dtype=self.dtype)
    
            for si in range(ph):
                for sj in range(pw):
                    ii = i + si - ph//2
                    jj = j + sj - pw//2
                    if ii >= 0 and ii < h and jj >= 0 and jj < w:
                        patch[si, sj] = image_slice[ii, jj]
    
            return patch, (j - pw//2, i - ph//2)
    
    
        def initializeDatasetOutput(self, output_dir, output_shape, dataset_type):
    
            self.output_dir = output_dir
    
            #if not os.path.exists(output_dir):
            #    os.makedirs(output_dir)
    
            if dataset_type == 'multipage file':
                self.ouput_writer = TiffImagePlugin.AppendingTiffWriter(output_dir,True)
    
        def saveTransformedImage(self, k, image, dataset_type):
    
            if dataset_type == 'multi image':
                filename = self.filenames[k]
                output_filepath = os.path.join(self.output_dir, filename)
                output_image = Image.fromarray(image).convert('L')
                output_image.save(output_filepath)
    
            elif dataset_type == 'multipage file':
                im = Image.fromarray(image).convert('L')
                im.save(self.ouput_writer)
                self.ouput_writer.newFrame()