Skip to content
Snippets Groups Projects
helpers.py 7.28 KiB
Newer Older
  • Learn to ignore specific revisions
  • willap's avatar
    willap committed
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    from scipy.ndimage import gaussian_filter
    from skimage.io import imread 
    
    from ipywidgets import interact, interactive, fixed, interact_manual, IntSlider, Button
    import ipywidgets as widgets
    from ipycanvas import MultiCanvas
    
    from slgbuilder import GraphObject, MaxflowBuilder
    
    def generate_synthetic_data(n_layers, smoothness, min_distance, blurring):
        
        std = 30
        size = 256
        sigma = smoothness
        
        line_locs = []
        n_unused_layers = 0
        for i in range(n_layers-1):
            if i == 0:
                line_locs.append(np.random.randint(size))
            else:       
                possible_indices = np.arange(size)
                for j in range(len(line_locs)):
                    possible_indices = np.setdiff1d(possible_indices, np.arange(line_locs[j] - min_distance, line_locs[j] + min_distance))
                if len(possible_indices) < 1:
                    n_layers = n_layers - 1
                else:
                    line_locs.append(possible_indices[np.random.randint(len(possible_indices))])
        line_locs = np.sort(line_locs)
    
        boundary_lines = [np.zeros(size)]
        for i in range(n_layers-1):
            mean = line_locs[i]
            boundary_line = np.random.normal(loc=mean, scale=std, size=(size))
            boundary_line = gaussian_filter(boundary_line, sigma=sigma)
            boundary_lines.append(boundary_line)
    
        synthetic_data = np.zeros((size, size))
        ground_truth = np.zeros((size, size))
        for i in range(len(boundary_lines)):
            xx, yy = np.meshgrid(np.arange(size), np.arange(size), indexing='ij')
            if i == 0:
                layer_region = (xx <= boundary_lines[i+1]).astype(int)
            elif i < len(boundary_lines) - 1:
                layer_region = (xx > boundary_lines[i]).astype(int) * (xx <= boundary_lines[i+1]).astype(int)
            else:
                layer_region = (xx > boundary_lines[i]).astype(int)
                
            ground_truth += (xx < boundary_lines[i]).astype(int)
    
            synthetic_data = synthetic_data * (1 - layer_region)
            loc = 112 + np.random.rand() * 32
            scale = 4
            synthetic_data += layer_region * np.random.normal(loc=loc, scale=scale, size=(size,size))
    
        loc = 0
        scale = 4
        synthetic_data = gaussian_filter(synthetic_data, blurring) + np.random.normal(loc=loc, scale=scale, size=(size,size))
    
        f, ax = plt.subplots(1,3,figsize=(16,6))
        ax[0].set_title('Synthetic data')
        ax[0].imshow(synthetic_data, cmap='gray', vmin=0, vmax=255)
        ax[1].set_title('Ground truth')
        ax[1].imshow(ground_truth)
        ax[2].set_title('Synthetic data w/ layers')
        ax[2].imshow(synthetic_data, cmap='gray', vmin=0, vmax=255)
        ax[2].set_xlim(0,size)
        ax[2].set_ylim(size,0)
        for line in boundary_lines:
            ax[2].plot(line)
        plt.show()
        
        return synthetic_data.astype(np.int32), ground_truth
    
    def create_synthetic_data_widget():
        return interactive(generate_synthetic_data,
                           n_layers=IntSlider(value=2, min=2, max=6, step=1, continuous_update=False, description='# of layers'),
                           smoothness=IntSlider(value=20, min=1, max=50, step=1, continuous_update=False, description='Smoothness'),
                           min_distance=IntSlider(value=10, min=1, max=150, step=1, continuous_update=False, description='Min distance'),
                           blurring=IntSlider(value=2, min=0, max=20, step=1, continuous_update=False, description='Blurring'))
    
    def estimate_mean(I, x, y, width, height):
    
        mean = np.mean(I[y:y+height,x:x+width])
        
        rect = patches.Rectangle((x,y),width,height,linewidth=1,edgecolor='r',facecolor='none')
        
        f, ax = plt.subplots(1,1,figsize=(6,6))
        ax.set_title(f'Mean: {mean:.04f}')
        ax.imshow(I, cmap='gray', vmin=0, vmax=255)
        ax.add_patch(rect)
    
        return mean
    
    def create_mean_estimator_widget(I):
        return interactive(estimate_mean,
                           I=fixed(I),
                           x=IntSlider(value=0, min=0, max=256, step=1, continuous_update=False, description='X'),
                           y=IntSlider(value=0, min=0, max=256, step=1, continuous_update=False, description='Y'),
                           width=IntSlider(value=20, min=1, max=256, step=1, continuous_update=False, description='Width'),
                           height=IntSlider(value=20, min=1, max=256, step=1, continuous_update=False, description='Height'))
    
    
    def display_results(synthetic_data, segmentations, segmentation_lines):
    
        # Draw results.
        f,ax = plt.subplots(1,3,figsize=(16,6))
        ax[0].imshow(synthetic_data, cmap='gray', vmin=0, vmax=255)
        ax[1].imshow(np.sum(segmentations, axis=0))
        ax[2].imshow(synthetic_data, cmap='gray', vmin=0, vmax=255)
        for line in segmentation_lines:
            ax[2].plot(line)
        plt.show()
    
    
    class MeanEstimatorTool:
    
        def __init__(self, background_image):
            
            self.drawing = False
            self.ix = None
            self.iy = None
            
            self.background_image = background_image
            self.line_width = 20
            self.radius = self.line_width / 2
            
            self.canvas = MultiCanvas(2, width=background_image.shape[1], height=background_image.shape[0])
            self.canvas[1].sync_image_data = True
            self.canvas[0].put_image_data(background_image, 0, 0)
        
            self.canvas[1].on_mouse_down(self._on_mouse_down)
            self.canvas[1].on_mouse_move(self._on_mouse_move)
            self.canvas[1].on_mouse_up(self._on_mouse_up)
            self.canvas[1].stroke_style = '#00FF00'
            self.canvas[1].fill_style = '#00FF00'
            self.canvas[1].global_alpha = 1
            
            brush_slider = interactive(self.update_brush_size,
                                       brush_size=IntSlider(value=20,
                                                            min=2,
                                                            max=20,
                                                            step=1,
                                                            continuous_update=True,
                                                            description='Brush size'))
            
            clear_button = Button(description='Clear')
            clear_button.on_click(self.clear)
            
            compute_button = Button(description='Calculate mean')
            compute_button.on_click(self.compute_mean)
            
            display(brush_slider)
    
            display(compute_button)
            display(clear_button)
            display(self.canvas)
            
        def update_brush_size(self, brush_size):
            self.line_width = brush_size
            self.radius = self.line_width / 2
            self.canvas[1].line_width = self.line_width
            
        def compute_mean(self, b):
            mask = (np.mean((self.canvas[1].get_image_data()),axis=-1) > 0)
            mean = np.sum(self.background_image * mask) / np.sum(mask)
            print(f'{mean:.04f}', end='\r')
            return mean
        
        def clear(self, b):
            self.canvas[1].clear()
            
        def _on_mouse_down(self, x, y):
            self.drawing = True
            self.canvas[1].fill_circle(x, y, self.radius)
            self.ix = x
            self.iy = y
    
        def _on_mouse_move(self, x, y):
            if self.drawing:
                self.canvas[1].stroke_line(self.ix, self.iy, x, y)
                self.canvas[1].fill_circle(x, y, self.radius)
            self.ix = x
            self.iy = y
    
        def _on_mouse_up(self, x, y):
            self.drawing = False    
            self.ix = x
            self.iy = y