Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 3D_UNet
  • main
  • notebooksv1
  • scaleZYX_mean
  • notebooks
  • convert_tiff_folders
  • test
  • notebook_update
  • threshold-exploration
  • optimize_scaleZYXdask
  • layered_surface_segmentation
  • conv_zarr_tiff_folders
  • 3d_watershed
  • tr_val_te_splits
  • save_files_function
  • memmap_txrm
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results

Target

Select target project
No results found
Select Git revision
  • 3D_UNet
  • main
  • notebooksv1
  • scaleZYX_mean
  • notebooks
  • convert_tiff_folders
  • test
  • notebook_update
  • threshold-exploration
  • optimize_scaleZYXdask
  • layered_surface_segmentation
  • conv_zarr_tiff_folders
  • 3d_watershed
  • tr_val_te_splits
  • save_files_function
  • memmap_txrm
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results
Show changes

Commits on Source 6

12 files
+ 321
38
Compare changes
  • Side-by-side
  • Inline

Files

Original line number Original line Diff line number Diff line
import qim3d.io as io
# import qim3d.io as io
import qim3d.gui as gui
# import qim3d.gui as gui
import qim3d.viz as viz
# import qim3d.viz as viz
import qim3d.utils as utils
# import qim3d.utils as utils
import qim3d.models as models
# import qim3d.models as models
import qim3d.processing as processing
# import qim3d.processing as processing
from . import io, gui, viz, utils, models, processing
import logging
import logging


__version__ = '0.3.2'
__version__ = '0.3.2'
Original line number Original line Diff line number Diff line
@@ -15,19 +15,21 @@ app.launch()
```
```
"""
"""


import gradio as gr
import numpy as np
import os
from qim3d.utils import internal_tools
from qim3d.io import load
from qim3d.io.logger import log
import tifffile
import outputformat as ouf
import datetime
import datetime
import os

import gradio as gr
import matplotlib
import matplotlib


# matplotlib.use("Agg")
# matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import numpy as np
import outputformat as ouf
import tifffile

from qim3d.io import load
from qim3d.io.logger import log
from qim3d.utils import internal_tools




class Interface:
class Interface:
@@ -45,7 +47,6 @@ class Interface:
            "Z min projection",
            "Z min projection",
            "Intensity histogram",
            "Intensity histogram",
            "Data summary",
            "Data summary",

        ]
        ]
        # CSS path
        # CSS path
        current_dir = os.path.dirname(os.path.abspath(__file__))
        current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -118,7 +119,7 @@ class Interface:
                                value="", elem_classes="btn-html h-36"
                                value="", elem_classes="btn-html h-36"
                            )
                            )
                    explorer = gr.FileExplorer(
                    explorer = gr.FileExplorer(
                        glob="{*/,}{*.*}",
                        ignore_glob="*/.*", # ignores hidden files
                        root_dir=os.getcwd(),
                        root_dir=os.getcwd(),
                        label=os.getcwd(),
                        label=os.getcwd(),
                        render=True,
                        render=True,
@@ -406,8 +407,10 @@ class Pipeline:
                virtual_stack=session.virtual_stack,
                virtual_stack=session.virtual_stack,
                dataset_name=session.dataset_name,
                dataset_name=session.dataset_name,
            )
            )
            if session.vol.ndim != 3:
                raise ValueError("Invalid data shape should be 3 dimensional, not shape: ", session.vol.shape)
        except Exception as error_message:
        except Exception as error_message:
            raise ValueError(
            raise gr.Error(
                f"Failed to load the image: {error_message}"
                f"Failed to load the image: {error_message}"
            ) from error_message
            ) from error_message


Original line number Original line Diff line number Diff line
from .filters import *
from .filters import *
from .local_thickness import local_thickness
from .local_thickness import local_thickness
from .detection import *
+99 −0
Original line number Original line Diff line number Diff line
import numpy as np
from qim3d.io.logger import log
from skimage.feature import blob_dog

