Skip to content
Snippets Groups Projects
Commit fecd3ec6 authored by David Grundfest's avatar David Grundfest
Browse files

Axes swap to follow YZX notation

parent e5c6fc89
Branches
No related tags found
2 merge requests!117New Layered Surface Segmentation,!42Layered Surface Segmentation Feature
<<<<<<< HEAD """The qim3d library is designed to make it easier to work with 3D imaging data in Python.
"""qim3d: A Python package for 3D image processing and visualization.
=======
import qim3d.io as io
import qim3d.gui as gui
import qim3d.viz as viz
import qim3d.utils as utils
import qim3d.models as models
import qim3d.process as process
import logging
>>>>>>> 1622d378193877cb85f53b1a05207088b3f3cf0a
The qim3d library is designed to make it easier to work with 3D imaging data in Python.
It offers a range of features, including data loading and manipulation, It offers a range of features, including data loading and manipulation,
image processing and filtering, visualization of 3D data, and analysis of imaging results. image processing and filtering, visualization of 3D data, and analysis of imaging results.
......
...@@ -27,39 +27,20 @@ from .interface import BaseInterface ...@@ -27,39 +27,20 @@ from .interface import BaseInterface
# from qim3d.processing import layers2d as l2d # from qim3d.processing import layers2d as l2d
from qim3d.processing import overlay_rgb_images, segment_layers, get_lines from qim3d.processing import overlay_rgb_images, segment_layers, get_lines
from qim3d.io import load from qim3d.io import load
from qim3d.viz.layers2d import image_with_lines from qim3d.viz.layers2d import image_with_lines, image_with_overlay
#TODO figure out how not update anything and go through processing when there are no data loaded #TODO figure out how not update anything and go through processing when there are no data loaded
# So user could play with the widgets but it doesnt throw error # So user could play with the widgets but it doesnt throw error
# Right now its only bypassed with several if statements # Right now its only bypassed with several if statements
# I opened an issue here https://github.com/gradio-app/gradio/issues/9273 # I opened an issue here https://github.com/gradio-app/gradio/issues/9273
X = 1 X = 'X'
Y = 0 Y = 'Y'
Z = 2 Z = 'Z'
AXES = {X:2, Y:1, Z:0}
DEFAULT_PLOT_TYPE = 'Segmentation mask' DEFAULT_PLOT_TYPE = 'Segmentation mask'
SEGMENTATION_COLORS = np.array([[0, 255, 255], # Cyan
class Interface(BaseInterface):
def __init__(self):
super().__init__("Layered surfaces 2D", 1080)
self.figsize = (8, 8)
self.cmap = "Greys_r"
self.axes = {'x': X, 'y':Y, 'z':Z}
self.data = None
self.x_slice = None
self.y_slice = None
self.z_slice = None
self.x_segmentation = None
self.y_segmentation = None
self.z_segmentation = None
self.plot_type = DEFAULT_PLOT_TYPE
self.segmentation_colors = np.array([[0, 255, 255], # Cyan
[255, 195, 0], # Yellow Orange [255, 195, 0], # Yellow Orange
[199, 0, 57], # Dark orange [199, 0, 57], # Dark orange
[218, 247, 166], # Light green [218, 247, 166], # Light green
...@@ -69,6 +50,22 @@ class Interface(BaseInterface): ...@@ -69,6 +50,22 @@ class Interface(BaseInterface):
[255, 0, 0], #Red [255, 0, 0], #Red
]) ])
class Interface(BaseInterface):
def __init__(self):
super().__init__("Layered surfaces 2D", 1080)
self.data = None
# It important to keep the name of the attributes like this (including the capital letter) becuase of
# accessing the attributes via __dict__
self.X_slice = None
self.Y_slice = None
self.Z_slice = None
self.X_segmentation = None
self.Y_segmentation = None
self.Z_segmentation = None
self.plot_type = DEFAULT_PLOT_TYPE
self.error = False self.error = False
...@@ -128,10 +125,10 @@ class Interface(BaseInterface): ...@@ -128,10 +125,10 @@ class Interface(BaseInterface):
with gr.Group(): with gr.Group():
with gr.Row(): with gr.Row():
axis = gr.Radio( axis = gr.Radio(
choices = ['Y', 'X', 'Z'], choices = [Z, Y, X],
value = 'Y', value = Z,
label = 'Layer axis', label = 'Layer axis',
info = 'Specifies in which direction are the layers. Because of numpy design, Y is 0th axis, X is 1st and Z is 2nd.',) info = 'Specifies in which direction are the layers. The order of axes is ZYX',)
with gr.Row(): with gr.Row():
wrap = gr.Checkbox( wrap = gr.Checkbox(
label = "Lines start and end at the same level.", label = "Lines start and end at the same level.",
...@@ -168,7 +165,7 @@ class Interface(BaseInterface): ...@@ -168,7 +165,7 @@ class Interface(BaseInterface):
with gr.Row(): with gr.Row():
n_layers = gr.Slider( n_layers = gr.Slider(
minimum=1, minimum=1,
maximum=len(self.segmentation_colors) - 1, maximum=len(SEGMENTATION_COLORS) - 1,
value=2, value=2,
step=1, step=1,
interactive=True, interactive=True,
...@@ -292,8 +289,8 @@ class Interface(BaseInterface): ...@@ -292,8 +289,8 @@ class Interface(BaseInterface):
fn = self.change_plot_size, inputs = visibility_check_inputs, outputs = output_plots) fn = self.change_plot_size, inputs = visibility_check_inputs, outputs = output_plots)
# for axis, slider, input_plot, output_plot in zip(['x','y','z'], positions, input_plots, output_plots): # for axis, slider, input_plot, output_plot in zip(['x','y','z'], positions, input_plots, output_plots):
for axis, slider, output_plot in zip(['x','y','z'], positions, output_plots): for axis, slider, output_plot in zip([X,Y,Z], positions, output_plots):
slider.release( slider.change(
self.process_wrapper(axis), inputs = [slider, *process_inputs]).then( self.process_wrapper(axis), inputs = [slider, *process_inputs]).then(
# self.plot_input_img_wrapper(axis), outputs = input_plot).then( # self.plot_input_img_wrapper(axis), outputs = input_plot).then(
self.plot_output_img_wrapper(axis), inputs = plotting_inputs, outputs = output_plot) self.plot_output_img_wrapper(axis), inputs = plotting_inputs, outputs = output_plot)
...@@ -364,9 +361,9 @@ class Interface(BaseInterface): ...@@ -364,9 +361,9 @@ class Interface(BaseInterface):
) from error_message ) from error_message
def process_all(self, x_pos:float, y_pos:float, z_pos:float, axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool): def process_all(self, x_pos:float, y_pos:float, z_pos:float, axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool):
self.process_wrapper('x')(x_pos, axis, inverted, delta, min_margin, n_layers, wrap) self.process_wrapper(X)(x_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper('y')(y_pos, axis, inverted, delta, min_margin, n_layers, wrap) self.process_wrapper(Y)(y_pos, axis, inverted, delta, min_margin, n_layers, wrap)
self.process_wrapper('z')(z_pos, axis, inverted, delta, min_margin, n_layers, wrap) self.process_wrapper(Z)(z_pos, axis, inverted, delta, min_margin, n_layers, wrap)
def process_wrapper(self, slicing_axis:str): def process_wrapper(self, slicing_axis:str):
""" """
...@@ -374,9 +371,9 @@ class Interface(BaseInterface): ...@@ -374,9 +371,9 @@ class Interface(BaseInterface):
Thus we have this wrapper function, where we pass the slicing axis - in which axis are we indexing the data Thus we have this wrapper function, where we pass the slicing axis - in which axis are we indexing the data
and we return a function working in that direction and we return a function working in that direction
""" """
slice_key = F'{slicing_axis.lower()}_slice' slice_key = F'{slicing_axis}_slice'
seg_key = F'{slicing_axis.lower()}_segmentation' seg_key = F'{slicing_axis}_segmentation'
slicing_axis_int = self.axes[slicing_axis] slicing_axis_int = AXES[slicing_axis]
def process(pos:float, segmenting_axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool): def process(pos:float, segmenting_axis:str, inverted:bool, delta:float, min_margin:int, n_layers:int, wrap:bool):
""" """
...@@ -393,7 +390,7 @@ class Interface(BaseInterface): ...@@ -393,7 +390,7 @@ class Interface(BaseInterface):
slice = self.get_slice(pos, slicing_axis_int) slice = self.get_slice(pos, slicing_axis_int)
self.__dict__[slice_key] = slice self.__dict__[slice_key] = slice
if segmenting_axis.lower() == slicing_axis.lower(): if segmenting_axis == slicing_axis:
self.__dict__[seg_key] = None self.__dict__[seg_key] = None
else: else:
...@@ -408,8 +405,8 @@ class Interface(BaseInterface): ...@@ -408,8 +405,8 @@ class Interface(BaseInterface):
Checks if the desired direction of segmentation is the same if the image would be submitted to segmentation as is. Checks if the desired direction of segmentation is the same if the image would be submitted to segmentation as is.
If it is not, we have to rotate it before we put it to segmentation algorithm If it is not, we have to rotate it before we put it to segmentation algorithm
""" """
remaining_axis = 'xyz'.replace(slicing_axis.lower(), '').replace(segmenting_axis.lower(), '') remaining_axis = F"{X}{Y}{Z}".replace(slicing_axis, '').replace(segmenting_axis, '')
return self.axes[segmenting_axis.lower()] > self.axes[remaining_axis] return AXES[segmenting_axis] > AXES[remaining_axis]
def get_slice(self, pos:float, axis:int): def get_slice(self, pos:float, axis:int):
idx = int(pos * (self.data.shape[axis] - 1)) idx = int(pos * (self.data.shape[axis] - 1))
...@@ -431,8 +428,8 @@ class Interface(BaseInterface): ...@@ -431,8 +428,8 @@ class Interface(BaseInterface):
# return x_plot, y_plot, z_plot # return x_plot, y_plot, z_plot
def plot_output_img_wrapper(self, slicing_axis:str): def plot_output_img_wrapper(self, slicing_axis:str):
slice_key = F'{slicing_axis.lower()}_slice' slice_key = F'{slicing_axis}_slice'
seg_key = F'{slicing_axis.lower()}_segmentation' seg_key = F'{slicing_axis}_segmentation'
def plot_output_img(segmenting_axis:str, alpha:float, line_thickness:float): def plot_output_img(segmenting_axis:str, alpha:float, line_thickness:float):
slice = self.__dict__[slice_key] slice = self.__dict__[slice_key]
...@@ -446,11 +443,12 @@ class Interface(BaseInterface): ...@@ -446,11 +443,12 @@ class Interface(BaseInterface):
seg = np.sum(seg, axis = 0) seg = np.sum(seg, axis = 0)
seg = np.repeat(seg[..., None], 3, axis = -1) seg = np.repeat(seg[..., None], 3, axis = -1)
for i in range(n_layers): for i in range(n_layers):
seg[seg[:,:,0] == i, :] = self.segmentation_colors[i] seg[seg[:,:,0] == i, :] = SEGMENTATION_COLORS[i]
if self.is_transposed(slicing_axis, segmenting_axis): if self.is_transposed(slicing_axis, segmenting_axis):
seg = np.rot90(seg, k = 3) seg = np.rot90(seg, k = 3)
return overlay_rgb_images(np.repeat(slice[..., None], 3, -1), seg, alpha) # slice = 255 * (slice/np.max(slice))
return image_with_overlay(np.repeat(slice[..., None], 3, -1), seg, alpha)
else: else:
lines = get_lines(seg) lines = get_lines(seg)
if self.is_transposed(slicing_axis, segmenting_axis): if self.is_transposed(slicing_axis, segmenting_axis):
...@@ -461,9 +459,9 @@ class Interface(BaseInterface): ...@@ -461,9 +459,9 @@ class Interface(BaseInterface):
return plot_output_img return plot_output_img
def plot_output_img_all(self, segmenting_axis:str, alpha:float, line_thickness:float): def plot_output_img_all(self, segmenting_axis:str, alpha:float, line_thickness:float):
x_output = self.plot_output_img_wrapper('x')(segmenting_axis, alpha, line_thickness) x_output = self.plot_output_img_wrapper(X)(segmenting_axis, alpha, line_thickness)
y_output = self.plot_output_img_wrapper('y')(segmenting_axis, alpha, line_thickness) y_output = self.plot_output_img_wrapper(Y)(segmenting_axis, alpha, line_thickness)
z_output = self.plot_output_img_wrapper('z')(segmenting_axis, alpha, line_thickness) z_output = self.plot_output_img_wrapper(Z)(segmenting_axis, alpha, line_thickness)
return x_output, y_output, z_output return x_output, y_output, z_output
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment