Skip to content
Snippets Groups Projects
Commit bdc2c8d7 authored by fima's avatar fima :beers:
Browse files

Merge branch 'viz-explore-fix' into 'main'

Added vmin and vmax arguments to viz functions

See merge request !120
parents 053f9395 8fbed9e3
No related branches found
No related tags found
1 merge request!120Added vmin and vmax arguments to viz functions
...@@ -12,6 +12,9 @@ def plot_cc( ...@@ -12,6 +12,9 @@ def plot_cc(
overlay=None, overlay=None,
crop=False, crop=False,
show=True, show=True,
cmap:str = 'viridis',
vmin:float = None,
vmax:float = None,
**kwargs, **kwargs,
) -> list[plt.Figure]: ) -> list[plt.Figure]:
""" """
...@@ -24,6 +27,9 @@ def plot_cc( ...@@ -24,6 +27,9 @@ def plot_cc(
overlay (optional): Overlay image. Defaults to None. overlay (optional): Overlay image. Defaults to None.
crop (bool, optional): Whether to crop the image to the cc. Defaults to False. crop (bool, optional): Whether to crop the image to the cc. Defaults to False.
show (bool, optional): Whether to show the figure. Defaults to True. show (bool, optional): Whether to show the figure. Defaults to True.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
**kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`. **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`.
Returns: Returns:
...@@ -66,11 +72,10 @@ def plot_cc( ...@@ -66,11 +72,10 @@ def plot_cc(
overlay_crop = overlay[bb] overlay_crop = overlay[bb]
# use cc as mask for overlay_crop, where all values in cc set to 0 should be masked out, cc contains integers # use cc as mask for overlay_crop, where all values in cc set to 0 should be masked out, cc contains integers
overlay_crop = np.where(cc == 0, 0, overlay_crop) overlay_crop = np.where(cc == 0, 0, overlay_crop)
fig = qim3d.viz.slices(overlay_crop, show=show, **kwargs)
else: else:
cc = connected_components.get_cc(component, crop=False) cc = connected_components.get_cc(component, crop=False)
overlay_crop = np.where(cc == 0, 0, overlay) overlay_crop = np.where(cc == 0, 0, overlay)
fig = qim3d.viz.slices(overlay_crop, show=show, **kwargs) fig = qim3d.viz.slices(overlay_crop, show=show, cmap = cmap, vmin = vmin, vmax = vmax, **kwargs)
else: else:
# assigns discrete color map to each connected component if not given # assigns discrete color map to each connected component if not given
if "cmap" not in kwargs: if "cmap" not in kwargs:
......
...@@ -20,6 +20,8 @@ def slices( ...@@ -20,6 +20,8 @@ def slices(
n_slices: int = 5, n_slices: int = 5,
max_cols: int = 5, max_cols: int = 5,
cmap: str = "viridis", cmap: str = "viridis",
vmin:float = None,
vmax:float = None,
img_height: int = 2, img_height: int = 2,
img_width: int = 2, img_width: int = 2,
show: bool = False, show: bool = False,
...@@ -41,6 +43,8 @@ def slices( ...@@ -41,6 +43,8 @@ def slices(
n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5. n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5.
max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5. max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure. img_height (int, optional): Height of the figure.
img_width (int, optional): Width of the figure. img_width (int, optional): Width of the figure.
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
...@@ -136,8 +140,13 @@ def slices( ...@@ -136,8 +140,13 @@ def slices(
slice_idx = i * max_cols + j slice_idx = i * max_cols + j
try: try:
slice_img = vol.take(slice_idxs[slice_idx], axis=axis) slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
ax.imshow( ax.imshow(
slice_img, cmap=cmap, interpolation=interpolation, **imshow_kwargs slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs
) )
if show_position: if show_position:
...@@ -200,6 +209,8 @@ def slicer( ...@@ -200,6 +209,8 @@ def slicer(
vol: np.ndarray, vol: np.ndarray,
axis: int = 0, axis: int = 0,
cmap: str = "viridis", cmap: str = "viridis",
vmin:float = None,
vmax:float = None,
img_height: int = 3, img_height: int = 3,
img_width: int = 3, img_width: int = 3,
show_position: bool = False, show_position: bool = False,
...@@ -213,6 +224,8 @@ def slicer( ...@@ -213,6 +224,8 @@ def slicer(
vol (np.ndarray): The 3D volume to be sliced. vol (np.ndarray): The 3D volume to be sliced.
axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure. Defaults to 3. img_height (int, optional): Height of the figure. Defaults to 3.
img_width (int, optional): Width of the figure. Defaults to 3. img_width (int, optional): Width of the figure. Defaults to 3.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False. show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
...@@ -241,6 +254,8 @@ def slicer( ...@@ -241,6 +254,8 @@ def slicer(
vol, vol,
axis=axis, axis=axis,
cmap=cmap, cmap=cmap,
vmin = vmin,
vmax = vmax,
img_height=img_height, img_height=img_height,
img_width=img_width, img_width=img_width,
show_position=show_position, show_position=show_position,
...@@ -268,6 +283,8 @@ def slicer( ...@@ -268,6 +283,8 @@ def slicer(
def orthogonal( def orthogonal(
vol: np.ndarray, vol: np.ndarray,
cmap: str = "viridis", cmap: str = "viridis",
vmin:float = None,
vmax:float = None,
img_height: int = 3, img_height: int = 3,
img_width: int = 3, img_width: int = 3,
show_position: bool = False, show_position: bool = False,
...@@ -279,6 +296,8 @@ def orthogonal( ...@@ -279,6 +296,8 @@ def orthogonal(
Args: Args:
vol (np.ndarray): The 3D volume to be sliced. vol (np.ndarray): The 3D volume to be sliced.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure. img_height (int, optional): Height of the figure.
img_width (int, optional): Width of the figure. img_width (int, optional): Width of the figure.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False. show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
...@@ -301,34 +320,22 @@ def orthogonal( ...@@ -301,34 +320,22 @@ def orthogonal(
img_height = img_size img_height = img_size
img_width = img_size img_width = img_size
z_slicer = slicer( get_slicer_for_axis = lambda axis: slicer(
vol,
axis=0,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
y_slicer = slicer(
vol, vol,
axis=1, axis = axis,
cmap=cmap,
img_height=img_height,
img_width=img_width,
show_position=show_position,
interpolation=interpolation,
)
x_slicer = slicer(
vol,
axis=2,
cmap = cmap, cmap = cmap,
vmin = vmin,
vmax = vmax,
img_height = img_height, img_height = img_height,
img_width = img_width, img_width = img_width,
show_position = show_position, show_position = show_position,
interpolation = interpolation, interpolation = interpolation,
) )
z_slicer = get_slicer_for_axis(axis = 0)
y_slicer = get_slicer_for_axis(axis = 1)
x_slicer = get_slicer_for_axis(axis = 2)
z_slicer.children[0].description = "Z" z_slicer.children[0].description = "Z"
y_slicer.children[0].description = "Y" y_slicer.children[0].description = "Y"
x_slicer.children[0].description = "X" x_slicer.children[0].description = "X"
...@@ -336,7 +343,7 @@ def orthogonal( ...@@ -336,7 +343,7 @@ def orthogonal(
return widgets.HBox([z_slicer, y_slicer, x_slicer]) return widgets.HBox([z_slicer, y_slicer, x_slicer])
def interactive_fade_mask(vol: np.ndarray, axis: int = 0): def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', vmin:float = None, vmax:float = None):
"""Interactive widget for visualizing the effect of edge fading on a 3D volume. """Interactive widget for visualizing the effect of edge fading on a 3D volume.
This can be used to select the best parameters before applying the mask. This can be used to select the best parameters before applying the mask.
...@@ -344,6 +351,9 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0): ...@@ -344,6 +351,9 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0):
Args: Args:
vol (np.ndarray): The volume to apply edge fading to. vol (np.ndarray): The volume to apply edge fading to.
axis (int, optional): The axis along which to apply the fading. Defaults to 0. axis (int, optional): The axis along which to apply the fading. Defaults to 0.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
Example: Example:
```python ```python
...@@ -359,7 +369,13 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0): ...@@ -359,7 +369,13 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0):
def _slicer(position, decay_rate, ratio, geometry, invert): def _slicer(position, decay_rate, ratio, geometry, invert):
fig, axes = plt.subplots(1, 3, figsize=(9, 3)) fig, axes = plt.subplots(1, 3, figsize=(9, 3))
axes[0].imshow(vol[position, :, :], cmap="viridis") slice_img = vol[position, :, :]
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
axes[0].imshow(slice_img, cmap=cmap, vmin = new_vmin, vmax = new_vmax)
axes[0].set_title("Original") axes[0].set_title("Original")
axes[0].axis("off") axes[0].axis("off")
...@@ -371,7 +387,7 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0): ...@@ -371,7 +387,7 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0):
axis=axis, axis=axis,
invert=invert, invert=invert,
) )
axes[1].imshow(mask[position, :, :], cmap="viridis") axes[1].imshow(mask[position, :, :], cmap=cmap)
axes[1].set_title("Mask") axes[1].set_title("Mask")
axes[1].axis("off") axes[1].axis("off")
...@@ -383,15 +399,20 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0): ...@@ -383,15 +399,20 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0):
axis=axis, axis=axis,
invert=invert, invert=invert,
) )
axes[2].imshow(masked_vol[position, :, :], cmap="viridis") # If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
slice_img = masked_vol[position, :, :]
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
axes[2].imshow(slice_img, cmap=cmap, vmin = new_vmin, vmax = new_vmax)
axes[2].set_title("Masked") axes[2].set_title("Masked")
axes[2].axis("off") axes[2].axis("off")
return fig return fig
shape_dropdown = widgets.Dropdown( shape_dropdown = widgets.Dropdown(
options=["sphere", "cilinder"], options=["spherical", "cylindrical"],
value="sphere", # default value value="spherical", # default value
description="Geometry", description="Geometry",
) )
......
...@@ -14,13 +14,13 @@ from qim3d.utils.misc import downscale_img, scale_to_float16 ...@@ -14,13 +14,13 @@ from qim3d.utils.misc import downscale_img, scale_to_float16
def vol( def vol(
img, img,
vmin=None,
vmax=None,
aspectmode="data", aspectmode="data",
show=True, show=True,
save=False, save=False,
grid_visible=False, grid_visible=False,
cmap=None, cmap=None,
vmin=None,
vmax=None,
samples="auto", samples="auto",
max_voxels=512**3, max_voxels=512**3,
data_type="scaled_float16", data_type="scaled_float16",
...@@ -41,8 +41,12 @@ def vol( ...@@ -41,8 +41,12 @@ def vol(
file will be saved. Defaults to False. file will be saved. Defaults to False.
grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False. grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
cmap (list, optional): The color map to be used for the volume rendering. Defaults to None. cmap (list, optional): The color map to be used for the volume rendering. Defaults to None.
vmin (float, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None
samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512. samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512.
Lower values will render faster but with lower quality. Lower values will render faster but with lower quality.
max_voxels (int, optional): Defaults to 512^3.
data_type (str, optional): Default to 'scaled_float16'.
**kwargs: Additional keyword arguments to be passed to the `k3d.plot` function. **kwargs: Additional keyword arguments to be passed to the `k3d.plot` function.
Returns: Returns:
......
...@@ -18,6 +18,9 @@ def vectors( ...@@ -18,6 +18,9 @@ def vectors(
volume: np.ndarray, volume: np.ndarray,
vec: np.ndarray, vec: np.ndarray,
axis: int = 0, axis: int = 0,
volume_cmap:str = 'grey',
vmin:float = None,
vmax:float = None,
slice_idx: Optional[Union[int, float]] = None, slice_idx: Optional[Union[int, float]] = None,
grid_size: int = 10, grid_size: int = 10,
interactive: bool = True, interactive: bool = True,
...@@ -31,6 +34,9 @@ def vectors( ...@@ -31,6 +34,9 @@ def vectors(
volume (np.ndarray): The 3D volume to be sliced. volume (np.ndarray): The 3D volume to be sliced.
vec (np.ndarray): The eigenvectors of the structure tensor. vec (np.ndarray): The eigenvectors of the structure tensor.
axis (int, optional): The axis along which to visualize the orientation. Defaults to 0. axis (int, optional): The axis along which to visualize the orientation. Defaults to 0.
volume_cmap (str, optional): Defines colormap for display of the volume
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
slice_idx (int or float, optional): The initial slice to be visualized. The slice index slice_idx (int or float, optional): The initial slice to be visualized. The slice index
can afterwards be changed. If value is an integer, it will be the index of the slice can afterwards be changed. If value is an integer, it will be the index of the slice
to be visualized. If value is a float between 0 and 1, it will be multiplied by the to be visualized. If value is a float between 0 and 1, it will be multiplied by the
...@@ -169,7 +175,7 @@ def vectors( ...@@ -169,7 +175,7 @@ def vectors(
angles="xy", angles="xy",
) )
ax[0].imshow(data_slice, cmap=plt.cm.gray) ax[0].imshow(data_slice, cmap = volume_cmap, vmin = vmin, vmax = vmax)
ax[0].set_title( ax[0].set_title(
f"Orientation vectors (slice {slice_idx})" f"Orientation vectors (slice {slice_idx})"
if not interactive if not interactive
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment