Skip to content
Snippets Groups Projects

Layered Surface Segmentation Feature

Open s194066 requested to merge layered_surface_segmentation into main
+ 45
47
@@ -27,48 +27,45 @@ from .interface import BaseInterface
# from qim3d.processing import layers2d as l2d
from qim3d.processing import overlay_rgb_images, segment_layers, get_lines
from qim3d.io import load
from qim3d.viz.layers2d import image_with_lines
from qim3d.viz.layers2d import image_with_lines, image_with_overlay
#TODO figure out how not update anything and go through processing when there are no data loaded
# So user could play with the widgets but it doesnt throw error
# Right now its only bypassed with several if statements
# I opened an issue here https://github.com/gradio-app/gradio/issues/9273
X = 1
Y = 0
Z = 2
X = 'X'
Y = 'Y'
Z = 'Z'
AXES = {X:2, Y:1, Z:0}
DEFAULT_PLOT_TYPE = 'Segmentation mask'
SEGMENTATION_COLORS = np.array([[0, 255, 255], # Cyan
[255, 195, 0], # Yellow Orange
[199, 0, 57], # Dark orange
[218, 247, 166], # Light green
[255, 0, 255], # Magenta
[65, 105, 225], # Royal blue
[138, 43, 226], # Blue violet
[255, 0, 0], #Red
])
class Interface(BaseInterface):
def __init__(self):
super().__init__("Layered surfaces 2D", 1080)
self.figsize = (8, 8)
self.cmap = "Greys_r"
self.axes = {'x': X, 'y':Y, 'z':Z}
self.data = None
self.x_slice = None
self.y_slice = None
self.z_slice = None
self.x_segmentation = None
self.y_segmentation = None
self.z_segmentation = None
# It important to keep the name of the attributes like this (including the capital letter) becuase of
# accessing the attributes via __dict__
self.X_slice = None
self.Y_slice = None
self.Z_slice = None
self.X_segmentation = None
self.Y_segmentation = None
self.Z_segmentation = None
self.plot_type = DEFAULT_PLOT_TYPE
self.segmentation_colors = np.array([[0, 255, 255], # Cyan
[255, 195, 0], # Yellow Orange
[199, 0, 57], # Dark orange
[218, 247, 166], # Light green
[255, 0, 255], # Magenta
[65, 105, 225], # Royal blue
[138, 43, 226], # Blue violet
[255, 0, 0], #Red
])
self.error = False
@@ -128,10 +125,10 @@ class Interface(BaseInterface):
with gr.Group():
with gr.Row():
axis = gr.Radio(
choices = ['Y', 'X', 'Z'],
value = 'Y',
choices = [Z, Y, X],
value = Z,
label = 'Layer axis',
info = 'Specifies in which direction are the layers. Because of numpy design, Y is 0th axis, X is 1st and Z is 2nd.',)
info = 'Specifies in which direction are the layers. The order of axes is ZYX',)
with gr.Row():
wrap = gr.Checkbox(
label = "Lines start and end at the same level.",
@@ -168,7 +165,7 @@ class Interface(BaseInterface):
with gr.Row():
n_layers = gr.Slider(
minimum=1,
maximum=len(self.segmentation_colors) - 1,
maximum=len(SEGMENTATION_COLORS) - 1,
value=2,
step=1,
interactive=True,
@@ -292,8 +289,8 @@ class Interface(BaseInterface):
fn = self.change_plot_size, inputs = visibility_check_inputs, outputs = output_plots)
# for axis, slider, input_plot, output_plot in zip(['x','y','z'], positions, input_plots, output_plots):
for axis, slider, output_plot in zip(['x','y','z'], positions, output_plots):
slider.release(
for axis, slider, output_plot in zip([X,Y,Z], positions, output_plots):
slider.change(
self.process_wrapper(axis), inputs = [slider, *process_inputs]).then(
# self.plot_input_img_wrapper(axis), outputs = input_plot).then(
self.plot_output_img_wrapper(axis), inputs = plotting_inputs, outputs = output_plot)
@@ -364,9 +361,9 @@ class Interface(BaseInterface):
) from error_message
def process_all(self, x_pos:float, y_pos:float, z_pos:float, axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool):
self.process_wrapper('x')(x_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper('y')(y_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper('z')(z_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper(X)(x_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper(Y)(y_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper(Z)(z_pos, axis, inverted, delta, min_margin, n_layers, wrap)
def process_wrapper(self, slicing_axis:str):
"""
@@ -374,9 +371,9 @@ class Interface(BaseInterface):
Thus we have this wrapper function, where we pass the slicing axis - in which axis are we indexing the data
and we return a function working in that direction
"""
slice_key = F'{slicing_axis.lower()}_slice'
seg_key = F'{slicing_axis.lower()}_segmentation'
slicing_axis_int = self.axes[slicing_axis]
slice_key = F'{slicing_axis}_slice'
seg_key = F'{slicing_axis}_segmentation'
slicing_axis_int = AXES[slicing_axis]
def process(pos:float, segmenting_axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool):
"""
@@ -393,7 +390,7 @@ class Interface(BaseInterface):
slice = self.get_slice(pos, slicing_axis_int)
self.__dict__[slice_key] = slice
if segmenting_axis.lower() == slicing_axis.lower():
if segmenting_axis == slicing_axis:
self.__dict__[seg_key] = None
else:
@@ -408,8 +405,8 @@ class Interface(BaseInterface):
Checks if the desired direction of segmentation is the same if the image would be submitted to segmentation as is.
If it is not, we have to rotate it before we put it to segmentation algorithm
"""
remaining_axis = 'xyz'.replace(slicing_axis.lower(), '').replace(segmenting_axis.lower(), '')
return self.axes[segmenting_axis.lower()] > self.axes[remaining_axis]
remaining_axis = F"{X}{Y}{Z}".replace(slicing_axis, '').replace(segmenting_axis, '')
return AXES[segmenting_axis] > AXES[remaining_axis]
def get_slice(self, pos:float, axis:int):
idx = int(pos * (self.data.shape[axis] - 1))
@@ -431,8 +428,8 @@ class Interface(BaseInterface):
# return x_plot, y_plot, z_plot
def plot_output_img_wrapper(self, slicing_axis:str):
slice_key = F'{slicing_axis.lower()}_slice'
seg_key = F'{slicing_axis.lower()}_segmentation'
slice_key = F'{slicing_axis}_slice'
seg_key = F'{slicing_axis}_segmentation'
def plot_output_img(segmenting_axis:str, alpha:float, line_thickness:float):
slice = self.__dict__[slice_key]
@@ -446,11 +443,12 @@ class Interface(BaseInterface):
seg = np.sum(seg, axis = 0)
seg = np.repeat(seg[..., None], 3, axis = -1)
for i in range(n_layers):
seg[seg[:,:,0] == i, :] = self.segmentation_colors[i]
seg[seg[:,:,0] == i, :] = SEGMENTATION_COLORS[i]
if self.is_transposed(slicing_axis, segmenting_axis):
seg = np.rot90(seg, k = 3)
return overlay_rgb_images(np.repeat(slice[..., None], 3, -1), seg, alpha)
# slice = 255 * (slice/np.max(slice))
return image_with_overlay(np.repeat(slice[..., None], 3, -1), seg, alpha)
else:
lines = get_lines(seg)
if self.is_transposed(slicing_axis, segmenting_axis):
@@ -461,9 +459,9 @@ class Interface(BaseInterface):
return plot_output_img
def plot_output_img_all(self, segmenting_axis:str, alpha:float, line_thickness:float):
x_output = self.plot_output_img_wrapper('x')(segmenting_axis, alpha, line_thickness)
y_output = self.plot_output_img_wrapper('y')(segmenting_axis, alpha, line_thickness)
z_output = self.plot_output_img_wrapper('z')(segmenting_axis, alpha, line_thickness)
x_output = self.plot_output_img_wrapper(X)(segmenting_axis, alpha, line_thickness)
y_output = self.plot_output_img_wrapper(Y)(segmenting_axis, alpha, line_thickness)
z_output = self.plot_output_img_wrapper(Z)(segmenting_axis, alpha, line_thickness)
return x_output, y_output, z_output
if __name__ == "__main__":
Loading