Skip to content
Snippets Groups Projects

Layered Surface Segmentation Feature

1 file
+ 96
11
Compare changes
  • Side-by-side
  • Inline
+ 96
11
@@ -5,6 +5,7 @@ from qim3d.io import DataLoader
@@ -5,6 +5,7 @@ from qim3d.io import DataLoader
from qim3d.io.logger import log
from qim3d.io.logger import log
from qim3d.process import layers2d as l2d
from qim3d.process import layers2d as l2d
from qim3d.io import load
from qim3d.io import load
 
import qim3d.viz
# matplotlib.use("Agg")
# matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
@@ -18,6 +19,9 @@ class Session:
@@ -18,6 +19,9 @@ class Session:
self.base_path = None
self.base_path = None
self.explorer = None
self.explorer = None
self.file_path = None
self.file_path = None
 
self.delta = 1
 
self.min_margin = 10
 
self.n_layers = 4
class Interface:
class Interface:
@@ -46,6 +50,9 @@ class Interface:
@@ -46,6 +50,9 @@ class Interface:
session = Session()
session = Session()
session.base_path = args[0]
session.base_path = args[0]
session.explorer = args[1]
session.explorer = args[1]
 
session.delta = args[2]
 
session.min_margin = args[3]
 
session.n_layers = args[4]
# Get the file path from the explorer or base path
# Get the file path from the explorer or base path
if session.base_path and os.path.isfile(session.base_path):
if session.base_path and os.path.isfile(session.base_path):
@@ -72,13 +79,26 @@ class Interface:
@@ -72,13 +79,26 @@ class Interface:
else:
else:
raise ValueError("Invalid path")
raise ValueError("Invalid path")
 
def set_spinner(self, message):
 
return gr.update(
 
elem_classes="btn btn-spinner",
 
value=f"{message}",
 
interactive=False,
 
)
 
 
def set_relaunch_button(self):
 
return gr.update(
 
elem_classes="btn btn-run",
 
value=f"Relaunch",
 
interactive=True,
 
)
 
def create_interface(self):
def create_interface(self):
with gr.Blocks(css=self.css_path) as gradio_interface:
with gr.Blocks(css=self.css_path) as gradio_interface:
gr.Markdown(f"# {self.title}")
gr.Markdown(f"# {self.title}")
with gr.Row():
with gr.Row():
with gr.Column(scale=1, min_width=320):
with gr.Column(scale=1, min_width=320):
gr.Markdown("### File selection")
with gr.Row():
with gr.Row():
with gr.Column(scale=99, min_width=128):
with gr.Column(scale=99, min_width=128):
base_path = gr.Textbox(
base_path = gr.Textbox(
@@ -102,36 +122,65 @@ class Interface:
@@ -102,36 +122,65 @@ class Interface:
elem_classes="h-256 hide-overflow",
elem_classes="h-256 hide-overflow",
)
)
# Run button
# TODO Add description for parameters in the interface
 
with gr.Row():
 
delta = gr.Slider(
 
minimum=0.5,
 
maximum=1.0,
 
value=1,
 
step=0.01,
 
label="Delta value",
 
)
 
with gr.Row():
 
min_margin = gr.Slider(
 
minimum=1, maximum=50, value=10, step=1, label="Min margin"
 
)
with gr.Row():
with gr.Row():
btn_run = gr.Button("Run Layers2D", elem_classes="btn btn-run")
n_layers = gr.Slider(
 
minimum=1,
 
maximum=8,
 
value=4,
 
step=1,
 
label="Number of layers",
 
)
 
with gr.Row():
 