__all__ = ["Blob"]


class Blob:
    def __init__(
        self,
        background="dark",
        min_sigma=1,
        max_sigma=50,
        sigma_ratio=1.6,
        threshold=0.5,
        overlap=0.5,
        threshold_rel=None,
        exclude_border=False,
    ):
        """
        Initialize the blob detection object
        Args:
            background: 'dark' if background is darker than the blobs, 'bright' if background is lighter than the blobs
            min_sigma: The minimum standard deviation for Gaussian kernel
            max_sigma: The maximum standard deviation for Gaussian kernel
            sigma_ratio: The ratio between the standard deviation of Gaussian Kernels
            threshold: The absolute lower bound for scale space maxima. Reduce this to detect blobs with lower intensities.
            overlap: The fraction of area of two blobs that overlap
            threshold_rel: The relative lower bound for scale space maxima
            exclude_border: If True, exclude blobs that are too close to the border of the image
        """
        self.background = background
        self.min_sigma = min_sigma
        self.max_sigma = max_sigma
        self.sigma_ratio = sigma_ratio
        self.threshold = threshold
        self.overlap = overlap
        self.threshold_rel = threshold_rel
        self.exclude_border = exclude_border
        self.vol_shape = None
        self.blobs = None

    def detect(self, vol):
        """
        Detect blobs in the volume
        Args:
            vol: The volume to detect blobs in
        Returns:
            blobs: The blobs found in the volume as (p, r, c, radius)
        """
        self.vol_shape = vol.shape
        if self.background == "bright":
            log.info("Bright background selected, volume will be inverted.")
            vol = np.invert(vol)

        blobs = blob_dog(
            vol,
            min_sigma=self.min_sigma,
            max_sigma=self.max_sigma,
            sigma_ratio=self.sigma_ratio,
            threshold=self.threshold,
            overlap=self.overlap,
            threshold_rel=self.threshold_rel,
            exclude_border=self.exclude_border,
        )
        blobs[:, 3] = blobs[:, 3] * np.sqrt(3)  # Change sigma to radius
        self.blobs = blobs
        return self.blobs
    
    def get_mask(self):
        '''
        Retrieve a binary volume with the blobs marked as True
        Returns:
            binary_volume: A binary volume with the blobs marked as True
        '''
        binary_volume = np.zeros(self.vol_shape, dtype=bool)

        for z, y, x, radius in self.blobs:
            # Calculate the bounding box around the blob
            z_start = max(0, int(z - radius))
            z_end = min(self.vol_shape[0], int(z + radius) + 1)
            y_start = max(0, int(y - radius))
            y_end = min(self.vol_shape[1], int(y + radius) + 1)
            x_start = max(0, int(x - radius))
            x_end = min(self.vol_shape[2], int(x + radius) + 1)

            z_indices, y_indices, x_indices = np.indices((z_end - z_start, y_end - y_start, x_end - x_start))
            z_indices += z_start
            y_indices += y_start
            x_indices += x_start

            # Calculate distances from the center of the blob to voxels within the bounding box
            dist = np.sqrt((x_indices - x)**2 + (y_indices - y)**2 + (z_indices - z)**2)

            binary_volume[z_start:z_end, y_start:y_end, x_start:x_end][dist <= radius] = True

        return binary_volume

Original line number Original line Diff line number Diff line
@@ -3,6 +3,6 @@ from . import doi, internal_tools
from .augmentations import Augmentation
from .augmentations import Augmentation
from .cc import get_3d_cc
from .cc import get_3d_cc
from .data import Dataset, prepare_dataloaders, prepare_datasets
from .data import Dataset, prepare_dataloaders, prepare_datasets
from .img import overlay_rgb_images
from .models import inference, model_summary, train_model
from .models import inference, model_summary, train_model
from .system import Memory
from .system import Memory
from .img import overlay_rgb_images
Original line number Original line Diff line number Diff line
from .visualizations import plot_metrics
from .visualizations import plot_metrics
from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness
from .img import grid_pred, grid_overview, slices, slicer, orthogonal, plot_cc, local_thickness
from .k3d import vol
from .k3d import vol
from .colormaps import objects
from .detection import circles

