From 06f92e02c79896ff13d9c82e9199291df370233e Mon Sep 17 00:00:00 2001
From: s214735 <s214735@student.dtu.dk>
Date: Wed, 20 Nov 2024 13:57:24 +0100
Subject: [PATCH] Added cbar_style as option for slices()

---
 qim3d/viz/explore.py | 32 +++++++++++++++++++++++++-------
 1 file changed, 25 insertions(+), 7 deletions(-)

diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py
index 23425112..0526c95b 100644
--- a/qim3d/viz/explore.py
+++ b/qim3d/viz/explore.py
@@ -36,6 +36,7 @@ def slices(
     interpolation: Optional[str] = "none",
     img_size=None,
     cbar: bool = False,
+    cbar_style: str = "small",
     **imshow_kwargs,
 ) -> plt.Figure:
     """Displays one or several slices from a 3d volume.
@@ -59,6 +60,7 @@ def slices(
         show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
         interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
         cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
+        cbar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'.
 
     Returns:
         fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
@@ -68,6 +70,7 @@ def slices(
         ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
         ValueError: If the file or array is not a volume with at least 3 dimensions.
         ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".
+        ValueError: If the cbar_style keyword argument is not one of the following strings: 'small' or 'large'.
 
     Example:
         ```python
@@ -91,6 +94,10 @@ def slices(
             "The provided object is not a volume as it has less than 3 dimensions."
         )
 
+    cbar_style_options = ["small", "large"]
+    if cbar_style not in cbar_style_options:
+        raise ValueError(f"Value '{cbar_style}' is not valid for colorbar style. Please select from {cbar_style_options}.")
+    
     if isinstance(vol, da.core.Array):
         vol = vol.compute()
 
@@ -215,16 +222,27 @@ def slices(
         with warnings.catch_warnings():
             warnings.simplefilter("ignore", category=UserWarning)
             fig.tight_layout()
+
         norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True)
         mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
 
-        # Figure coordinates of top-right axis
-        tr_pos = np.atleast_1d(axs[0])[-1].get_position()
-        # The width is divided by ncols to make it the same relative size to the images
-        cbar_ax = fig.add_axes(
-            [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
-        )
-        fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
+        if cbar_style =="small":
+            # Figure coordinates of top-right axis
+            tr_pos = np.atleast_1d(axs[0])[-1].get_position()
+            # The width is divided by ncols to make it the same relative size to the images
+            cbar_ax = fig.add_axes(
+                [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
+            )
+            fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
+        elif cbar_style == "large":
+            # Figure coordinates of bottom- and top-right axis
+            br_pos = np.atleast_1d(axs[-1])[-1].get_position()
+            tr_pos = np.atleast_1d(axs[0])[-1].get_position()
+            # The width is divided by ncols to make it the same relative size to the images
+            cbar_ax = fig.add_axes(
+                [br_pos.xmax + 0.05 / ncols, br_pos.y0+0.0015, 0.05 / ncols, (tr_pos.y1 - br_pos.y0)-0.0015]
+            )
+            fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
 
     if show:
         plt.show()
-- 
GitLab