btn_run = gr.Button(
 
"Run Layers2D", elem_classes="btn btn-html btn-run"
 
)
with gr.Column(scale=4):
with gr.Column(scale=2):
with gr.Row():
with gr.Row():
input_plot = gr.Plot(
input_plot = gr.Plot(
show_label=True,
show_label=True,
label="Source image",
label="Source image",
visible=True,
visible=True,
elem_classes="plot",
elem_classes="rounded",
)
)
output_plot = gr.Plot(
output_plot = gr.Plot(
show_label=True,
show_label=True,
label="Detected layers",
label="Detected layers",
visible=True,
visible=True,
elem_classes="plot",
elem_classes="rounded",
)
)
# Session
# Session
session = gr.State([])
session = gr.State([])
pipeline = Pipeline()
pipeline = Pipeline()
inputs = [base_path, explorer]
inputs = [base_path, explorer, delta, min_margin, n_layers]
 
spinner_loading = gr.Text("Loading data...", visible=False)
 
spinner_running = gr.Text("Running pipeline...", visible=False)
# fmt: off
# fmt: off
reload_base_path.click(fn=self.update_explorer,inputs=base_path, outputs=explorer)
reload_base_path.click(fn=self.update_explorer,inputs=base_path, outputs=explorer)
btn_run.click(
btn_run.click(
fn=self.start_session, inputs=inputs, outputs=session).then(
fn=self.start_session, inputs=inputs, outputs=session).then(
 
fn=self.set_spinner, inputs=spinner_loading, outputs=btn_run).then(
fn=pipeline.load_data, inputs=session, outputs=session).then(
fn=pipeline.load_data, inputs=session, outputs=session).then(
fn=pipeline.plot_input_img, inputs=session, outputs=input_plot)
fn=self.set_spinner, inputs=spinner_running, outputs=btn_run).then(
 
fn=pipeline.plot_input_img, inputs=session, outputs=input_plot).then(
 
fn=pipeline.process_l2d, inputs=session, outputs=session).then(
 
fn=pipeline.plot_l2d_output, inputs=session, outputs=output_plot).then(
 
fn=self.set_relaunch_button, inputs=[], outputs=btn_run)
# fmt: on
# fmt: on
return gradio_interface
return gradio_interface
@@ -139,9 +188,9 @@ class Interface:
@@ -139,9 +188,9 @@ class Interface:
class Pipeline:
class Pipeline:
def __init__(self):
def __init__(self):
self.figsize = (6, 6)
self.figsize = (8, 8)
def plot_input_img(self, session, cmap="viridis"):
def plot_input_img(self, session, cmap="Greys_r"):
plt.close()
plt.close()
fig, ax = plt.subplots(figsize=self.figsize)
fig, ax = plt.subplots(figsize=self.figsize)
@@ -153,6 +202,43 @@ class Pipeline:
@@ -153,6 +202,43 @@ class Pipeline:
return fig
return fig
 
def plot_l2d_output(self, session):
 
l2d_obj = session.l2d_obj
 
fig, ax = qim3d.viz.layers2d.create_plot_of_2d_array(l2d_obj.get_data())
 
 
data_lines = []
 
for i in range(len(l2d_obj.get_segmentation_lines())):
 
data_lines.append(l2d_obj.get_segmentation_lines()[i])
 
 
# Show how add_line_to_plot works:
 
for line in data_lines:
 
qim3d.viz.layers2d.add_line_to_plot(ax, line)
 
 
# Adjustments
 
ax.axis("off")
 
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
 
 
return fig
 
 
def process_l2d(self, session):
 
data = session.data
 
# TODO Add here some checks to be usre data is 2D
 
 
# TODO: Handle "is_inverted" from gradio
 
l2d_obj = l2d.Layers2d()
 
l2d_obj.prepare_update(
 
data=data,
 
is_inverted=False,
 
delta=session.delta,
 
min_margin=session.min_margin,
 
n_layers=session.n_layers,
 
)
 
l2d_obj.update()
 
 
session.l2d_obj = l2d_obj
 
 
return session
 
def load_data(self, session):
def load_data(self, session):
try:
try:
session.data = load(
session.data = load(
@@ -165,7 +251,6 @@ class Pipeline:
@@ -165,7 +251,6 @@ class Pipeline:
f"Failed to load the image: {error_message}"
f"Failed to load the image: {error_message}"
) from error_message
) from error_message
print(session.data)
return session
return session
Loading