qim3d/viz/colormaps.py

0 → 100644
+95 −0
Original line number Original line Diff line number Diff line
import colorsys

import numpy as np
from matplotlib.colors import LinearSegmentedColormap

from qim3d.io.logger import log


def objects(
    nlabels,
    style="bright",
    first_color_background=True,
    last_color_background=False,
    background_color=(0.0, 0.0, 0.0),
    seed=19,
):
    """
    Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks

    Args:
        nlabels (int): Number of labels (size of colormap)
        style (str, optional): 'bright' for strong colors, 'soft' for pastel colors. Defaults to 'bright'.
        first_color_background (bool, optional): Option to use first color as background. Defaults to True.
        last_color_background (bool, optional): Option to use last color as background. Defaults to False.
        seed (int, optional): Seed for random number generator. Defaults to 19.

    Returns:
        matplotlib.colors.LinearSegmentedColormap: Colormap for matplotlib
    """
    # Check style
    if style not in ("bright", "soft"):
        raise ValueError(
            f'Please choose "bright" or "soft" for style in qim3dCmap not "{style}"'
        )

    # Translate strings to background color
    color_dict = {"black": (0.0, 0.0, 0.0), "white": (1.0, 1.0, 1.0)}
    if not isinstance(background_color, tuple):
        try:
            background_color = color_dict[background_color]
        except KeyError:
            raise ValueError(
                f'Invalid color name "{background_color}". Please choose from {list(color_dict.keys())}.'
            )

    # Add one to nlabels to include the background color
    nlabels += 1

    # Create a new random generator, to locally set seed
    rng = np.random.default_rng(seed)

    # Generate color map for bright colors, based on hsv
    if style == "bright":
        randHSVcolors = [
            (
                rng.uniform(low=0.0, high=1),
                rng.uniform(low=0.4, high=1),
                rng.uniform(low=0.9, high=1),
            )
            for i in range(nlabels)
        ]

        # Convert HSV list to RGB
        randRGBcolors = []
        for HSVcolor in randHSVcolors:
            randRGBcolors.append(
                colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])
            )

    # Generate soft pastel colors, by limiting the RGB spectrum
    if style == "soft":
        low = 0.6
        high = 0.95
        randRGBcolors = [
            (
                rng.uniform(low=low, high=high),
                rng.uniform(low=low, high=high),
                rng.uniform(low=low, high=high),
            )
            for i in range(nlabels)
        ]

    # Set first and last color to background
    if first_color_background:
        randRGBcolors[0] = background_color

    if last_color_background:
        randRGBcolors[-1] = background_color

    # Create colormap
    objects_cmap = LinearSegmentedColormap.from_list(
        "objects_cmap", randRGBcolors, N=nlabels
    )

    return objects_cmap

qim3d/viz/detection.py

0 → 100644
+76 −0
Original line number Original line Diff line number Diff line
import matplotlib.pyplot as plt
from qim3d.viz import slices
from qim3d.io.logger import log
import numpy as np
import ipywidgets as widgets
from IPython.display import clear_output, display


