diff --git a/qim3d/utils/__init__.py b/qim3d/utils/__init__.py index 7361d8cdd03af2d818a464d498ac6c854946c7b6..9b08365be27631d622fbf6edfb29014330885fc9 100644 --- a/qim3d/utils/__init__.py +++ b/qim3d/utils/__init__.py @@ -3,6 +3,6 @@ from . import doi, internal_tools from .augmentations import Augmentation from .cc import get_3d_cc from .data import Dataset, prepare_dataloaders, prepare_datasets +from .img import overlay_rgb_images from .models import inference, model_summary, train_model from .system import Memory -from .img import overlay_rgb_images diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index a641174d1fba75ebc66847d328e92e956b3df8c0..b770c3d0a51477bc17a241812a32f18aeba12657 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,4 +1,5 @@ from .visualizations import plot_metrics from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness from .k3d import vol -from .detection import circles \ No newline at end of file +from .colormaps import objects +from .detection import circles diff --git a/qim3d/viz/colormaps.py b/qim3d/viz/colormaps.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ee1b5f7f2708ee376574b896758dff8d8f1939 --- /dev/null +++ b/qim3d/viz/colormaps.py @@ -0,0 +1,95 @@ +import colorsys + +import numpy as np +from matplotlib.colors import LinearSegmentedColormap + +from qim3d.io.logger import log + + +def objects( + nlabels, + style="bright", + first_color_background=True, + last_color_background=False, + background_color=(0.0, 0.0, 0.0), + seed=19, +): + """ + Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks + + Args: + nlabels (int): Number of labels (size of colormap) + style (str, optional): 'bright' for strong colors, 'soft' for pastel colors. Defaults to 'bright'. + first_color_background (bool, optional): Option to use first color as background. Defaults to True. + last_color_background (bool, optional): Option to use last color as background. Defaults to False. + seed (int, optional): Seed for random number generator. Defaults to 19. + + Returns: + matplotlib.colors.LinearSegmentedColormap: Colormap for matplotlib + """ + # Check style + if style not in ("bright", "soft"): + raise ValueError( + f'Please choose "bright" or "soft" for style in qim3dCmap not "{style}"' + ) + + # Translate strings to background color + color_dict = {"black": (0.0, 0.0, 0.0), "white": (1.0, 1.0, 1.0)} + if not isinstance(background_color, tuple): + try: + background_color = color_dict[background_color] + except KeyError: + raise ValueError( + f'Invalid color name "{background_color}". Please choose from {list(color_dict.keys())}.' + ) + + # Add one to nlabels to include the background color + nlabels += 1 + + # Create a new random generator, to locally set seed + rng = np.random.default_rng(seed) + + # Generate color map for bright colors, based on hsv + if style == "bright": + randHSVcolors = [ + ( + rng.uniform(low=0.0, high=1), + rng.uniform(low=0.4, high=1), + rng.uniform(low=0.9, high=1), + ) + for i in range(nlabels) + ] + + # Convert HSV list to RGB + randRGBcolors = [] + for HSVcolor in randHSVcolors: + randRGBcolors.append( + colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]) + ) + + # Generate soft pastel colors, by limiting the RGB spectrum + if style == "soft": + low = 0.6 + high = 0.95 + randRGBcolors = [ + ( + rng.uniform(low=low, high=high), + rng.uniform(low=low, high=high), + rng.uniform(low=low, high=high), + ) + for i in range(nlabels) + ] + + # Set first and last color to background + if first_color_background: + randRGBcolors[0] = background_color + + if last_color_background: + randRGBcolors[-1] = background_color + + # Create colormap + objects_cmap = LinearSegmentedColormap.from_list( + "objects_cmap", randRGBcolors, N=nlabels + ) + + return objects_cmap diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index 812b2028621ff508a638b78b009d4fc69a7fd139..d3cf0bd4b9eb4fe7bd6224a8b00da126b11752aa 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -4,6 +4,8 @@ Provides a collection of visualization functions. import math from typing import List, Optional, Union, Tuple + +import ipywidgets as widgets import matplotlib.pyplot as plt import numpy as np import torch @@ -11,9 +13,9 @@ from matplotlib import colormaps from matplotlib.colors import LinearSegmentedColormap import qim3d.io -import ipywidgets as widgets from qim3d.io.logger import log from qim3d.utils.cc import CC +from qim3d.viz.colormaps import objects def grid_overview( @@ -231,6 +233,7 @@ def slices( show: bool = False, show_position: bool = True, interpolation: Optional[str] = "none", + **imshow_kwargs, ) -> plt.Figure: """Displays one or several slices from a 3d volume. @@ -334,7 +337,7 @@ def slices( slice_idx = i * max_cols + j try: slice_img = vol.take(slice_idxs[slice_idx], axis=axis) - ax.imshow(slice_img, cmap=cmap, interpolation=interpolation) + ax.imshow(slice_img, cmap=cmap, interpolation=interpolation, **imshow_kwargs) if show_position: ax.text( @@ -400,6 +403,7 @@ def slicer( img_width: int = 3, show_position: bool = False, interpolation: Optional[str] = "none", + **imshow_kwargs, ) -> widgets.interactive: """Interactive widget for visualizing slices of a 3D volume. @@ -437,6 +441,7 @@ def slicer( position=position, n_slices=1, show=True, + **imshow_kwargs, ) return fig @@ -535,14 +540,13 @@ def plot_cc( components (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None. max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32. overlay (optional): Overlay image. Defaults to None. - crop (bool, optional): Whether to crop the overlay image. Defaults to False. + crop (bool, optional): Whether to crop the image to the cc. Defaults to False. show (bool, optional): Whether to show the figure. Defaults to True. **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`. Returns: figs (list[plt.Figure]): List of figures, if `show=False`. """ - figs = [] # if no components are given, plot the first max_cc_to_plot=32 components if component_indexs is None: if len(connected_components) > max_cc_to_plot: @@ -552,12 +556,11 @@ def plot_cc( component_indexs = range( 1, min(max_cc_to_plot + 1, len(connected_components) + 1) ) - + + figs = [] for component in component_indexs: if overlay is not None: - assert ( - overlay.shape == connected_components.shape - ), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}." + assert (overlay.shape == connected_components.shape), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}." # plots overlay masked to connected component if crop: @@ -573,10 +576,12 @@ def plot_cc( overlay_crop = np.where(cc == 0, 0, overlay) fig = slices(overlay_crop, show=show, **kwargs) else: + # assigns discrete color map to each connected component if not given + if "cmap" not in kwargs: + kwargs["cmap"] = qim3dCmap(len(component_indexs)) + # Plot the connected component without overlay - fig = slices( - connected_components.get_cc(component, crop=crop), show=show, **kwargs - ) + fig = slices(connected_components.get_cc(component, crop=crop), show=show, **kwargs) figs.append(fig) diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py index c0759d92a7f7c686cc88519acaee9699594b2d90..d5f3ebbd6c9e236459df1348fa6bbe03ad1a5d56 100644 --- a/qim3d/viz/k3d.py +++ b/qim3d/viz/k3d.py @@ -11,15 +11,15 @@ import k3d import numpy as np -def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwargs): +def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=None, **kwargs): """ Visualizes a 3D volume using volumetric rendering. Args: img (numpy.ndarray): The input 3D image data. It should be a 3D numpy array. - aspectmode (str, optional): Determines the proportions of the scene's axes. - If "data", the axes are drawn in proportion with the axes' ranges. - If "cube", the axes are drawn as a cube, regardless of the axes' ranges. + aspectmode (str, optional): Determines the proportions of the scene's axes. + If "data", the axes are drawn in proportion with the axes' ranges. + If "cube", the axes are drawn as a cube, regardless of the axes' ranges. Defaults to "data". show (bool, optional): If True, displays the visualization inline. Defaults to True. save (bool or str, optional): If True, saves the visualization as an HTML file. @@ -62,6 +62,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwa if aspectmode.lower() == "data" else None ), + color_map=cmap, ) plot = k3d.plot(grid_visible=grid_visible, **kwargs) plot += plt_volume