Skip to content
Snippets Groups Projects
Commit 67d42af8 authored by s204159's avatar s204159 :sunglasses: Committed by fima
Browse files

Rand cmap

parent 7466989f
No related branches found
No related tags found
1 merge request!64Rand cmap
...@@ -3,6 +3,6 @@ from . import doi, internal_tools ...@@ -3,6 +3,6 @@ from . import doi, internal_tools
from .augmentations import Augmentation from .augmentations import Augmentation
from .cc import get_3d_cc from .cc import get_3d_cc
from .data import Dataset, prepare_dataloaders, prepare_datasets from .data import Dataset, prepare_dataloaders, prepare_datasets
from .img import overlay_rgb_images
from .models import inference, model_summary, train_model from .models import inference, model_summary, train_model
from .system import Memory from .system import Memory
from .img import overlay_rgb_images
from .visualizations import plot_metrics from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness
from .k3d import vol from .k3d import vol
from .colormaps import objects
from .detection import circles from .detection import circles
import colorsys
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from qim3d.io.logger import log
def objects(
nlabels,
style="bright",
first_color_background=True,
last_color_background=False,
background_color=(0.0, 0.0, 0.0),
seed=19,
):
"""
Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
Args:
nlabels (int): Number of labels (size of colormap)
style (str, optional): 'bright' for strong colors, 'soft' for pastel colors. Defaults to 'bright'.
first_color_background (bool, optional): Option to use first color as background. Defaults to True.
last_color_background (bool, optional): Option to use last color as background. Defaults to False.
seed (int, optional): Seed for random number generator. Defaults to 19.
Returns:
matplotlib.colors.LinearSegmentedColormap: Colormap for matplotlib
"""
# Check style
if style not in ("bright", "soft"):
raise ValueError(
f'Please choose "bright" or "soft" for style in qim3dCmap not "{style}"'
)
# Translate strings to background color
color_dict = {"black": (0.0, 0.0, 0.0), "white": (1.0, 1.0, 1.0)}
if not isinstance(background_color, tuple):
try:
background_color = color_dict[background_color]
except KeyError:
raise ValueError(
f'Invalid color name "{background_color}". Please choose from {list(color_dict.keys())}.'
)
# Add one to nlabels to include the background color
nlabels += 1
# Create a new random generator, to locally set seed
rng = np.random.default_rng(seed)
# Generate color map for bright colors, based on hsv
if style == "bright":
randHSVcolors = [
(
rng.uniform(low=0.0, high=1),
rng.uniform(low=0.4, high=1),
rng.uniform(low=0.9, high=1),
)
for i in range(nlabels)
]
# Convert HSV list to RGB
randRGBcolors = []
for HSVcolor in randHSVcolors:
randRGBcolors.append(
colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])
)
# Generate soft pastel colors, by limiting the RGB spectrum
if style == "soft":
low = 0.6
high = 0.95
randRGBcolors = [
(
rng.uniform(low=low, high=high),
rng.uniform(low=low, high=high),
rng.uniform(low=low, high=high),
)
for i in range(nlabels)
]
# Set first and last color to background
if first_color_background:
randRGBcolors[0] = background_color
if last_color_background:
randRGBcolors[-1] = background_color
# Create colormap
objects_cmap = LinearSegmentedColormap.from_list(
"objects_cmap", randRGBcolors, N=nlabels
)
return objects_cmap
...@@ -4,6 +4,8 @@ Provides a collection of visualization functions. ...@@ -4,6 +4,8 @@ Provides a collection of visualization functions.
import math import math
from typing import List, Optional, Union, Tuple from typing import List, Optional, Union, Tuple
import ipywidgets as widgets
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
...@@ -11,9 +13,9 @@ from matplotlib import colormaps ...@@ -11,9 +13,9 @@ from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
import qim3d.io import qim3d.io
import ipywidgets as widgets
from qim3d.io.logger import log from qim3d.io.logger import log
from qim3d.utils.cc import CC from qim3d.utils.cc import CC
from qim3d.viz.colormaps import objects
def grid_overview( def grid_overview(
...@@ -231,6 +233,7 @@ def slices( ...@@ -231,6 +233,7 @@ def slices(
show: bool = False, show: bool = False,
show_position: bool = True, show_position: bool = True,
interpolation: Optional[str] = "none", interpolation: Optional[str] = "none",
**imshow_kwargs,
) -> plt.Figure: ) -> plt.Figure:
"""Displays one or several slices from a 3d volume. """Displays one or several slices from a 3d volume.
...@@ -334,7 +337,7 @@ def slices( ...@@ -334,7 +337,7 @@ def slices(
slice_idx = i * max_cols + j slice_idx = i * max_cols + j
try: try:
slice_img = vol.take(slice_idxs[slice_idx], axis=axis) slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
ax.imshow(slice_img, cmap=cmap, interpolation=interpolation) ax.imshow(slice_img, cmap=cmap, interpolation=interpolation, **imshow_kwargs)
if show_position: if show_position:
ax.text( ax.text(
...@@ -400,6 +403,7 @@ def slicer( ...@@ -400,6 +403,7 @@ def slicer(
img_width: int = 3, img_width: int = 3,
show_position: bool = False, show_position: bool = False,
interpolation: Optional[str] = "none", interpolation: Optional[str] = "none",
**imshow_kwargs,
) -> widgets.interactive: ) -> widgets.interactive:
"""Interactive widget for visualizing slices of a 3D volume. """Interactive widget for visualizing slices of a 3D volume.
...@@ -437,6 +441,7 @@ def slicer( ...@@ -437,6 +441,7 @@ def slicer(
position=position, position=position,
n_slices=1, n_slices=1,
show=True, show=True,
**imshow_kwargs,
) )
return fig return fig
...@@ -535,14 +540,13 @@ def plot_cc( ...@@ -535,14 +540,13 @@ def plot_cc(
components (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None. components (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.
max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32. max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32.
overlay (optional): Overlay image. Defaults to None. overlay (optional): Overlay image. Defaults to None.
crop (bool, optional): Whether to crop the overlay image. Defaults to False. crop (bool, optional): Whether to crop the image to the cc. Defaults to False.
show (bool, optional): Whether to show the figure. Defaults to True. show (bool, optional): Whether to show the figure. Defaults to True.
**kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`. **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`.
Returns: Returns:
figs (list[plt.Figure]): List of figures, if `show=False`. figs (list[plt.Figure]): List of figures, if `show=False`.
""" """
figs = []
# if no components are given, plot the first max_cc_to_plot=32 components # if no components are given, plot the first max_cc_to_plot=32 components
if component_indexs is None: if component_indexs is None:
if len(connected_components) > max_cc_to_plot: if len(connected_components) > max_cc_to_plot:
...@@ -553,11 +557,10 @@ def plot_cc( ...@@ -553,11 +557,10 @@ def plot_cc(
1, min(max_cc_to_plot + 1, len(connected_components) + 1) 1, min(max_cc_to_plot + 1, len(connected_components) + 1)
) )
figs = []
for component in component_indexs: for component in component_indexs:
if overlay is not None: if overlay is not None:
assert ( assert (overlay.shape == connected_components.shape), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."
overlay.shape == connected_components.shape
), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."
# plots overlay masked to connected component # plots overlay masked to connected component
if crop: if crop:
...@@ -573,10 +576,12 @@ def plot_cc( ...@@ -573,10 +576,12 @@ def plot_cc(
overlay_crop = np.where(cc == 0, 0, overlay) overlay_crop = np.where(cc == 0, 0, overlay)
fig = slices(overlay_crop, show=show, **kwargs) fig = slices(overlay_crop, show=show, **kwargs)
else: else:
# assigns discrete color map to each connected component if not given
if "cmap" not in kwargs:
kwargs["cmap"] = qim3dCmap(len(component_indexs))
# Plot the connected component without overlay # Plot the connected component without overlay
fig = slices( fig = slices(connected_components.get_cc(component, crop=crop), show=show, **kwargs)
connected_components.get_cc(component, crop=crop), show=show, **kwargs
)
figs.append(fig) figs.append(fig)
......
...@@ -11,7 +11,7 @@ import k3d ...@@ -11,7 +11,7 @@ import k3d
import numpy as np import numpy as np
def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwargs): def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=None, **kwargs):
""" """
Visualizes a 3D volume using volumetric rendering. Visualizes a 3D volume using volumetric rendering.
...@@ -62,6 +62,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwa ...@@ -62,6 +62,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwa
if aspectmode.lower() == "data" if aspectmode.lower() == "data"
else None else None
), ),
color_map=cmap,
) )
plot = k3d.plot(grid_visible=grid_visible, **kwargs) plot = k3d.plot(grid_visible=grid_visible, **kwargs)
plot += plt_volume plot += plt_volume
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment