import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from scipy.ndimage import gaussian_filter
from skimage.io import imread 

from ipywidgets import interact, interactive, fixed, interact_manual, IntSlider, Button
import ipywidgets as widgets
from ipycanvas import MultiCanvas

from slgbuilder import GraphObject, MaxflowBuilder

def generate_synthetic_data(n_layers, smoothness, min_distance, blurring):
    
    std = 30
    size = 256
    sigma = smoothness
    
    line_locs = []
    n_unused_layers = 0
    for i in range(n_layers-1):
        if i == 0:
            line_locs.append(np.random.randint(size))
        else:       
            possible_indices = np.arange(size)
            for j in range(len(line_locs)):
                possible_indices = np.setdiff1d(possible_indices, np.arange(line_locs[j] - min_distance, line_locs[j] + min_distance))
            if len(possible_indices) < 1:
                n_layers = n_layers - 1
            else:
                line_locs.append(possible_indices[np.random.randint(len(possible_indices))])
    line_locs = np.sort(line_locs)

    boundary_lines = [np.zeros(size)]
    for i in range(n_layers-1):
        mean = line_locs[i]
        boundary_line = np.random.normal(loc=mean, scale=std, size=(size))
        boundary_line = gaussian_filter(boundary_line, sigma=sigma)
        boundary_lines.append(boundary_line)

    synthetic_data = np.zeros((size, size))
    ground_truth = np.zeros((size, size))
    for i in range(len(boundary_lines)):
        xx, yy = np.meshgrid(np.arange(size), np.arange(size), indexing='ij')
        if i == 0:
            layer_region = (xx <= boundary_lines[i+1]).astype(int)
        elif i < len(boundary_lines) - 1:
            layer_region = (xx > boundary_lines[i]).astype(int) * (xx <= boundary_lines[i+1]).astype(int)
        else:
            layer_region = (xx > boundary_lines[i]).astype(int)
            
        ground_truth += (xx < boundary_lines[i]).astype(int)

        synthetic_data = synthetic_data * (1 - layer_region)
        loc = 112 + np.random.rand() * 32
        scale = 4
        synthetic_data += layer_region * np.random.normal(loc=loc, scale=scale, size=(size,size))

    loc = 0
    scale = 4
    synthetic_data = gaussian_filter(synthetic_data, blurring) + np.random.normal(loc=loc, scale=scale, size=(size,size))

    f, ax = plt.subplots(1,3,figsize=(16,6))
    ax[0].set_title('Synthetic data')
    ax[0].imshow(synthetic_data, cmap='gray', vmin=0, vmax=255)
    ax[1].set_title('Ground truth')
    ax[1].imshow(ground_truth)
    ax[2].set_title('Synthetic data w/ layers')
    ax[2].imshow(synthetic_data, cmap='gray', vmin=0, vmax=255)
    ax[2].set_xlim(0,size)
    ax[2].set_ylim(size,0)
    for line in boundary_lines:
        ax[2].plot(line)
    plt.show()
    
    return synthetic_data.astype(np.int32), ground_truth

def create_synthetic_data_widget():
    return interactive(generate_synthetic_data,
                       n_layers=IntSlider(value=2, min=2, max=6, step=1, continuous_update=False, description='# of layers'),
                       smoothness=IntSlider(value=20, min=1, max=50, step=1, continuous_update=False, description='Smoothness'),
                       min_distance=IntSlider(value=10, min=1, max=150, step=1, continuous_update=False, description='Min distance'),
                       blurring=IntSlider(value=2, min=0, max=20, step=1, continuous_update=False, description='Blurring'))

def estimate_mean(I, x, y, width, height):

    mean = np.mean(I[y:y+height,x:x+width])
    
    rect = patches.Rectangle((x,y),width,height,linewidth=1,edgecolor='r',facecolor='none')
    
    f, ax = plt.subplots(1,1,figsize=(6,6))
    ax.set_title(f'Mean: {mean:.04f}')
    ax.imshow(I, cmap='gray', vmin=0, vmax=255)
    ax.add_patch(rect)

    return mean

def create_mean_estimator_widget(I):
    return interactive(estimate_mean,
                       I=fixed(I),
                       x=IntSlider(value=0, min=0, max=256, step=1, continuous_update=False, description='X'),
                       y=IntSlider(value=0, min=0, max=256, step=1, continuous_update=False, description='Y'),
                       width=IntSlider(value=20, min=1, max=256, step=1, continuous_update=False, description='Width'),
                       height=IntSlider(value=20, min=1, max=256, step=1, continuous_update=False, description='Height'))


def display_results(synthetic_data, segmentations, segmentation_lines):

    # Draw results.
    f,ax = plt.subplots(1,3,figsize=(16,6))
    ax[0].imshow(synthetic_data, cmap='gray', vmin=0, vmax=255)
    ax[1].imshow(np.sum(segmentations, axis=0))
    ax[2].imshow(synthetic_data, cmap='gray', vmin=0, vmax=255)
    for line in segmentation_lines:
        ax[2].plot(line)
    plt.show()


class MeanEstimatorTool:

    def __init__(self, background_image):
        
        self.drawing = False
        self.ix = None
        self.iy = None
        
        self.background_image = background_image
        self.line_width = 20
        self.radius = self.line_width / 2
        
        self.canvas = MultiCanvas(2, width=background_image.shape[1], height=background_image.shape[0])
        self.canvas[1].sync_image_data = True
        self.canvas[0].put_image_data(background_image, 0, 0)
    
        self.canvas[1].on_mouse_down(self._on_mouse_down)
        self.canvas[1].on_mouse_move(self._on_mouse_move)
        self.canvas[1].on_mouse_up(self._on_mouse_up)
        self.canvas[1].stroke_style = '#00FF00'
        self.canvas[1].fill_style = '#00FF00'
        self.canvas[1].global_alpha = 1
        
        brush_slider = interactive(self.update_brush_size,
                                   brush_size=IntSlider(value=20,
                                                        min=2,
                                                        max=20,
                                                        step=1,
                                                        continuous_update=True,
                                                        description='Brush size'))
        
        clear_button = Button(description='Clear')
        clear_button.on_click(self.clear)
        
        compute_button = Button(description='Calculate mean')
        compute_button.on_click(self.compute_mean)
        
        display(brush_slider)

        display(compute_button)
        display(clear_button)
        display(self.canvas)
        
    def update_brush_size(self, brush_size):
        self.line_width = brush_size
        self.radius = self.line_width / 2
        self.canvas[1].line_width = self.line_width
        
    def compute_mean(self, b):
        mask = (np.mean((self.canvas[1].get_image_data()),axis=-1) > 0)
        mean = np.sum(self.background_image * mask) / np.sum(mask)
        print(f'{mean:.04f}', end='\r')
        return mean
    
    def clear(self, b):
        self.canvas[1].clear()
        
    def _on_mouse_down(self, x, y):
        self.drawing = True
        self.canvas[1].fill_circle(x, y, self.radius)
        self.ix = x
        self.iy = y

    def _on_mouse_move(self, x, y):
        if self.drawing:
            self.canvas[1].stroke_line(self.ix, self.iy, x, y)
            self.canvas[1].fill_circle(x, y, self.radius)
        self.ix = x
        self.iy = y

    def _on_mouse_up(self, x, y):
        self.drawing = False    
        self.ix = x
        self.iy = y