diff --git a/docs/assets/screenshots/GUI-layers.png b/docs/assets/screenshots/GUI-layers.png
new file mode 100644
index 0000000000000000000000000000000000000000..0285fa99cca778942b5665e8f252c94096a8d7fb
Binary files /dev/null and b/docs/assets/screenshots/GUI-layers.png differ
diff --git a/docs/assets/screenshots/layers.png b/docs/assets/screenshots/layers.png
new file mode 100644
index 0000000000000000000000000000000000000000..0715964a369d4f40cd198b58bf7cc815dda3ecb5
Binary files /dev/null and b/docs/assets/screenshots/layers.png differ
diff --git a/docs/assets/screenshots/segmented_layers.png b/docs/assets/screenshots/segmented_layers.png
new file mode 100644
index 0000000000000000000000000000000000000000..8f9a261aed7b11514ce1d3f3c56900f8152a0a75
Binary files /dev/null and b/docs/assets/screenshots/segmented_layers.png differ
diff --git a/docs/cli.md b/docs/cli.md
index 4b0ed8a8ce388db4700db38a5aeb0c0c730b7cd6..a682f685c465080e93891b57e8be553d4b0281d1 100644
--- a/docs/cli.md
+++ b/docs/cli.md
@@ -30,6 +30,8 @@ This offers quick interactions, making it ideal for tasks that require efficienc
 | `--data-explorer` | Starts the Data Explorer |
 | `--iso3d` | Starts the 3D Isosurfaces visualization |
 | `--local-thickness` | Starts the Local thickness tool |
+| `--anotation-tool` | Starts the annotation tool |
+| `--layers` | Starts the tool for segmenting layers |
 | `--host` | Desired host for the server. By default runs on `0.0.0.0`  |
 | `--platform` | Uses the Qim platform API for a unique path and port depending on the username |
 
