Skip to content
Snippets Groups Projects
Commit d8f37a54 authored by fima's avatar fima :beers:
Browse files

Merge branch 'automatic-objects-colormap' into 'main'

Automatic objects colormap

See merge request !130
parents 73ea1c27 330f4837
No related branches found
No related tags found
1 merge request!130Automatic objects colormap
...@@ -91,6 +91,9 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i ...@@ -91,6 +91,9 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i
import skimage import skimage
import scipy import scipy
if len(np.unique(bin_vol)) > 2:
raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.")
# Compute distance transform of binary volume # Compute distance transform of binary volume
distance = scipy.ndimage.distance_transform_edt(bin_vol) distance = scipy.ndimage.distance_transform_edt(bin_vol)
......
...@@ -86,11 +86,17 @@ def objects( ...@@ -86,11 +86,17 @@ def objects(
``` ```
![colormap objects](assets/screenshots/viz-colormaps-objects.gif) ![colormap objects](assets/screenshots/viz-colormaps-objects.gif)
Tip:
It can be easily used when calling visualization functions as
```python
qim3d.viz.slices(segmented_volume, cmap = 'objects')
```
which automatically detects number of unique classes
and creates the colormap object with defualt arguments.
Tip: Tip:
The `min_dist` parameter can be used to control the distance between neighboring colors. The `min_dist` parameter can be used to control the distance between neighboring colors.
![colormap objects mind_dist](assets/screenshots/viz-colormaps-min_dist.gif) ![colormap objects mind_dist](assets/screenshots/viz-colormaps-min_dist.gif)
""" """
from skimage import color from skimage import color
......
...@@ -19,7 +19,6 @@ import seaborn as sns ...@@ -19,7 +19,6 @@ import seaborn as sns
import qim3d import qim3d
def slices( def slices(
vol: np.ndarray, vol: np.ndarray,
axis: int = 0, axis: int = 0,
...@@ -33,7 +32,7 @@ def slices( ...@@ -33,7 +32,7 @@ def slices(
img_width: int = 2, img_width: int = 2,
show: bool = False, show: bool = False,
show_position: bool = True, show_position: bool = True,
interpolation: Optional[str] = "none", interpolation: Optional[str] = None,
img_size=None, img_size=None,
cbar: bool = False, cbar: bool = False,
cbar_style: str = "small", cbar_style: str = "small",
...@@ -85,6 +84,11 @@ def slices( ...@@ -85,6 +84,11 @@ def slices(
img_height = img_size img_height = img_size
img_width = img_size img_width = img_size
# If we pass python None to the imshow function, it will set to
# default value 'antialiased'
if interpolation is None:
interpolation = 'none'
# Numpy array or Torch tensor input # Numpy array or Torch tensor input
if not isinstance(vol, (np.ndarray, da.core.Array)): if not isinstance(vol, (np.ndarray, da.core.Array)):
raise ValueError("Data type not supported") raise ValueError("Data type not supported")
...@@ -107,6 +111,19 @@ def slices( ...@@ -107,6 +111,19 @@ def slices(
f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}." f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}."
) )
if type(cmap) == matplotlib.colors.LinearSegmentedColormap or cmap == 'objects':
num_labels = len(np.unique(vol))
if cmap == 'objects':
cmap = qim3d.viz.colormaps.objects(num_labels)
# If vmin and vmax are not set like this, then in case the
# number of objects changes on new slice, objects might change
# colors. So when using a slider, the same object suddently
# changes color (flickers), which is confusing and annoying.
vmin = 0
vmax = num_labels
# Get total number of slices in the specified dimension # Get total number of slices in the specified dimension
n_total = vol.shape[axis] n_total = vol.shape[axis]
...@@ -152,9 +169,10 @@ def slices( ...@@ -152,9 +169,10 @@ def slices(
vol = vol.compute() vol = vol.compute()
if cbar: if cbar:
# In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar. # In this case, we want the vrange to be constant across the
new_vmin = vmin if vmin else np.min(vol) # slices, which makes them all comparable to a single cbar.
new_vmax = vmax if vmax else np.max(vol) new_vmin = vmin if vmin is not None else np.min(vol)
new_vmax = vmax if vmax is not None else np.max(vol)
# Run through each ax of the grid # Run through each ax of the grid
for i, ax_row in enumerate(axs): for i, ax_row in enumerate(axs):
...@@ -164,8 +182,9 @@ def slices( ...@@ -164,8 +182,9 @@ def slices(
slice_img = vol.take(slice_idxs[slice_idx], axis=axis) slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
if not cbar: if not cbar:
# If vmin is higher than the highest value in the image ValueError is raised # If vmin is higher than the highest value in the
# We don't want to override the values because next slices might be okay # image ValueError is raised. We don't want to
# override the values because next slices might be okay
new_vmin = ( new_vmin = (
None None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
...@@ -277,7 +296,7 @@ def slicer( ...@@ -277,7 +296,7 @@ def slicer(
img_height: int = 3, img_height: int = 3,
img_width: int = 3, img_width: int = 3,
show_position: bool = False, show_position: bool = False,
interpolation: Optional[str] = "none", interpolation: Optional[str] = None,
img_size=None, img_size=None,
cbar: bool = False, cbar: bool = False,
**imshow_kwargs, **imshow_kwargs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment