import itertools
from PIL import Image, ImageTk, TiffImagePlugin
import os
from collections import Counter
import numpy as np

class DatasetManager:

    _accepted_formats = ['.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp']

    def __init__(self, dataset_directory, dataset_format):
        self._dataset_directory = os.fsencode(dataset_directory)
        self.dataset_format = dataset_format
        self._volume, self.filenames = self._read_dataset()
        self._dtype = None
        self._axis = 0

    def __len__(self):
        return self._volume.shape[0]

    def __getitem__(self, key):
        return self._volume[key, :, :]

    @property
    def dtype(self):
        return self._dtype

    @property
    def shape(self):
        return self._volume.shape

    @property
    def axis(self):
        return self._axis

    @property
    def axis_shape(self):
        return tuple([dim for axis, dim in enumerate(self.shape) if axis != self.axis])

    def set_axis(self, axis):
        self._axis = axis

    def _read_dataset(self):

        if self.dataset_format == 'multi image':
            vfilenames = self._get_multi_image_volume_and_filenames()
        elif self.dataset_format == 'multipage file':
            vfilenames = self._get_multipage_image_volume_and_filename()

        return vfilenames

    def _get_multi_image_volume_and_filenames(self):

        filtered_string_list = []
        file_formats_counter = Counter()

        for file in os.listdir(self._dataset_directory):
            filename = os.fsdecode(file)

            formatted_string_list = [''.join(x) for _, x in itertools.groupby(filename, key=str.isdigit)]

            file_format = formatted_string_list[-1]

            if file_format in self._accepted_formats:
                file_formats_counter[file_format] += 1
                filtered_string_list.append(formatted_string_list)

        if len(file_formats_counter) == 0:
            raise IOError("No valid image files found")

        chosen_ext = file_formats_counter.most_common()[0][0]

        number_filename_list = []

        for string_list in filtered_string_list:
            ext = string_list[-1]
            if chosen_ext == ext:
                idx = int(string_list[-2])
                filename = ''.join(string_list)
                full_image_path = os.path.join(os.fsdecode(self._dataset_directory), filename)
                number_filename_list.append([idx, full_image_path, filename])

        number_filename_list = sorted(number_filename_list)
        filenames = [filename for _, _, filename in number_filename_list]

        fpairs = zip(number_filename_list[:-1], number_filename_list[1:])

        # Check that sections are consecutive
        for (idx_a, path_a, filename_a), (idx_b, path_b, filename_b) in fpairs:
            if idx_b - idx_a != 1:
                err_msg_info = filename_a + ' with number ' + str(idx_a) \
                     + ' and ' + filename_b + ' with number ' + str(idx_b)

                raise IOError('Files must be numbered consecutively, got: ' + err_msg_info)

        self.min_idx = number_filename_list[0][0]

        volume = np.stack([np.array(Image.open(filepath)) for (N, filepath, _) in number_filename_list])
        return volume, filenames

    def _get_multipage_image_volume_and_filename(self):
        img = Image.open(self._dataset_directory)
        image_list = []

        full_filename = os.fsdecode(self._dataset_directory).split('/')[-1]
        filename_list = full_filename.split(".")
        filename = ''.join(filename_list[:-1])
        ext = filename_list[-1]

        filenames = []

        i = 0

        while True:
            try:
                img.seek(i)
                image_list.append(np.array(img))
                filenames.append(filename + '_' + str(i) + '.' + ext)
                i += 1
            except EOFError:
                # Not enough frames in img
                break

        volume = np.stack(image_list)

        return volume, filenames

    def get_image_patch(self, center, shape):

        if self.axis == 0:
            image_slice = self._volume[center[0], :, :]
            i, j = center[1], center[2]
            h, w = self._volume.shape[1], self._volume.shape[2]
        if self.axis == 1:
            image_slice = self._volume[:, center[1], :]
            i, j = center[0], center[2]
            h, w = self._volume.shape[0], self._volume.shape[2]
        if self.axis == 2:
            image_slice = self._volume[:, :, center[2]]
            i, j = center[0], center[1]
            h, w = self._volume.shape[0], self._volume.shape[1]

        ph, pw = shape

        if ph % 2 == 0 or pw % 2 == 0:
            raise ValueError('Shape widths must be odd')

        patch = np.zeros(shape, dtype=self.dtype)

        for si in range(ph):
            for sj in range(pw):
                ii = i + si - ph//2
                jj = j + sj - pw//2
                if ii >= 0 and ii < h and jj >= 0 and jj < w:
                    patch[si, sj] = image_slice[ii, jj]

        return patch, (j - pw//2, i - ph//2)


    def initializeDatasetOutput(self, output_dir, output_shape, dataset_type):

        self.output_dir = output_dir

        #if not os.path.exists(output_dir):
        #    os.makedirs(output_dir)

        if dataset_type == 'multipage file':
            self.ouput_writer = TiffImagePlugin.AppendingTiffWriter(output_dir,True)

    def saveTransformedImage(self, k, image, dataset_type):

        if dataset_type == 'multi image':
            filename = self.filenames[k]
            output_filepath = os.path.join(self.output_dir, filename)
            output_image = Image.fromarray(image).convert('L')
            output_image.save(output_filepath)

        elif dataset_type == 'multipage file':
            im = Image.fromarray(image).convert('L')
            im.save(self.ouput_writer)
            self.ouput_writer.newFrame()