diff --git a/docs/gui.md b/docs/gui.md
index 5401cb85bd40f32e9d74fadfc531c3fa0dec8879..ef996896db91d26a02037f28f288ece44d0892d3 100644
--- a/docs/gui.md
+++ b/docs/gui.md
@@ -34,5 +34,9 @@ For details see [here](cli.md#qim3d-gui).
 ![Iso3d GUI](assets/screenshots/GUI-iso3d.png)
 
 ::: qim3d.gui.annotation_tool
+    options:
+        members: False
+
+::: qim3d.gui.layers2d
     options:
         members: False
\ No newline at end of file
diff --git a/docs/processing.md b/docs/processing.md
index d1d0d071f2b161844ce7ce2b4035417027b748c1..37f35ba49d9515a89c604f7a59b54ee76e1c7ebd 100644
--- a/docs/processing.md
+++ b/docs/processing.md
@@ -14,6 +14,8 @@ Here, we provide functionalities designed specifically for 3D image analysis and
             - maximum
             - minimum
             - tophat
+            - get_lines
+            - segment_layers
 
 ::: qim3d.processing.Pipeline
     options:
diff --git a/qim3d/gui/interface.py b/qim3d/gui/interface.py
index 78830582e052d4a72f303530a48bfc522bca88aa..d528d2153e4b5a9f535d850d7a3f86c767119e90 100644
--- a/qim3d/gui/interface.py
+++ b/qim3d/gui/interface.py
@@ -47,6 +47,9 @@ class BaseInterface(ABC):
 
     def set_invisible(self):
         return gr.update(visible=False)
+    
+    def change_visibility(self, is_visible):
+        return gr.update(visible = is_visible)
 
     def launch(self, img=None, force_light_mode: bool = True, **kwargs):
         """
diff --git a/qim3d/gui/layers2d.py b/qim3d/gui/layers2d.py
index 1126442fa711c517521e6d478c7f5334507546c0..043f24064a62eef2f0274c54e09d125dec73e010 100644
--- a/qim3d/gui/layers2d.py
+++ b/qim3d/gui/layers2d.py
@@ -1,36 +1,75 @@
+"""
+The GUI can be launched directly from the command line:
+
+```bash
+qim3d gui --layers
+```
+
+Or launched from a python script
+
+```python
+import qim3d
+
+layers = qim3d.gui.layers2d.Interface()
+app = layers.launch()
+```
+![gui-layers](assets/screenshots/GUI-layers.png)
+
+"""
+
 import os
 
 import gradio as gr
-import matplotlib.pyplot as plt
+import numpy as np
 
 from .interface import BaseInterface
 
-from qim3d.processing import layers2d as l2d
+# from qim3d.processing import layers2d as l2d
+from qim3d.processing import overlay_rgb_images, segment_layers, get_lines
 from qim3d.io import load
-import qim3d.viz
+from qim3d.viz.layers2d import image_with_lines
 
 #TODO figure out how not update anything and go through processing when there are no data loaded
 # So user could play with the widgets but it doesnt throw error
 # Right now its only bypassed with several if statements
 # I opened an issue here https://github.com/gradio-app/gradio/issues/9273
 
+X = 1
+Y = 0
+Z = 2
+
+DEFAULT_PLOT_TYPE = 'Segmentation mask'
+
 class Interface(BaseInterface):
     def __init__(self):
         super().__init__("Layered surfaces 2D", 1080)
 
-        self.l2d_obj_x = l2d.Layers2d()
-        self.l2d_obj_y = l2d.Layers2d()
-        self.l2d_obj_z = l2d.Layers2d()
-
         self.figsize = (8, 8)
         self.cmap = "Greys_r"
+        self.axes = {'x': X, 'y':Y, 'z':Z}
 
         self.data = None
 
-        self.virtual_stack = True #TODO ask why
-        self.dataset_name = '' #TODO check if necessary to even have
+        self.x_slice = None
+        self.y_slice = None
+        self.z_slice = None
+        self.x_segmentation = None
+        self.y_segmentation = None
+        self.z_segmentation = None
+
+        self.plot_type = DEFAULT_PLOT_TYPE
 
-        self.error_state = False
+        self.segmentation_colors = np.array([[0, 255, 255], # Cyan
+                                             [255, 195, 0], # Yellow Orange
+                                             [199, 0, 57], # Dark orange
+                                             [218, 247, 166], # Light green
+                                             [255, 0, 255], # Magenta
+                                             [65, 105, 225], # Royal blue
+                                             [138, 43, 226], # Blue violet
+                                             [255, 0, 0], #Red
+                                             ])
+
+        self.error = False
 
 
 
@@ -58,153 +97,229 @@ class Interface(BaseInterface):
                     interactive=True,
                     height = 230,
                 )
-                
-                # Parameters sliders and checkboxes
-                with gr.Row():
-                    delta = gr.Slider(
-                        minimum=0.5,
-                        maximum=1.0,
-                        value=0.75,
-                        step=0.01,
-                        label="Delta value",
-                        info="The lower the delta is, the more accurate the gradient calculation will be. However, the calculation takes longer to execute.", 
-                    )
-                    
-                with gr.Row():
-                    min_margin = gr.Slider(
-                        minimum=1, 
-                        maximum=50, 
-                        value=10, 
-                        step=1, 
-                        label="Min margin",
-                        info="Minimum margin between layers to be detected in the image.",
-                    )
 
                 with gr.Row():
-                    n_layers = gr.Slider(
-                        minimum=1,
-                        maximum=10,
-                        value=2,
-                        step=1,
-                        label="Number of layers",
-                        info="Number of layers to be detected in the image",
-                    )
+                    with gr.Group():
+                        plot_type = gr.Radio(
+                            choices= (DEFAULT_PLOT_TYPE, 'Segmentation lines',),
+                            value = DEFAULT_PLOT_TYPE,
+                            interactive = True,
+                            show_label=False
+                        )
+                        alpha = gr.Slider(
+                            minimum=0,
+                            maximum = 1,
+                            step = 0.01,
+                            label = 'Alpha value',
+                            show_label=True,
+                            value = 0.5,
+                            visible = True,
+                            interactive=True)
+                        line_thickness = gr.Slider(
+                            minimum=0.1,
+                            maximum = 5,
+                            value = 2,
+                            label = 'Line thickness',
+                            show_label = True,
+                            visible = False,
+                            interactive = True
+                            )
+                        
+                with gr.Group():
+                    with gr.Row():
+                        axis = gr.Radio(
+                            choices = ['Y', 'X', 'Z'],
+                            value = 'Y', 
+                            label = 'Layer axis',
+                            info = 'Specifies in which direction are the layers. Because of numpy design, Y is 0th axis, X is 1st and Z is 2nd.',)
+                    with gr.Row():
+                        wrap = gr.Checkbox(
+                            label = "Lines start and end at the same level.",
+                            info = "Used when segmenting layers of unfolded image."
+                        )
+                        
+                        is_inverted = gr.Checkbox(
+                            label="Invert image before processing",
+                            info="The algorithm effectively flips the gradient.",
+                        ) 
+                    
+                    with gr.Row():
+                        delta = gr.Slider(
+                            minimum=0,
+                            maximum=5,
+                            value=0.75,
+                            step=0.01,
+                            interactive = True,
+                            label="Delta value",
+                            info="The lower the delta is, the more accurate the gradient calculation will be. However, the calculation takes longer to execute. Delta above 1 is rounded down to closest lower integer", 
+                        )
+                        
+                    with gr.Row():
+                        min_margin = gr.Slider(
+                            minimum=1, 
+                            maximum=50, 
+                            value=10, 
+                            step=1, 
+                            interactive = True,
+                            label="Min margin",
+                            info="Minimum margin between layers to be detected in the image.",
+                        )
 
-                with gr.Row():
-                    is_inverted = gr.Checkbox(
-                        label="Is inverted",
-                        info="To invert the image before processing, click this box. By inverting the source image before processing, the algorithm effectively flips the gradient.",
-                    )                    
+                    with gr.Row():
+                        n_layers = gr.Slider(
+                            minimum=1,
+                            maximum=len(self.segmentation_colors) - 1,
+                            value=2,
+                            step=1,
+                            interactive=True,
+                            label="Number of layers",
+                            info="Number of layers to be detected in the image",
+                        )                 
 
-                with gr.Row():
-                    btn_run = gr.Button("Run Layers2D", variant = 'primary')
+                # with gr.Row():
+                #     btn_run = gr.Button("Run Layers2D", variant = 'primary')
 
             # Output panel: Plots
-            with gr.Column(scale=2):
-                with gr.Row(): # Source image outputs
-                    input_plot_x = gr.Plot(
-                        show_label=True,
-                        label="Slice along X-axis",
-                        visible=True,
-                    )
-                    input_plot_y = gr.Plot(
-                        show_label=True,
-                        label="Slice along Y-axis",
-                        visible=True,
-                    )
-                    input_plot_z = gr.Plot(
-                        show_label=True,
-                        label="Slice along Z-axis",
-                        visible=True,
-                    )
-                with gr.Row(): # Detected layers outputs
-                    output_plot_x = gr.Plot(
-                        show_label=True,
-                        label="Detected layers X-axis",
-                        visible=True,
+            """
+            60em if plot is alone
+            30em if two of them
+            20em if all of them are visible
 
-                    )
-                    output_plot_y = gr.Plot(
-                        show_label=True,
-                        label="Detected layers Y-axis",
-                        visible=True,
+            When one slicing axis is made unvisible we want the other two images to be bigger
+            For some reason, gradio changes their width but not their height. So we have to 
+            change their height manually
+            """
 
+            self.heights = ['60em', '30em', '20em'] # em units are relative to the parent, 
+
+
+            with gr.Column(scale=2,):
+                # with gr.Row(): # Source image outputs
+                #     input_image_kwargs = lambda axis: dict(
+                #         show_label = True,
+                #         label = F'Slice along {axis}-axis', 
+                #         visible = True, 
+                #         height = self.heights[2]
+                #     )
+
+                #     input_plot_x = gr.Image(**input_image_kwargs('X'))
+                #     input_plot_y = gr.Image(**input_image_kwargs('Y'))
+                #     input_plot_z = gr.Image(**input_image_kwargs('Z'))
+
+                with gr.Row(): # Detected layers outputs
+                    output_image_kwargs = lambda axis: dict(
+                        show_label = True,
+                        label = F'Detected layers {axis}-axis',
+                        visible = True,
+                        height = self.heights[2]
                     )
-                    output_plot_z = gr.Plot(
-                        show_label=True,
-                        label="Detected layers Z-axis",
-                        visible=True,
-                    )
+                    output_plot_x = gr.Image(**output_image_kwargs('X'))
+                    output_plot_y = gr.Image(**output_image_kwargs('Y'))
+                    output_plot_z = gr.Image(**output_image_kwargs('Z'))
                     
                 with gr.Row(): # Axis position sliders
-                    x_pos = gr.Slider(
-                        minimum=0,
-                        maximum=1,
-                        value=0.5,
-                        step=0.01,
-                        label="X position",
-                        info="The 3D image is sliced along the X-axis.",
+                    slider_kwargs = lambda axis: dict(
+                        minimum = 0,
+                        maximum = 1,
+                        value = 0.5,
+                        step = 0.01,
+                        label = F'{axis} position',
+                        info = F'The 3D image is sliced along {axis}-axis'
                     )
-                    y_pos = gr.Slider(
-                        minimum=0,
-                        maximum=1,
-                        value=0.5,
-                        step=0.01,
-                        label="Y position",
-                        info="The 3D image is sliced along the Y-axis.",
-                    )
-                    z_pos = gr.Slider(
-                        minimum=0,
-                        maximum=1,
-                        value=0.5,
-                        step=0.01,
-                        label="Z position",
-                        info="The 3D image is sliced along the Z-axis.",
-                    )
-        
+                    
+                    x_pos = gr.Slider(**slider_kwargs('X'))                    
+                    y_pos = gr.Slider(**slider_kwargs('Y'))
+                    z_pos = gr.Slider(**slider_kwargs('Z'))
+
+                with gr.Row():
+                    x_check = gr.Checkbox(value = True, interactive=True, label = 'Show X slice')
+                    y_check = gr.Checkbox(value = True, interactive=True, label = 'Show Y slice')
+                    z_check = gr.Checkbox(value = True, interactive=True, label = 'Show Z slice')
+                with gr.Row():
+                    btn_run = gr.Button("Run Layers2D", variant = 'primary')
+
+
         positions = [x_pos, y_pos, z_pos]
-        process_inputs = [is_inverted, delta, min_margin, n_layers]
-        input_plots = [input_plot_x, input_plot_y, input_plot_z]
+        process_inputs = [axis, is_inverted, delta, min_margin, n_layers, wrap]
+        plotting_inputs = [axis, alpha, line_thickness]
+        # input_plots = [input_plot_x, input_plot_y, input_plot_z]
         output_plots = [output_plot_x, output_plot_y, output_plot_z]
+        visibility_check_inputs = [x_check, y_check, z_check]
 
         spinner_loading = gr.Text("Loading data...", visible=False)
         spinner_running = gr.Text("Running pipeline...", visible=False)
-        spinner_updating = gr.Text("Updating layers...", visible=False)
 
-        # fmt: off
         reload_base_path.click(
             fn=self.update_explorer,inputs=base_path, outputs=explorer)
+        
+        plot_type.change(
+            self.change_plot_type, inputs = plot_type, outputs = [alpha, line_thickness]).then(
+            fn = self.plot_output_img_all, inputs = plotting_inputs, outputs = output_plots
+            )
+        
+        gr.on(
+            triggers = [alpha.release, line_thickness.release],
+            fn = self.plot_output_img_all, inputs = plotting_inputs, outputs = output_plots
+        )
+
+        """
+        Difference between btn_run.click and the other triggers below is only loading the data.
+        To make it easier to maintain, I created 'update_component' variable. Its value is completely
+        unimportant. It exists only to be changed after loading the data which triggers further processing
+        which is the same for button click and the other triggers
+        """
+
+        update_component = gr.State(True)
 
         btn_run.click(
             fn=self.set_spinner, inputs=spinner_loading, outputs=btn_run).then(
             fn=self.load_data, inputs = [base_path, explorer]).then(
-            fn=self.set_spinner, inputs=spinner_running, outputs=btn_run).then(
-            fn=self.process_all, inputs = [*positions, *process_inputs]).then(
-            fn=self.plot_input_img_all, inputs = positions, outputs = input_plots, show_progress='hidden').then(
-            fn=self.plot_output_all, outputs = output_plots, show_progress='hidden').then(
-            fn=self.set_relaunch_button, inputs=[], outputs=btn_run)
+            fn = lambda state: not state, inputs = update_component, outputs = update_component)
         
         gr.on(
-            triggers=[delta.change, min_margin.change, n_layers.change, is_inverted.change],
-            fn=self.set_spinner, inputs=spinner_updating, outputs=btn_run).then(
+            triggers= (axis.change, is_inverted.change, delta.release, min_margin.release, n_layers.release, update_component.change, wrap.change),
+            fn=self.set_spinner, inputs = spinner_running, outputs=btn_run).then(
             fn=self.process_all, inputs = [*positions, *process_inputs]).then(
-            fn=self.plot_output_all, outputs = output_plots, show_progress='hidden').then(
+            # fn=self.plot_input_img_all, outputs = input_plots, show_progress='hidden').then(
+            fn=self.plot_output_img_all, inputs =  plotting_inputs, outputs = output_plots, show_progress='hidden').then(
             fn=self.set_relaunch_button, inputs=[], outputs=btn_run)
-                    
-        slider_change_arguments = (
-            (x_pos, self.process_x, self.plot_input_img_x, input_plot_x, self.l2d_obj_x, output_plot_x),
-            (y_pos, self.process_y, self.plot_input_img_y, input_plot_y, self.l2d_obj_y, output_plot_y),
-            (z_pos, self.process_z, self.plot_input_img_z, input_plot_z, self.l2d_obj_z, output_plot_z))
         
-        for slider, process_func, plot_input_func, input_plot, l2d_obj, output_plot in slider_change_arguments:
-            slider.change(
-                fn=self.set_spinner, inputs=spinner_updating, outputs=btn_run).then(
-                fn=process_func, inputs=[slider, *process_inputs]).then(
-                fn=plot_input_func, inputs=slider, outputs=input_plot, show_progress='hidden').then(
-                fn=self.plot_output_wrapper(l2d_obj), outputs=output_plot, show_progress='hidden').then(
-                fn=self.set_relaunch_button, inputs=[], outputs=btn_run)
+        # Chnages visibility and sizes of the plots - gives user the option to see only some of the images and in bigger scale
+        gr.on(
+            triggers=[x_check.change, y_check.change, z_check.change],
+            fn = self.change_row_visibility, inputs = visibility_check_inputs, outputs = positions).then(
+            # fn = self.change_row_visibility, inputs = visibility_check_inputs, outputs = input_plots).then(
+            fn = self.change_plot_size, inputs = visibility_check_inputs, outputs = output_plots)
+        
+        # for  axis, slider, input_plot, output_plot in zip(['x','y','z'], positions, input_plots, output_plots):
+        for  axis, slider, output_plot in zip(['x','y','z'], positions, output_plots):
+            slider.release(
+                self.process_wrapper(axis), inputs = [slider, *process_inputs]).then( 
+                # self.plot_input_img_wrapper(axis), outputs = input_plot).then(
+                self.plot_output_img_wrapper(axis), inputs = plotting_inputs, outputs = output_plot)
+            
+        
 
+    def change_plot_type(self, plot_type, ):
+        self.plot_type = plot_type
+        if plot_type == 'Segmentation lines':
+            return gr.update(visible = False), gr.update(visible = True)
+        else:  
+            return gr.update(visible = True), gr.update(visible = False)
         
+    def change_plot_size(self, x_check, y_check, z_check):
+        """
+        Based on how many plots are we displaying (controlled by checkboxes in the bottom) we define
+        also their height because gradio doesn't do it automatically. The values of heights were set just by eye.
+        They are defines before defining the plot in 'define_interface'
+        """
+        index = x_check + y_check + z_check - 1
+        height = self.heights[index] # also used to define heights of plots in the begining
+        return gr.update(height = height, visible= x_check), gr.update(height = height, visible = y_check), gr.update(height = height, visible = z_check)
+
+    def change_row_visibility(self, x_check, y_check, z_check):
+        return self.change_visibility(x_check), self.change_visibility(y_check), self.change_visibility(z_check)
+    
     def update_explorer(self, new_path):
         # Refresh the file explorer object
         new_path = os.path.expanduser(new_path)
@@ -222,20 +337,13 @@ class Interface(BaseInterface):
             raise ValueError("Invalid path")
 
     def set_relaunch_button(self):
-        # Sets the button to relaunch
-        return gr.update(
-            value=f"Relaunch",
-            interactive=True,
-        )
+        return gr.update(value=f"Relaunch", interactive=True)
 
     def set_spinner(self, message):
-        if not self.error_state:
+        if self.error:
             return gr.Button()
         # spinner icon/shows the user something is happeing
-        return gr.update(
-            value=f"{message}",
-            interactive=False,
-        )
+        return gr.update(value=f"{message}", interactive=False)
     
     def load_data(self, base_path, explorer):
         if base_path and os.path.isfile(base_path):
@@ -248,107 +356,114 @@ class Interface(BaseInterface):
         try:
             self.data = load(
                 file_path,
-                virtual_stack=self.virtual_stack,
-                dataset_name=self.dataset_name,
+                progress_bar=False
             )
         except Exception as error_message:
             raise gr.Error(
                 f"Failed to load the image: {error_message}"
             ) from error_message
         
-    def idx(self, pos, axis):
-        return int(pos * (self.data.shape[axis] - 1))
-    
-
-    # PROCESSING FUNCTIONS
-
-    def process(self, l2d_obj:l2d.Layers2d, slice, is_inverted, delta, min_margin, n_layers):
-        l2d_obj.prepare_update(
-            data = slice,
-            is_inverted = is_inverted,
-            delta = delta,
-            min_margin = min_margin,
-            n_layers = n_layers,
-        )
-        l2d_obj.update()
-
-    def process_x(self, x_pos, is_inverted, delta, min_margin, n_layers):
-        if self.data is not None:
-            slice = self.data[self.idx(x_pos, 0), :, :]
-            self.process(self.l2d_obj_x, slice, is_inverted, delta, min_margin, n_layers)
-    
-    def process_y(self, y_pos, is_inverted, delta, min_margin, n_layers):
-        if self.data is not None:
-            slice = self.data[:, self.idx(y_pos, 1), :]
-            self.process(self.l2d_obj_y, slice, is_inverted, delta, min_margin, n_layers)
-
-    def process_z(self, z_pos, is_inverted, delta, min_margin, n_layers):
-        if self.data is not None:
-            slice = self.data[:, :, self.idx(z_pos, 2)]
-            self.process(self.l2d_obj_z, slice, is_inverted, delta, min_margin, n_layers)
-
-    def process_all(self, x_pos, y_pos, z_pos, is_inverted, delta, min_margin, n_layers):
-        self.process_x(x_pos, is_inverted, delta, min_margin, n_layers)
-        self.process_y(y_pos, is_inverted, delta, min_margin, n_layers)
-        self.process_z(z_pos, is_inverted, delta, min_margin, n_layers)
+    def process_all(self, x_pos:float, y_pos:float, z_pos:float, axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool):
+        self.process_wrapper('x')(x_pos, axis, inverted, delta, min_margin, n_layers, wrap)
+        self.process_wrapper('y')(y_pos, axis, inverted, delta, min_margin, n_layers, wrap)
+        self.process_wrapper('z')(z_pos, axis, inverted, delta, min_margin, n_layers, wrap)
+
+    def process_wrapper(self, slicing_axis:str):
+        """
+        The function behaves the same in all 3 directions, however we have to know in which direction we are now.
+        Thus we have this wrapper function, where we pass the slicing axis - in which axis are we indexing the data
+            and we return a function working in that direction
+        """
+        slice_key = F'{slicing_axis.lower()}_slice'
+        seg_key = F'{slicing_axis.lower()}_segmentation'
+        slicing_axis_int = self.axes[slicing_axis]
+
+        def process(pos:float, segmenting_axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool):
+            """
+            Parameters:
+            -----------
+            pos: Relative position of a slice from data
+            segmenting_axis: In which direction we want to detect layers
+            inverted: If we want use inverted gradient
+            delta: Smoothness parameter
+            min_margin: What is the minimum distance between layers. If it was 0, all layers would be the same
+            n_layers: How many layer boarders we want to find
+            wrap: If True, the starting point and end point will be at the same level. Useful when segmenting unfolded images.
+            """
+            slice = self.get_slice(pos, slicing_axis_int)
+            self.__dict__[slice_key] = slice
+
+            if segmenting_axis.lower() == slicing_axis.lower():
+                self.__dict__[seg_key] = None
+            else:
+                
+                if self.is_transposed(slicing_axis, segmenting_axis):
+                    slice = np.rot90(slice)
+                self.__dict__[seg_key] = segment_layers(slice, inverted = inverted, n_layers = n_layers, delta = delta, min_margin = min_margin, wrap = wrap)
         
-
-    # PLOTTING FUNCTIONS
-
-    def plot_input_img(self, slice):
-        plt.close()
-        fig, ax = plt.subplots(figsize=self.figsize)
-        ax.imshow(slice, interpolation="nearest", cmap = self.cmap)
-
-        # Adjustments
-        ax.axis("off")
-        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
-
-        return fig
+        return process
+
+    def is_transposed(self, slicing_axis:str, segmenting_axis:str):
+        """
+        Checks if the desired direction of segmentation is the same if the image would be submitted to segmentation as is. 
+        If it is not, we have to rotate it before we put it to segmentation algorithm
+        """
+        remaining_axis = 'xyz'.replace(slicing_axis.lower(), '').replace(segmenting_axis.lower(), '')
+        return self.axes[segmenting_axis.lower()] > self.axes[remaining_axis]
     
-    def plot_input_img_x(self, x_pos):
-        if self.data is None:
-            return gr.Plot()
-        slice = self.data[self.idx(x_pos, 0), :, :]
-        return self.plot_input_img(slice)
+    def get_slice(self, pos:float, axis:int):
+        idx = int(pos * (self.data.shape[axis] - 1))
+        return np.take(self.data, idx, axis = axis)
     
-    def plot_input_img_y(self, y_pos):
-        if self.data is None:
-            return gr.Plot()
-        slice = self.data[:, self.idx(y_pos, 1), :]
-        return self.plot_input_img(slice)
+    # def plot_input_img_wrapper(self, axis:str):
+    #     slice_key = F'{axis.lower()}_slice'
+    #     def plot_input_img():
+    #         slice = self.__dict__[slice_key]
+    #         slice = slice + np.abs(np.min(slice))
+    #         slice = slice / np.max(slice)
+    #         return slice
+    #     return plot_input_img
+
+    # def plot_input_img_all(self):
+    #     x_plot = self.plot_input_img_wrapper('x')()
+    #     y_plot = self.plot_input_img_wrapper('y')()
+    #     z_plot = self.plot_input_img_wrapper('z')()
+    #     return x_plot, y_plot, z_plot
     
-    def plot_input_img_z(self, z_pos):
-        if self.data is None:
-            return gr.Plot()
-        slice = self.data[:, :, self.idx(z_pos, 2)]
-        return self.plot_input_img(slice)
+    def plot_output_img_wrapper(self, slicing_axis:str):
+        slice_key = F'{slicing_axis.lower()}_slice'
+        seg_key = F'{slicing_axis.lower()}_segmentation'
+
+        def plot_output_img(segmenting_axis:str, alpha:float, line_thickness:float):
+            slice = self.__dict__[slice_key]
+            seg = self.__dict__[seg_key]
+
+            if seg is None: # In case segmenting axis si the same as slicing axis
+                return slice
+            
+            if self.plot_type == DEFAULT_PLOT_TYPE:
+                n_layers = len(seg) + 1
+                seg = np.sum(seg, axis = 0)
+                seg = np.repeat(seg[..., None], 3, axis = -1)
+                for i in range(n_layers):
+                    seg[seg[:,:,0] == i, :] = self.segmentation_colors[i]
+
+                if self.is_transposed(slicing_axis, segmenting_axis):
+                    seg = np.rot90(seg, k = 3)
+                return overlay_rgb_images(np.repeat(slice[..., None], 3, -1), seg, alpha) 
+            else:
+                lines = get_lines(seg)
+                if self.is_transposed(slicing_axis, segmenting_axis):
+                    return image_with_lines(np.rot90(slice), lines, line_thickness).rotate(270, expand = True)
+                else:
+                    return image_with_lines(slice, lines, line_thickness)
+            
+        return plot_output_img
     
-    def plot_input_img_all(self, x_pos, y_pos, z_pos):
-        x_plot = self.plot_input_img_x(x_pos)
-        y_plot = self.plot_input_img_y(y_pos)
-        z_plot = self.plot_input_img_z(z_pos)
-        return x_plot, y_plot, z_plot
-
-    def plot_output_wrapper(self, l2d_obj:l2d.Layers2d):
-        def plot_l2d_output():
-            if self.data is None:
-                return gr.Plot()
-            fig, ax = qim3d.viz.layers2d.create_plot_of_2d_array(l2d_obj.get_data_not_inverted())
-
-            for line in l2d_obj.segmentation_lines:
-                qim3d.viz.layers2d.add_line_to_plot(ax, line)
-
-            ax.axis("off")
-            fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
-
-            return fig
-        return plot_l2d_output
-
-    def plot_output_all(self):
-        x_output = self.plot_output_wrapper(self.l2d_obj_x)()
-        y_output = self.plot_output_wrapper(self.l2d_obj_y)()
-        z_output = self.plot_output_wrapper(self.l2d_obj_z)()
+    def plot_output_img_all(self, segmenting_axis:str, alpha:float, line_thickness:float):
+        x_output = self.plot_output_img_wrapper('x')(segmenting_axis, alpha, line_thickness)
+        y_output = self.plot_output_img_wrapper('y')(segmenting_axis, alpha, line_thickness)
+        z_output = self.plot_output_img_wrapper('z')(segmenting_axis, alpha, line_thickness)
         return x_output, y_output, z_output
 
 if __name__ == "__main__":
diff --git a/qim3d/processing/__init__.py b/qim3d/processing/__init__.py
index 58bdbadfe1b5ca21e0cc18e63fb47a1bf65f2399..aabf2a7d05f144c936edd3eb0426c407f13c4289 100644
--- a/qim3d/processing/__init__.py
+++ b/qim3d/processing/__init__.py
@@ -4,3 +4,4 @@ from .detection import blob_detection
 from .filters import *
 from .operations import *
 from .cc import get_3d_cc
+from .layers2d import segment_layers, get_lines
diff --git a/qim3d/processing/layers2d.py b/qim3d/processing/layers2d.py
index d1e5457013e81c76d2a031668a9826eded8fc8d5..d91ef0a42f1ba8715793a674321219cbbf1d634b 100644
--- a/qim3d/processing/layers2d.py
+++ b/qim3d/processing/layers2d.py
@@ -1,410 +1,97 @@
-"""Class for layered surface segmentation in 2D images."""
 import numpy as np
-from scipy import signal as sig
-import os
 from slgbuilder import GraphObject 
 from slgbuilder import MaxflowBuilder
 
-class Layers2d:
+def segment_layers(data:np.ndarray, inverted:bool = False, n_layers:int = 1, delta:float = 1, min_margin:int = 10, max_margin:int = None, wrap:bool = False):
     """
-    Create an object to store graphs for layered surface segmentations.
+    Works on 2D and 3D data.
+    Light one function wrapper around slgbuilder https://github.com/Skielex/slgbuilder to do layer segmentation
+    Now uses only MaxflowBuilder for solving.
 
     Args:
-        data (numpy.ndarray, optional): 2D image data.
-        n_layers (int, optional): Number of layers. Defaults to 1.
-        delta (int, optional): Smoothness parameter. Defaults to 1.
-        min_margin (int, optional): Minimum margin between layers. Defaults to 10.
-        inverted (bool, optional): Choose inverted data for segmentation. Defaults to False.
+        data: 2D or 3D array on which it will be computed
+        inverted: if True, it will invert the brightness of the image
+        n_layers: How many layers are we looking for (result in a layer and background)
+        delta: Smoothness parameter
+        min_margin: If we want more layers, we have to have a margin otherwise they are all going to be exactly the same
+        max_margin: Maximum margin between layers
+        wrap: If True, starting and ending point of the border between layers are at the same level
 
-    Raises:
-        TypeError: If `data` is not numpy.ndarray.
-    
-    Example:
-        layers2d = Layers2d(data = np_arr, n_layers = 3, delta = 5, min_margin = 20)
-    """
-        
-    def __init__(self, 
-                 data = None, 
-                 is_inverted = False,
-                 n_layers = 1, 
-                 delta = 1, 
-                 min_margin = 10
-                 ):
-        '''
-        Create an object to store graphs for layered surface segmentations.\n
-        - 'Data' must be a numpy.ndarray.\n
-        - 'is_inverted' is a boolean which decides if the data is inverted or not.\n
-        - 'n_layers' is the number of layers.\n
-        - 'delta' is the smoothness parameter.\n
-        - 'min_margin' is the minimum margin between layers.\n
-        - 'data_not_inverted' is the original data.\n
-        - 'data_inverted' is the inverted data.\n
-        - 'layers' is a list of GraphObject objects.\n
-        - 'helper' is a MaxflowBuilder object.\n
-        - 'flow' is the result of the maxflow algorithm on the helper.\n
-        - 'segmentations' is a list of segmentations.\n
-        - 'segmentation_lines' is a list of segmentation lines.\n
-        '''
-        if data is not None:
-            if not isinstance(data, np.ndarray):
-                raise TypeError("Data must be a numpy.ndarray.")
-            self.data = data.astype(np.int32)
-        
-        self.is_inverted = is_inverted
-        self.n_layers = n_layers
-        self.delta = delta
-        self.min_margin = min_margin
-        
-        self.data_not_inverted = None
-        self.data_inverted = None        
-        self.layers = []
-        self.helper = MaxflowBuilder()
-        self.flow = None
-        self.segmentations = []
-        self.segmentation_lines = []
-
-    def get_data(self):
-        return self.data    
-    
-    def set_data(self, data):
-        '''
-        Sets data.\n
-        - Data must be a numpy.ndarray.
-        '''
-        if not isinstance(data, np.ndarray):
-            raise TypeError("Data must be a numpy.ndarray.")
-        self.data = data.astype(np.int32)
-    
-    def get_is_inverted(self):
-        return self.is_inverted
-    
-    def set_is_inverted(self, is_inverted):
-        self.is_inverted = is_inverted
-        
-    def get_delta(self):
-        return self.delta
-    
-    def set_delta(self, delta):
-        self.delta = delta
-    
-    def get_min_margin(self):
-        return self.min_margin
-    
-    def set_min_margin(self, min_margin):
-        self.min_margin = min_margin    
-    
-    def get_data_not_inverted(self):
-        return self.data_not_inverted
-    
-    def set_data_not_inverted(self, data_not_inverted):
-        self.data_not_inverted = data_not_inverted
-    
-    def get_data_inverted(self):
-        return self.data_inverted
-    
-    def set_data_inverted(self, data_inverted):
-        self.data_inverted = data_inverted
-    
-    def update_data_not_inverted(self):
-        self.set_data_not_inverted(self.get_data())
-    
-    def update_data_inverted(self):
-        if self.get_data() is not None:
-            self.set_data_inverted(~self.get_data())
-        else:
-            self.set_data_inverted(None)
-    
-    def update_data(self):
-        '''
-        Updates data:\n
-        - If 'is_inverted' is True, data is set to 'data_inverted'.\n
-        - If 'is_inverted' is False, data is set to 'data_not_inverted'.
-        '''
-        if self.get_is_inverted():
-            self.set_data(self.get_data_inverted())
-        else:
-            self.set_data(self.get_data_not_inverted())        
-    
-    def get_n_layers(self):
-        return self.n_layers
-    
-    def set_n_layers(self, n_layers):
-        self.n_layers = n_layers
-    
-    def get_layers(self):
-        return self.layers
-    
-    def set_layers(self, layers):
-        self.layers = layers
-    
-    def add_layer_to_layers(self):
-        '''
-        Append a layer to layers.\n
-        - Data must be set and not Nonetype before adding a layer.\n
-        '''
-        if self.get_data() is None:
-            raise ValueError("Data must be set before adding a layer.")
-        self.get_layers().append(GraphObject(self.get_data()))
-    
-    def add_n_layers_to_layers(self):
-        '''
-        Append n_layers to layers.
-        '''
-        for i in range(self.get_n_layers()):
-            self.add_layer_to_layers()
-    
-    def update_layers(self):
-        '''
-        Updates layers:\n
-        - Resets layers to empty list.\n
-        - Appends n_layers to layers.
-        '''
-        self.set_layers([])
-        self.add_n_layers_to_layers()
-    
-    def get_helper(self):
-        return self.helper
-    
-    def set_helper(self, helper):
-        self.helper = helper
-    
-    def create_new_helper(self):
-        '''
-        Creates a new helper MaxflowBuilder object.
-        '''
-        self.set_helper(MaxflowBuilder())
-    
-    def add_objects_to_helper(self):
-        '''
-        Adds layers as objects to the helper.
-        '''
-        self.get_helper().add_objects(self.get_layers())
-    
-    def add_layered_boundary_cost_to_helper(self):
-        '''
-        Adds layered boundary cost to the helper.
-        '''
-        self.get_helper().add_layered_boundary_cost()
-    
-    def add_layered_smoothness_to_helper(self):
-        '''
-        Adds layered smoothness to the helper.
-        '''
-        self.get_helper().add_layered_smoothness(delta = self.get_delta())
-    
-    def add_a_layered_containment_to_helper(self, outer_object, inner_object):
-        '''
-        Adds a layered containment to the helper.
-        '''
-        self.get_helper().add_layered_containment(
-                outer_object = outer_object, 
-                inner_object = inner_object, 
-                min_margin = self.get_min_margin()
-            )
-    
-    def add_all_layered_containments_to_helper(self):
-        '''
-        Adds all layered containments to the helper.\n
-        n_layers most be at least 1.
-        '''
-        if len(self.get_layers()) < 1:
-            raise ValueError("There must be at least 1 layer to add containment.")
-        
-        for i in range(self.get_n_layers()-1):
-            self.add_a_layered_containment_to_helper(
-                    outer_object  = self.get_layers()[i], 
-                    inner_object = self.get_layers()[i + 1]
-                )
-    
-    def get_flow(self):
-        return self.flow
-    
-    def set_flow(self, flow):
-        self.flow = flow
-    
-    def solve_helper(self):
-        '''
-        Solves maxflow of the helper and stores the result in self.flow.
-        '''
-        self.set_flow(self.get_helper().solve())
+    Returns:
+        segmentations: list of numpy arrays, even if n_layers == 1, each array is only 0s and 1s, 1s segmenting this specific layer
 
-    def update_helper(self):
-        '''
-        Updates helper MaxflowBuilder object:\n
-        - Adds to helper:
-            - objects\n 
-            - layered boundary cost\n 
-            - layered smoothness\n 
-            - all layered containments\n
-        - Finally solves maxflow of the helper.
-        '''
-        self.add_objects_to_helper()
-        self.add_layered_boundary_cost_to_helper()
-        self.add_layered_smoothness_to_helper()
-        self.add_all_layered_containments_to_helper()
-        self.solve_helper()
+    Raises:
+        TypeError: If Data is not np.array, if n_layers is not integer.
+        ValueError: If n_layers is less than 1, if delta is negative or zero
 
-    def get_segmentations(self):
-        return self.segmentations
+    Example:
+        Example is only shown on 2D image, but segment_layers can also take 3D structures.
+        ```python
+        import qim3d
 
-    def set_segmentations(self, segmentations):
-        self.segmentations = segmentations
+        layers_image = qim3d.io.load('layers3d.tif')[:,:,0]
+        layers = qim3d.processing.segment_layers(layers_image, n_layers = 2)
+        layer_lines = qim3d.processing.get_lines(layers)
 
-    def add_segmentation_to_segmentations(self, layer, type = np.int32):
-        '''
-        Adds a segmentation of a layer to segmentations.\n
-        '''
-        self.get_segmentations().append(self.get_helper().what_segments(layer).astype(type))
+        from matplotlib import pyplot as plt
 
-    def add_all_segmentations_to_segmentations(self, type = np.int32):
-        '''
-        Adds all segmentations to segmentations.\n
-        - Resets segmentations to empty list.\n
-        - Appends segmentations of all layers to segmentations.
-        '''
-        self.set_segmentations([])
-        for l in self.get_layers():
-            self.add_segmentation_to_segmentations(l, type = type)
+        plt.imshow(layers_image, cmap='gray')
+        plt.axis('off')
+        for layer_line in layer_lines:
+            plt.plot(layer_line, linewidth = 3)
+        ```
+        ![layer_segmentation](assets/screenshots/layers.png)
+        ![layer_segmentation](assets/screenshots/segmented_layers.png)
 
-    def get_segmentation_lines(self):
-        return self.segmentation_lines
-    
-    def set_segmentation_lines(self, segmentation_lines):
-        self.segmentation_lines = segmentation_lines
-        
-    def add_segmentation_line_to_segmentation_lines(self, segmentation):
-        '''
-        Adds a segmentation line to segmentation_lines.\n
-        - A segmentation line is the minimum values along a given axis of a segmentation.\n
-        - Each segmentation line is shifted by 0.5 to be in the middle of the pixel.
-        '''
-        self.get_segmentation_lines().append(sig.medfilt(
-            np.argmin(segmentation, axis = 0), kernel_size = 3))
+    """
+    if isinstance(data, np.ndarray):
+        data = data.astype(np.int32)
+        if inverted:
+            data = ~data
+    else:
+        raise TypeError(F"Data has to be type np.ndarray. Your data is of type {type(data)}")
     
-    def add_all_segmentation_lines_to_segmentation_lines(self):
-        '''
-        Adds all segmentation lines to segmentation_lines.\n
-        - Resets segmentation_lines to an empty list.\n
-        - Appends segmentation lines of all segmentations to segmentation_lines.
-        '''
-        self.set_segmentation_lines([])
-        for s in self.get_segmentations():
-            self.add_segmentation_line_to_segmentation_lines(s)
-        
-    def update_semgmentations_and_semgmentation_lines(self, type = np.int32):
-        '''
-        Updates segmentations and segmentation_lines:\n
-        - Adds all segmentations to segmentations.\n
-        - Adds all segmentation lines to segmentation_lines.
-        '''
-        self.add_all_segmentations_to_segmentations(type = type)
-        self.add_all_segmentation_lines_to_segmentation_lines()
-
-    def prepare_update(self, 
-                       data = None,
-                       is_inverted = None,
-                       n_layers = None,
-                       delta = None,
-                       min_margin = None
-                       ):
-        '''
-        Prepare update of all fields of the object.\n
-        - If a field is None, it is not updated.\n
-        - If a field is not None, it is updated.
-        '''
-        if data is not None:
-            self.set_data(data)
-        if is_inverted is not None:
-            self.set_is_inverted(is_inverted)
-        if n_layers is not None:
-            self.set_n_layers(n_layers)
-        if delta is not None:
-            self.set_delta(delta)
-        if min_margin is not None:
-            self.set_min_margin(min_margin)
+    helper = MaxflowBuilder()
+    if not isinstance(n_layers, int):
+        raise TypeError(F"Number of layers has to be positive integer. You passed {type(n_layers)}")
     
-    def update(self, type = np.int32):
-        '''
-        Update all fields of the object.
-        '''
-        self.update_data_not_inverted()
-        self.update_data_inverted()
-        self.update_data()
-        self.update_layers()
-        self.create_new_helper()
-        self.update_helper()
-        self.update_semgmentations_and_semgmentation_lines(type = type)
+    if n_layers == 1:
+        layer = GraphObject(data)
+        helper.add_object(layer)
+    elif n_layers > 1:
+        layers = [GraphObject(data) for _ in range(n_layers)]
+        helper.add_objects(layers)
+        for i in range(len(layers)-1):
+            helper.add_layered_containment(layers[i], layers[i+1], min_margin=min_margin, max_margin=max_margin) 
 
-    def __repr__(self):
-        '''
-        Returns string representation of all fields of the object.
-        '''
-        return "data: %s\n, \nis_inverted: %s, \nn_layers: %s, \ndelta: %s, \nmin_margin: %s, \ndata_not_inverted: %s, \ndata_inverted: %s, \nlayers: %s, \nhelper: %s, \nflow: %s, \nsegmentations: %s, \nsegmentations_lines: %s" % (
-            self.get_data(),
-            self.get_is_inverted(),
-            self.get_n_layers(),
-            self.get_delta(),
-            self.get_min_margin(),
-            self.get_data_not_inverted(),
-            self.get_data_inverted(),
-            self.get_layers(),
-            self.get_helper(),
-            self.get_flow(),
-            self.get_segmentations(),
-            self.get_segmentation_lines()
-        )
-
-
-
-import matplotlib.pyplot as plt
-from skimage.io import imread
-from qim3d.io import load
+    else:
+        raise ValueError(F"Number of layers has to be positive integer. You passed {n_layers}")
+    
+    helper.add_layered_boundary_cost()
+
+    if delta > 1:
+        delta = int(delta)
+    elif delta <= 0:
+        raise ValueError(F'Delta has to be positive number. You passed {delta}')
+    helper.add_layered_smoothness(delta=delta, wrap = bool(wrap))
+    helper.solve()
+    if n_layers == 1:
+        segmentations =[helper.what_segments(layer)]
+    else:
+        segmentations = [helper.what_segments(l).astype(np.int32) for l in layers]
 
-if __name__ == "__main__":        
-    # Draw results.
-    def visulise(l2d = None):
-        plt.figure(figsize = (10, 10))
-        ax = plt.subplot(1, 3, 1)
-        ax.imshow(l2d.get_data(), cmap = "gray")
+    return segmentations
 
-        ax = plt.subplot(1, 3, 2)
-        ax.imshow(np.sum(l2d.get_segmentations(), axis = 0))
+def get_lines(segmentations:list|np.ndarray) -> list:
+    """
+    Expects list of arrays where each array is 2D segmentation with only 2 classes. This function gets the border between those two
+    so it could be plotted. Used with qim3d.processing.segment_layers
 
-        ax = plt.subplot(1, 3, 3)
-        ax.imshow(data, cmap = "gray")
-        for line in l2d.get_segmentation_lines():
-            ax.plot(line)
-        plt.show()
-    
-    # Data input
-    d_switch = False    
-    if d_switch:
-        path = os.path.join(os.getcwd(), "qim3d", "img_examples", "slice_218x193.png")
-        data = imread(path).astype(np.int32)
-    else:
-        path = os.path.join(os.getcwd(), "qim3d", "img_examples", "bone_128x128x128.tif")
-        data3D = load(
-                    path,
-                    virtual_stack=True,
-                    dataset_name="",
-                )
-    
-        x = data3D.shape[0]
-        y = data3D.shape[1]
-        z = data3D.shape[2]
-    
-        data = data3D[x//2, :, :] 
-        data = data3D[:, y//2, :]
-        data = data3D[:, :, z//2] 
+    Args:
+        segmentations: list of arrays where each array is 2D segmentation with only 2 classes
 
-    layers2d = Layers2d(data = data, n_layers = 3, delta = 1, min_margin = 10)
-    layers2d.update()
-    visulise(layers2d)
-    
-    layers2d.prepare_update(n_layers = 1)
-    layers2d.update()
-    visulise(layers2d)
-    
-    layers2d.prepare_update(is_inverted = True)
-    layers2d.update()
-    visulise(layers2d)
\ No newline at end of file
+    Returns:
+        segmentation_lines: list of 1D numpy arrays
+    """
+    segmentation_lines = [np.argmin(s, axis=0) - 0.5 for s in segmentations]
+    return segmentation_lines
\ No newline at end of file
diff --git a/qim3d/viz/layers2d.py b/qim3d/viz/layers2d.py
index 478583145ce5f6372f6ec1f6b1238ccbb80dea97..682d2ce016e78f84bf9328f2782e942f2615c2f5 100644
--- a/qim3d/viz/layers2d.py
+++ b/qim3d/viz/layers2d.py
@@ -1,162 +1,143 @@
 """ Provides a collection of visualisation functions for the Layers2d class."""
+import io
+
 import matplotlib.pyplot as plt
 import numpy as np
-<<<<<<< HEAD
+
 from qim3d.processing import layers2d as l2d
-=======
-from qim3d.process import layers2d as l2d
->>>>>>> 1622d378193877cb85f53b1a05207088b3f3cf0a
 
-def create_subplot_of_2d_arrays(data, m_rows = 1, n_cols = 1, figsize = None):
-    '''
-    Creates a `m x n` grid subplot from a collection of 2D arrays.
-    
-    Args:
-        `data` (list of 2D numpy.ndarray): A list of 2d numpy.ndarray.
-        `m_rows` (int): The number of rows in the subplot grid.
-        `n_cols` (int): The number of columns in the subplot grid.
+from PIL import Image
+def image_with_overlay(image:np.ndarray, overlay:np.ndarray, alpha:int|float|np.ndarray = 125) -> Image:
+    #TODO : also accepts Image type
+    # We want to accept as many different values as possible to make convenient for the user.
+    """
+    Takes image and puts a transparent segmentation mask on it.
 
-    Raises:
-        ValueError: If the product of m_rows and n_cols is not equal to the number of 2d arrays in data.
+    Parameters:
+    -----------
+    Image: Can be grayscale or colorful, accepts all kinds of shapes, color has to be the last axis
+    Overlay: If has its own alpha channel, alpha argument is ignored. 
+    Alpha:  Can be ansolute value as int, relative vlaue as float or an array so different parts can differ with the transparency.
 
-    Notes:
-    - Subplots are organized in a m rows x n columns Grid.
-    - The total number of subplots is equal to the product of m_rows and n_cols.
-    
     Returns:
-        A tuple of (`fig`, `ax_list`), where fig is a matplotlib.pyplot.figure and ax_list is a list of matplotlib.pyplot.axes.
-    '''
-    total = m_rows * n_cols
-    
-    if total != len(data):
-        raise ValueError("The product of m_rows and n_cols must be equal to the number of 2D arrays in data.\nCurrently, m_rows * n_cols = {}, while arrays in data = {}".format(m_rows * n_cols, len(data)))
-    
-    pos_idx = range(1, total + 1)
-    
-    if figsize is None:
-        figsize = (m_rows * 10, n_cols * 10)
-    fig = plt.figure(figsize = figsize)
+    ---------
+    Image: PIL.Image in the original size as the image array
+
+    Raises:
+    --------
+    ValueError: If there is a missmatch of shapes or alpha has an invalid value.
+    """
+    def check_dtype(image:np.ndarray):
+        if image.dtype != np.uint8:
+            minimal = np.min(image)
+            if minimal < 0:
+                image = image + minimal
+            maximum = np.max(image)
+            if maximum > 255:
+                image = (image/maximum)*255
+            elif maximum <= 1:
+                image = image*255  
+            image = np.uint8(image)
+        return image
+
+    image = check_dtype(image)
+    overlay = check_dtype(overlay)
     
-    ax_list = []
+    if image.shape[0] != overlay.shape[0] or image.shape[1] != overlay.shape[1]:
+        raise ValueError(F"The first two dimensions of overlay image must match those of background image.\nYour background image: {image.shape}\nYour overlay image: {overlay.shape}")
     
-    for k in range(total):
-        ax_list.append(fig.add_subplot(m_rows, n_cols, pos_idx[k]))
-        ax_list[k].imshow(data[k], cmap = "gray")
     
-    plt.tight_layout()
-    return fig, ax_list
+    if image.ndim == 3:
+        if image.shape[2] < 3:
+            image = np.repeat(image[:,:,:1], 3, -1)
+        elif image.shape[2] > 4:
+            image = image[:,:,:4]
+
+    elif image.ndim == 2:
+        image = np.repeat(image[..., None], 3, -1)
 
-def create_plot_of_2d_array(data, figsize = (10, 10)):
-    '''
-    Creates a plot of a 2D array.
+    else:
+        raise ValueError(F"Background image must have 2 or 3 dimensions. Yours have {image.ndim}")
     
-    Args:
-        `data` (list of 2D numpy.ndarray): A list of 2d numpy.ndarray.
-        `figsize` (tuple of int): The figure size.
-    Notes:
-        - If data is not a list, it is converted to a list.
-    Returns:
-        A tuple of (`fig`, `ax`), where fig is a matplotlib.pyplot.figure and ax is a matplotlib.pyplot.axes.
-    '''
-    if not isinstance(data, list):
-        data = [data]
     
-    fig, ax_list = create_subplot_of_2d_arrays(data, figsize = figsize)
-    return fig, ax_list[0]
     
-def merge_multiple_segmentations_2d(segmentations):
-    '''
-    Merges multiple segmentations of a 2D image into a single image.
+    if isinstance(alpha, (float, int)):
+        if alpha<0:
+            raise ValueError(F"Alpha can not be negative. You passed {alpha}")
+        elif alpha<=1:
+            alpha = int(255*alpha)
+        elif alpha> 255:
+            alpha = 255
+        else:
+            alpha = int(alpha)
+
+    elif isinstance(alpha, np.ndarray):
+        if alpha.ndim == 3:
+            alpha = alpha[..., :1] # Making sure it is only one layer
+        elif alpha.ndim == 2:
+            alpha = alpha[..., None] # Making sure it has 3 dimensions
+        else:
+            raise ValueError(F"If alpha is numpy array, it must have 2 or 3 dimensions. Your have {alpha.ndim}")
+        
+        # We have not checked ndims of overlay
+        try:
+            if alpha.shape[0] != overlay.shape[0] or alpha.shape[1] != overlay.shape[1]:
+                raise ValueError(F"The first two dimensions of alpha must match those of overlay image.\nYour alpha: {alpha.shape}\nYour overlay: {overlay.shape}")
+        except IndexError:
+            raise ValueError(F"Overlay image must have 2 or 3 dimensions. Yours have {overlay.ndim}")
+        
+
+    if overlay.ndim == 3:
+        if overlay.shape[2] < 3:
+            overlay = np.repeat(overlay[..., :1], 4, -1)
+            if alpha is None:
+                raise ValueError("Alpha can not be None if overlay image doesn't have alpha channel")
+            overlay[..., 3] = alpha
+        elif overlay.shape[2] == 3:
+            if isinstance(alpha, int):
+                overlay = np.concatenate((overlay, np.full((overlay.shape[0], overlay.shape[1], 1,), alpha, dtype = np.uint8)), axis = -1)
+            elif isinstance(alpha, np.ndarray):
+                overlay = np.concatenate((overlay, alpha), axis = -1)
+
+        elif overlay.shape[2]>4:
+            raise ValueError(F"Overlay image can not have more than 4 channels. Yours have {overlay.shape[2]}")
+
+    elif overlay.ndim == 2:
+        overlay = np.repeat(overlay[..., None], 4, axis = -1)
+        overlay[..., 3] = alpha
+    else:
+        raise ValueError(F"Overlay image must have 2 or 3 dimensions. Yours have {overlay.ndim}")
     
-    Args:
-        `segmenations` (list of numpy.ndarray): A list of 2D numpy.ndarray.
+    background = Image.fromarray(image)
+    overlay = Image.fromarray(overlay)
+    background.paste(overlay, mask = overlay)
+    return background
+
+
+def image_with_lines(image:np.ndarray, lines: list, line_thickness:float|int) -> Image:
+    """
+    Plots the image and plots the lines on top of it. Then extracts it as PIL.Image and in the same size as the input image was.
+    Paramters:
+    -----------
+    image: Image on which we put the lines
+    lines: list of 1D arrays to be plotted on top of the image
+    line_thickness: how thick is the line supposed to be
+
     Returns:
-        A 2D numpy.ndarray representing the merged segmentations.
-    '''
-    if len(segmentations) == 0:
-        raise ValueError("Segmentations must contain at least one segmentation.")
-    if len(segmentations) == 1:
-        return segmentations[0]
-    else:
-        return np.sum(segmentations, axis = 0)
+    ----------
+    image_with_lines: 
+    """
+    fig, ax = plt.subplots()
+    ax.imshow(image, cmap = 'gray')
+    ax.axis('off')
 
-def add_line_to_plot(axes, line, line_color = None):
-    '''
-    Adds a line to plot.
-    
-    Args:
-        `axes` (matplotlib.pyplot.axes): A matplotlib.pyplot.axes.
-        `line` (numpy.ndarray): A 1D numpy.ndarray.
-    
-    Notes:
-        - The line is added on top of to the plot.
-    '''
-    if line_color is None:
-        axes.plot(line)
-    else:
-        axes.plot(line, color = line_color)
+    for line in lines:
+        ax.plot(line, linewidth = line_thickness)
 
-def add_lines_to_plot(axes, lines, line_colors = None):
-    '''
-    Adds multiple lines to plot.
-    
-    Args:
-        `axes` (matplotlib.pyplot.axes): A matplotlib.pyplot.axes.
-        `lines` (list of numpy.ndarray): A list of 1D numpy.ndarray.
-    
-    Notes:
-        - The lines are added on top of to the plot.
-    '''
-    if line_colors is None:
-        for line in lines:
-            axes.plot(line)
-    else:
-        for i in range(len(lines)):
-            axes.plot(lines[i], color = line_colors[i])
+    buf = io.BytesIO()
+    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
+    plt.close()
 
-import os
-from skimage.io import imread
+    buf.seek(0)
+    return Image.open(buf).resize(size = image.squeeze().shape[::-1])
 
-if __name__ == "__main__":
-    path = os.path.join(os.getcwd(), "qim3d", "img_examples", "slice_218x193.png")
-    data = imread(path).astype(np.int32)
-    
-    l2d_obj = l2d.Layers2d()
-    l2d_obj.prepare_update(
-        data = data, 
-        is_inverted=False,
-        delta=1,
-        min_margin=10,
-        n_layers=4,
-        ) 
-    l2d_obj.update()    
-        
-    # Show how create_plot_from_2d_arrays works:
-    fig1, ax1 = create_plot_of_2d_array(l2d_obj.get_data())
-    
-    data_lines = []
-    for i in range(len(l2d_obj.get_segmentation_lines())):
-        data_lines.append(l2d_obj.get_segmentation_lines()[i])
-    
-    # Show how add_line_to_plot works:
-    add_line_to_plot(ax1, data_lines[3])
-    
-    # Show how merge_multiple_segmentations_2d works:
-    data_seg = []
-    for i in range(len(l2d_obj.get_segmentations())):
-        data_seg.append(merge_multiple_segmentations_2d(l2d_obj.get_segmentations()[:i+1]))
-    
-    # Show how create_subplot_of_2d_arrays works:
-    fig2, ax_list = create_subplot_of_2d_arrays(
-            data_seg, 
-            m_rows = 1, 
-            n_cols = len(l2d_obj.get_segmentations())
-            # m_rows = len(l2d_obj.get_segmentations()), 
-            # n_cols = 1
-        )
-    
-    # Show how add_lines_to_plot works:
-    add_lines_to_plot(ax_list[1], data_lines[0:3])
-    
-    plt.show()
-    
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 58a1a3ddae03bb830662981ad98fa9d483308501..d28a8e0113da17fc95f2c09e4f368532a84412c3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,7 +10,7 @@ Pillow>=10.0.1
 plotly>=5.14.1
 scipy>=1.11.2
 seaborn>=0.12.2
-pydicom>=2.4.4
+pydicom==2.4.4
 setuptools>=68.0.0
 tifffile>=2023.4.12
 torch>=2.0.1
diff --git a/setup.py b/setup.py
index 0838da68652a2331ff134eef4e222034382f48b5..8ccf01bcdde46d5d149629f0de422862ce70eb3e 100644
--- a/setup.py
+++ b/setup.py
@@ -49,7 +49,7 @@ setup(
         "h5py>=3.9.0",
         "localthickness>=0.1.2",
         "matplotlib>=3.8.0",
-        "pydicom>=2.4.4",
+        "pydicom==2.4.4",
         "numpy>=1.26.0",
         "outputformat>=0.1.3",
         "Pillow>=10.0.1",