def circles(blobs, vol, alpha=0.5, color="#ff9900", **kwargs):
    """
    Plots the blobs found on a slice of the volume.

    This function takes in a 3D volume and a list of blobs (detected features)
    and plots the blobs on a specified slice of the volume. If no slice is specified,
    it defaults to the middle slice of the volume.

    Args:
        blobs (array-like): An array-like object of blobs, where each blob is represented
            as a 4-tuple (p, r, c, radius). Usally the result of qim3d.processing.detection.Blob()
        vol (array-like): The 3D volume on which to plot the blobs.
        z_slice (int, optional): The index of the slice to plot. If not provided, the middle slice is used.
        **kwargs: Arbitrary keyword arguments for the `slices` function.

    Returns:
        matplotlib.figure.Figure: The resulting figure after adding the blobs to the slice.

    """

    def _slicer(z_slice):
        clear_output(wait=True)
        fig = slices(
            vol,
            n_slices=1,
            position=z_slice,
            img_height=3,
            img_width=3,
            cmap="gray",
            show_position=False,
        )
        # Add circles from deteced blobs
        for detected in blobs:
            z, y, x, s = detected
            if abs(z - z_slice) < s:  # The blob is in the slice

                # Adjust the radius based on the distance from the center of the sphere
                distance_from_center = abs(z - z_slice)
                angle = (
                    np.pi / 2 * (distance_from_center / s)
                )  # Angle varies from 0 at the center to pi/2 at the edge
                adjusted_radius = s * np.cos(angle)  # Radius follows a cosine curve

                if adjusted_radius > 0.5:
                    c = plt.Circle(
                        (x, y),
                        adjusted_radius,
                        color=color,
                        linewidth=0,
                        fill=True,
                        alpha=alpha,
                    )
                    fig.get_axes()[0].add_patch(c)

        display(fig)
        return fig

    position_slider = widgets.IntSlider(
        value=vol.shape[0] // 2,
        min=0,
        max=vol.shape[0] - 1,
        description="Slice",
        continuous_update=True,
    )
    slicer_obj = widgets.interactive(_slicer, z_slice=position_slider)
    slicer_obj.layout = widgets.Layout(align_items="flex-start")

    return slicer_obj
+17 −12
Original line number Original line Diff line number Diff line
@@ -4,6 +4,8 @@ Provides a collection of visualization functions.


import math
import math
from typing import List, Optional, Union, Tuple
from typing import List, Optional, Union, Tuple

import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import numpy as np
import numpy as np
import torch
import torch
@@ -11,9 +13,9 @@ from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import LinearSegmentedColormap


import qim3d.io
import qim3d.io
import ipywidgets as widgets
from qim3d.io.logger import log
from qim3d.io.logger import log
from qim3d.utils.cc import CC
from qim3d.utils.cc import CC
from qim3d.viz.colormaps import objects




def grid_overview(
def grid_overview(
@@ -231,6 +233,7 @@ def slices(
    show: bool = False,
    show: bool = False,
    show_position: bool = True,
    show_position: bool = True,
    interpolation: Optional[str] = "none",
    interpolation: Optional[str] = "none",
    **imshow_kwargs,
) -> plt.Figure:
) -> plt.Figure:
    """Displays one or several slices from a 3d volume.
    """Displays one or several slices from a 3d volume.


@@ -334,7 +337,7 @@ def slices(
            slice_idx = i * max_cols + j
            slice_idx = i * max_cols + j
            try:
            try:
                slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
                slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
                ax.imshow(slice_img, cmap=cmap, interpolation=interpolation)
                ax.imshow(slice_img, cmap=cmap, interpolation=interpolation, **imshow_kwargs)


                if show_position:
                if show_position:
                    ax.text(
                    ax.text(
@@ -399,7 +402,8 @@ def slicer(
    img_height: int = 3,
    img_height: int = 3,
    img_width: int = 3,
    img_width: int = 3,
    show_position: bool = False,
    show_position: bool = False,
    interpolation: Optional[str] = None,
    interpolation: Optional[str] = "none",
    **imshow_kwargs,
) -> widgets.interactive:
) -> widgets.interactive:
    """Interactive widget for visualizing slices of a 3D volume.
    """Interactive widget for visualizing slices of a 3D volume.


@@ -437,6 +441,7 @@ def slicer(
            position=position,
            position=position,
            n_slices=1,
            n_slices=1,
            show=True,
            show=True,
            **imshow_kwargs,
        )
        )
        return fig
        return fig


@@ -535,14 +540,13 @@ def plot_cc(
        components (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.
        components (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.
        max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32.
        max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32.
        overlay (optional): Overlay image. Defaults to None.
        overlay (optional): Overlay image. Defaults to None.
        crop (bool, optional): Whether to crop the overlay image. Defaults to False.
        crop (bool, optional): Whether to crop the image to the cc. Defaults to False.
        show (bool, optional): Whether to show the figure. Defaults to True.
        show (bool, optional): Whether to show the figure. Defaults to True.
        **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`.
        **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`.


    Returns:
    Returns:
        figs (list[plt.Figure]): List of figures, if `show=False`.
        figs (list[plt.Figure]): List of figures, if `show=False`.
    """
    """
    figs = []
    # if no components are given, plot the first max_cc_to_plot=32 components
    # if no components are given, plot the first max_cc_to_plot=32 components
    if component_indexs is None:
    if component_indexs is None:
        if len(connected_components) > max_cc_to_plot:
        if len(connected_components) > max_cc_to_plot:
@@ -553,11 +557,10 @@ def plot_cc(
            1, min(max_cc_to_plot + 1, len(connected_components) + 1)
            1, min(max_cc_to_plot + 1, len(connected_components) + 1)
        )
        )
        
        
    figs = []
    for component in component_indexs:
    for component in component_indexs:
        if overlay is not None:
        if overlay is not None:
            assert (
            assert (overlay.shape == connected_components.shape), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."
                overlay.shape == connected_components.shape
            ), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."


            # plots overlay masked to connected component
            # plots overlay masked to connected component
            if crop:
            if crop:
@@ -573,10 +576,12 @@ def plot_cc(
                overlay_crop = np.where(cc == 0, 0, overlay)
                overlay_crop = np.where(cc == 0, 0, overlay)
                fig = slices(overlay_crop, show=show, **kwargs)
                fig = slices(overlay_crop, show=show, **kwargs)
        else:
        else:
            # assigns discrete color map to each connected component if not given 
            if "cmap" not in kwargs:
                kwargs["cmap"] = qim3dCmap(len(component_indexs))
        
            # Plot the connected component without overlay
            # Plot the connected component without overlay
            fig = slices(
            fig = slices(connected_components.get_cc(component, crop=crop), show=show, **kwargs)
                connected_components.get_cc(component, crop=crop), show=show, **kwargs
            )


        figs.append(fig)
        figs.append(fig)


Original line number Original line Diff line number Diff line
@@ -11,7 +11,7 @@ import k3d
import numpy as np
import numpy as np




def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwargs):
def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=None, **kwargs):
    """
    """
    Visualizes a 3D volume using volumetric rendering.
    Visualizes a 3D volume using volumetric rendering.


@@ -62,6 +62,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, **kwa
            if aspectmode.lower() == "data"
            if aspectmode.lower() == "data"
            else None
            else None
        ),
        ),
        color_map=cmap,
    )
    )
    plot = k3d.plot(grid_visible=grid_visible, **kwargs)
    plot = k3d.plot(grid_visible=grid_visible, **kwargs)
    plot += plt_volume
    plot += plt_volume
+2 −2
Original line number Original line Diff line number Diff line
from setuptools import setup, find_packages
import os
import os


from setuptools import find_packages, setup


# Read the contents of your README file
# Read the contents of your README file
with open("README.md", "r", encoding="utf-8") as f:
with open("README.md", "r", encoding="utf-8") as f:
@@ -38,7 +38,7 @@ setup(
    python_requires=">=3.10",
    python_requires=">=3.10",
    install_requires=[
    install_requires=[
        "albumentations>=1.3.1",
        "albumentations>=1.3.1",
        "gradio>=4.15.0",
        "gradio>=4.22.0",
        "h5py>=3.9.0",
        "h5py>=3.9.0",
        "localthickness>=0.1.2",
        "localthickness>=0.1.2",
        "matplotlib>=3.8.0",
        "matplotlib>=3.8.0",