Skip to content
Snippets Groups Projects
Commit 330f4837 authored by s233039's avatar s233039 Committed by fima
Browse files

Automatic objects colormap

parent 73ea1c27
No related branches found
No related tags found
1 merge request!130Automatic objects colormap
......@@ -91,6 +91,9 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i
import skimage
import scipy
if len(np.unique(bin_vol)) > 2:
raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.")
# Compute distance transform of binary volume
distance = scipy.ndimage.distance_transform_edt(bin_vol)
......
......@@ -86,11 +86,17 @@ def objects(
```
![colormap objects](assets/screenshots/viz-colormaps-objects.gif)
Tip:
It can be easily used when calling visualization functions as
```python
qim3d.viz.slices(segmented_volume, cmap = 'objects')
```
which automatically detects number of unique classes
and creates the colormap object with defualt arguments.
Tip:
The `min_dist` parameter can be used to control the distance between neighboring colors.
![colormap objects mind_dist](assets/screenshots/viz-colormaps-min_dist.gif)
"""
from skimage import color
......
......@@ -19,7 +19,6 @@ import seaborn as sns
import qim3d
def slices(
vol: np.ndarray,
axis: int = 0,
......@@ -33,7 +32,7 @@ def slices(
img_width: int = 2,
show: bool = False,
show_position: bool = True,
interpolation: Optional[str] = "none",
interpolation: Optional[str] = None,
img_size=None,
cbar: bool = False,
cbar_style: str = "small",
......@@ -85,6 +84,11 @@ def slices(
img_height = img_size
img_width = img_size
# If we pass python None to the imshow function, it will set to
# default value 'antialiased'
if interpolation is None:
interpolation = 'none'
# Numpy array or Torch tensor input
if not isinstance(vol, (np.ndarray, da.core.Array)):
raise ValueError("Data type not supported")
......@@ -107,6 +111,19 @@ def slices(
f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}."
)
if type(cmap) == matplotlib.colors.LinearSegmentedColormap or cmap == 'objects':
num_labels = len(np.unique(vol))
if cmap == 'objects':
cmap = qim3d.viz.colormaps.objects(num_labels)
# If vmin and vmax are not set like this, then in case the
# number of objects changes on new slice, objects might change
# colors. So when using a slider, the same object suddently
# changes color (flickers), which is confusing and annoying.
vmin = 0
vmax = num_labels
# Get total number of slices in the specified dimension
n_total = vol.shape[axis]
......@@ -152,9 +169,10 @@ def slices(
vol = vol.compute()
if cbar:
# In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar.
new_vmin = vmin if vmin else np.min(vol)
new_vmax = vmax if vmax else np.max(vol)
# In this case, we want the vrange to be constant across the
# slices, which makes them all comparable to a single cbar.
new_vmin = vmin if vmin is not None else np.min(vol)
new_vmax = vmax if vmax is not None else np.max(vol)
# Run through each ax of the grid
for i, ax_row in enumerate(axs):
......@@ -164,8 +182,9 @@ def slices(
slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
if not cbar:
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
# If vmin is higher than the highest value in the
# image ValueError is raised. We don't want to
# override the values because next slices might be okay
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
......@@ -277,7 +296,7 @@ def slicer(
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
interpolation: Optional[str] = "none",
interpolation: Optional[str] = None,
img_size=None,
cbar: bool = False,
**imshow_kwargs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment