Skip to content
Snippets Groups Projects
Commit 9278a1ef authored by fima's avatar fima :beers:
Browse files

Hotfix: Torch dependency

parent aff52c48
No related branches found
No related tags found
No related merge requests found
...@@ -6,14 +6,14 @@ import matplotlib.pyplot as plt ...@@ -6,14 +6,14 @@ import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps from matplotlib import colormaps
from qim3d.utils._logger import log from qim3d.utils._logger import log
import torch
import matplotlib import matplotlib
def plot_metrics( def plot_metrics(
*metrics: tuple[dict[str, float]], *metrics: tuple[dict[str, float]],
linestyle: str = "-", linestyle: str = "-",
batch_linestyle: str = "dotted", batch_linestyle: str = "dotted",
labels: list|None = None, labels: list | None = None,
figsize: tuple = (16, 6), figsize: tuple = (16, 6),
show: bool = False show: bool = False
): ):
...@@ -81,13 +81,13 @@ def plot_metrics( ...@@ -81,13 +81,13 @@ def plot_metrics(
def grid_overview( def grid_overview(
data: list|torch.utils.data.Dataset, data: list,
num_images: int = 7, num_images: int = 7,
cmap_im: str = "gray", cmap_im: str = "gray",
cmap_segm: str = "viridis", cmap_segm: str = "viridis",
alpha: float = 0.5, alpha: float = 0.5,
show: bool = False show: bool = False,
)-> matplotlib.figure.Figure: ) -> matplotlib.figure.Figure:
"""Displays an overview grid of images, labels, and masks (if they exist). """Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations Labels are the annotated target segmentations
...@@ -121,6 +121,7 @@ def grid_overview( ...@@ -121,6 +121,7 @@ def grid_overview(
and the length of the data. and the length of the data.
- The grid layout and dimensions vary based on the presence of a mask. - The grid layout and dimensions vary based on the presence of a mask.
""" """
import torch
# Check if data has a mask # Check if data has a mask
has_mask = len(data[0]) > 2 and data[0][-1] is not None has_mask = len(data[0]) > 2 and data[0][-1] is not None
...@@ -187,7 +188,7 @@ def grid_pred( ...@@ -187,7 +188,7 @@ def grid_pred(
cmap_segm: str = "viridis", cmap_segm: str = "viridis",
alpha: float = 0.5, alpha: float = 0.5,
show: bool = False, show: bool = False,
)-> matplotlib.figure.Figure: ) -> matplotlib.figure.Figure:
"""Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison. """Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Displays a grid of subplots representing different aspects of the input images and segmentations. Displays a grid of subplots representing different aspects of the input images and segmentations.
...@@ -290,7 +291,9 @@ def grid_pred( ...@@ -290,7 +291,9 @@ def grid_pred(
return fig return fig
def vol_masked(vol: np.ndarray, vol_mask: np.ndarray, viz_delta: int=128) -> np.ndarray: def vol_masked(
vol: np.ndarray, vol_mask: np.ndarray, viz_delta: int = 128
) -> np.ndarray:
""" """
Applies masking to a volume based on a binary volume mask. Applies masking to a volume based on a binary volume mask.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment