diff --git a/qim3d/viz/_data_exploration.py b/qim3d/viz/_data_exploration.py index e03f08526b176ff61d878cf11876473c7a36577c..1d8c0175de297a066a1049c10aa817b9e440eb86 100644 --- a/qim3d/viz/_data_exploration.py +++ b/qim3d/viz/_data_exploration.py @@ -327,7 +327,7 @@ def slicer( display_positions: bool = False, interpolation: Optional[str] = None, image_size: int = None, - color_bar: bool = False, + color_bar: str = None, **matplotlib_imshow_kwargs, ) -> widgets.interactive: """Interactive widget for visualizing slices of a 3D volume. @@ -342,7 +342,7 @@ def slicer( image_width (int, optional): Width of the figure. Defaults to 3. display_positions (bool, optional): If True, displays the position of the slices. Defaults to False. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. - color_bar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False. + color_bar (str, optional): Controls the options for color bar. If None, no color bar is included. If 'volume', the color map range is constant for each slice. If 'slices', the color map range changes dynamically according to the slice. Defaults to None. Returns: slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume. @@ -361,14 +361,34 @@ def slicer( image_height = image_size image_width = image_size + color_bar_options = [None, 'slices', 'volume'] + if color_bar not in color_bar_options: + raise ValueError( + f"Unrecognized value '{color_bar}' for parameter color_bar. " + f"Expected one of {color_bar_options}." + ) + show_color_bar = color_bar is not None + if color_bar == 'slices': + # Precompute the minimum and maximum along each slice for faster widget sliding. + non_slice_axes = tuple(i for i in range(volume.ndim) if i != slice_axis) + slice_mins = np.min(volume, axis=non_slice_axes) + slice_maxs = np.max(volume, axis=non_slice_axes) + # Create the interactive widget def _slicer(slice_positions): + if color_bar == 'slices': + dynamic_min = slice_mins[slice_positions] + dynamic_max = slice_maxs[slice_positions] + else: + dynamic_min = value_min + dynamic_max = value_max + fig = slices_grid( volume, slice_axis=slice_axis, color_map=color_map, - value_min=value_min, - value_max=value_max, + value_min=dynamic_min, + value_max=dynamic_max, image_height=image_height, image_width=image_width, display_positions=display_positions, @@ -376,7 +396,7 @@ def slicer( slice_positions=slice_positions, num_slices=1, display_figure=True, - color_bar=color_bar, + color_bar=show_color_bar, **matplotlib_imshow_kwargs, ) return fig