Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results

Target

Select target project
  • QIM/tools/qim3d
1 result
Select Git revision
  • 3D_UNet
  • 3d_watershed
  • conv_zarr_tiff_folders
  • convert_tiff_folders
  • layered_surface_segmentation
  • main
  • memmap_txrm
  • notebook_update
  • notebooks
  • notebooksv1
  • optimize_scaleZYXdask
  • save_files_function
  • scaleZYX_mean
  • test
  • threshold-exploration
  • tr_val_te_splits
  • v0.2.0
  • v0.3.0
  • v0.3.1
  • v0.3.2
  • v0.3.3
  • v0.3.9
  • v0.4.0
  • v0.4.1
24 results
Show changes
Showing with 678 additions and 494 deletions
import ipywidgets as widgets
import matplotlib.pyplot as plt
from qim3d.utils._logger import log
import numpy as np
import ipywidgets as widgets
from IPython.display import clear_output, display
import qim3d
def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float = 0.5, color: str = "#ff9900", **kwargs)-> widgets.interactive:
def circles(
blobs: tuple[float, float, float, float],
vol: np.ndarray,
alpha: float = 0.5,
color: str = '#ff9900',
**kwargs,
) -> widgets.interactive:
"""
Plots the blobs found on a slice of the volume.
......@@ -47,6 +53,7 @@ def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float
qim3d.viz.circles(blobs, vol, alpha=0.8, color='blue')
```
![blob detection](../../assets/screenshots/blob_detection.gif)
"""
def _slicer(z_slice):
......@@ -54,16 +61,15 @@ def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float
fig = qim3d.viz.slices_grid(
vol[z_slice : z_slice + 1],
num_slices=1,
color_map="gray",
color_map='gray',
display_figure=False,
display_positions=False,
**kwargs
**kwargs,
)
# Add circles from deteced blobs
for detected in blobs:
z, y, x, s = detected
if abs(z - z_slice) < s: # The blob is in the slice
# Adjust the radius based on the distance from the center of the sphere
distance_from_center = abs(z - z_slice)
angle = (
......@@ -89,10 +95,10 @@ def circles(blobs: tuple[float,float,float,float], vol: np.ndarray, alpha: float
value=vol.shape[0] // 2,
min=0,
max=vol.shape[0] - 1,
description="Slice",
description='Slice',
continuous_update=True,
)
slicer_obj = widgets.interactive(_slicer, z_slice=position_slider)
slicer_obj.layout = widgets.Layout(align_items="flex-start")
slicer_obj.layout = widgets.Layout(align_items='flex-start')
return slicer_obj
......@@ -7,16 +7,20 @@ Volumetric visualization using K3D
"""
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Colormap
from qim3d.utils._logger import log
from qim3d.utils._misc import downscale_img, scale_to_float16
from pygel3d import hmesh
from pygel3d import jupyter_display as jd
import k3d
from typing import Optional
def volumetric(
img: np.ndarray,
aspectmode: str = "data",
aspectmode: str = 'data',
show: bool = True,
save: bool = False,
grid_visible: bool = False,
......@@ -24,9 +28,9 @@ def volumetric(
constant_opacity: bool = False,
vmin: float | None = None,
vmax: float | None = None,
samples: int|str = "auto",
samples: int | str = 'auto',
max_voxels: int = 512**3,
data_type: str = "scaled_float16",
data_type: str = 'scaled_float16',
**kwargs,
):
"""
......@@ -81,11 +85,10 @@ def volumetric(
```
"""
import k3d
pixel_count = img.shape[0] * img.shape[1] * img.shape[2]
# target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html
if samples == "auto":
if samples == 'auto':
y1, x1 = 256, 16777216 # 256 samples at res 256*256*256=16.777.216
y2, x2 = 32, 134217728 # 32 samples at res 512*512*512=134.217.728
......@@ -97,7 +100,7 @@ def volumetric(
else:
samples = int(samples) # make sure it's an integer
if aspectmode.lower() not in ["data", "cube"]:
if aspectmode.lower() not in ['data', 'cube']:
raise ValueError("aspectmode should be either 'data' or 'cube'")
# check if image should be downsampled for visualization
original_shape = img.shape
......@@ -107,7 +110,7 @@ def volumetric(
if original_shape != new_shape:
log.warning(
f"Downsampled image for visualization, from {original_shape} to {new_shape}"
f'Downsampled image for visualization, from {original_shape} to {new_shape}'
)
# Scale the image to float16 if needed
......@@ -115,8 +118,7 @@ def volumetric(
# When saving, we need float64
img = img.astype(np.float64)
else:
if data_type == "scaled_float16":
if data_type == 'scaled_float16':
img = scale_to_float16(img)
else:
img = img.astype(data_type)
......@@ -151,7 +153,7 @@ def volumetric(
img,
bounds=(
[0, img.shape[2], 0, img.shape[1], 0, img.shape[0]]
if aspectmode.lower() == "data"
if aspectmode.lower() == 'data'
else None
),
color_map=color_map,
......@@ -164,7 +166,7 @@ def volumetric(
plot += plt_volume
if save:
# Save html to disk
with open(str(save), "w", encoding="utf-8") as fp:
with open(str(save), 'w', encoding='utf-8') as fp:
fp.write(plot.get_snapshot())
if show:
......@@ -172,80 +174,94 @@ def volumetric(
else:
return plot
def mesh(
verts: np.ndarray,
faces: np.ndarray,
mesh,
backend: str = "pygel3d",
wireframe: bool = True,
flat_shading: bool = True,
grid_visible: bool = False,
show: bool = True,
save: bool = False,
**kwargs,
):
"""
Visualizes a 3D mesh using K3D.
)-> Optional[k3d.Plot]:
"""Visualize a 3D mesh using `pygel3d` or `k3d`.
Args:
verts (numpy.ndarray): A 2D array (Nx3) containing the vertices of the mesh.
faces (numpy.ndarray): A 2D array (Mx3) containing the indices of the mesh faces.
wireframe (bool, optional): If True, the mesh is rendered as a wireframe. Defaults to True.
flat_shading (bool, optional): If True, flat shading is applied to the mesh. Defaults to True.
grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
show (bool, optional): If True, displays the visualization inline. Defaults to True.
mesh (pygel3d.hmesh.Manifold): The input mesh object.
backend (str, optional): The visualization backend to use.
Choose between `pygel3d` (default) and `k3d`.
wireframe (bool, optional): If True, displays the mesh as a wireframe.
Works both with `pygel3d` and `k3d`. Defaults to True.
flat_shading (bool, optional): If True, applies flat shading to the mesh.
Works only with `k3d`. Defaults to True.
grid_visible (bool, optional): If True, shows a grid in the visualization.
Works only with `k3d`. Defaults to False.
show (bool, optional): If True, displays the visualization inline.
Works only with `k3d`. Defaults to True.
save (bool or str, optional): If True, saves the visualization as an HTML file.
If a string is provided, it's interpreted as the file path where the HTML
file will be saved. Defaults to False.
**kwargs (Any): Additional keyword arguments to be passed to the `k3d.plot` function.
file will be saved. Works only with `k3d`. Defaults to False.
**kwargs (Any): Additional keyword arguments specific to the chosen backend:
- `k3d.plot` kwargs: Arguments that customize the [`k3d.plot`](https://k3d-jupyter.org/reference/factory.plot.html) visualization.
- `pygel3d.display` kwargs: Arguments that customize the [`pygel3d.display`](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/jupyter_display.html#display) visualization.
Returns:
plot (k3d.plot): If `show=False`, returns the K3D plot object.
k3d.Plot or None:
- If `backend="k3d"`, returns a `k3d.Plot` object.
- If `backend="pygel3d"`, the function displays the mesh but does not return a plot object.
Raises:
ValueError: If `backend` is not `pygel3d` or `k3d`.
Example:
```python
import qim3d
vol = qim3d.generate.noise_object(base_shape=(128,128,128),
final_shape=(128,128,128),
noise_scale=0.03,
order=1,
gamma=1,
max_value=255,
threshold=0.5,
dtype='uint8'
)
mesh = qim3d.mesh.from_volume(vol, step_size=3)
qim3d.viz.mesh(mesh.vertices, mesh.faces)
synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015)
mesh = qim3d.mesh.from_volume(synthetic_blob)
qim3d.viz.mesh(mesh, backend="pygel3d") # or qim3d.viz.mesh(mesh, backend="k3d")
```
<iframe src="https://platform.qim.dk/k3d/mesh_visualization.html" width="100%" height="500" frameborder="0"></iframe>
![pygel3d_visualization](../../assets/screenshots/pygel3d_visualization.png)
"""
import k3d
# Validate the inputs
if verts.shape[1] != 3:
raise ValueError("Vertices array must have shape (N, 3)")
if faces.shape[1] != 3:
raise ValueError("Faces array must have shape (M, 3)")
# Ensure the correct data types and memory layout
verts = np.ascontiguousarray(
verts.astype(np.float32)
) # Cast and ensure C-contiguous layout
faces = np.ascontiguousarray(
faces.astype(np.uint32)
) # Cast and ensure C-contiguous layout
# Create the mesh plot
plt_mesh = k3d.mesh(
vertices=verts,
indices=faces,
if backend not in ["k3d", "pygel3d"]:
raise ValueError("Invalid backend. Choose 'pygel3d' or 'k3d'.")
# Extract vertex positions and face indices
face_indices = list(mesh.faces())
vertices_array = np.array(mesh.positions())
# Extract face vertex indices
face_vertices = [
list(mesh.circulate_face(int(fid), mode="v"))[:3] for fid in face_indices
]
face_vertices = np.array(face_vertices, dtype=np.uint32)
# Validate the mesh structure
if vertices_array.shape[1] != 3 or face_vertices.shape[1] != 3:
raise ValueError("Vertices must have shape (N, 3) and faces (M, 3)")
# Separate valid kwargs for each backend
valid_k3d_kwargs = {k: v for k, v in kwargs.items() if k not in ["smooth", "data"]}
valid_pygel_kwargs = {k: v for k, v in kwargs.items() if k in ["smooth", "data"]}
if backend == "k3d":
vertices_array = np.ascontiguousarray(vertices_array.astype(np.float32))
face_vertices = np.ascontiguousarray(face_vertices)
mesh_plot = k3d.mesh(
vertices=vertices_array,
indices=face_vertices,
wireframe=wireframe,
flat_shading=flat_shading,
)
# Create plot
plot = k3d.plot(grid_visible=grid_visible, **kwargs)
plot += plt_mesh
plot = k3d.plot(grid_visible=grid_visible, **valid_k3d_kwargs)
plot += mesh_plot
if save:
# Save html to disk
......@@ -256,3 +272,8 @@ def mesh(
plot.display()
else:
return plot
elif backend == "pygel3d":
jd.set_export_mode(True)
return jd.display(mesh, wireframe=wireframe, **valid_pygel_kwargs)
"""Provides a collection of visualisation functions for the Layers2d class."""
import io
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
......@@ -17,9 +16,10 @@ def image_with_lines(image: np.ndarray, lines: list, line_thickness: float) -> I
lines: list of 1D arrays to be plotted on top of the image
line_thickness: how thick is the line supposed to be
Returns:
----------
Returns
-------
image_with_lines:
"""
fig, ax = plt.subplots()
ax.imshow(image, cmap='gray')
......@@ -34,4 +34,3 @@ def image_with_lines(image: np.ndarray, lines: list, line_thickness: float) -> I
buf.seek(0)
return Image.open(buf).resize(size=image.squeeze().shape[::-1])
from qim3d.utils._logger import log
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Union, Tuple
from typing import Optional, Tuple, Union
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from qim3d.utils._logger import log
def local_thickness(
image: np.ndarray,
......@@ -13,7 +16,8 @@ def local_thickness(
show: bool = False,
figsize: Tuple[int, int] = (15, 5),
) -> Union[plt.Figure, widgets.interactive]:
"""Visualizes the local thickness of a 2D or 3D image.
"""
Visualizes the local thickness of a 2D or 3D image.
Args:
image (np.ndarray): 2D or 3D NumPy array representing the image/volume.
......@@ -56,24 +60,24 @@ def local_thickness(
image = image.take(slice_idx, axis=axis)
image_lt = image_lt.take(slice_idx, axis=axis)
fig, axs = plt.subplots(1, 3, figsize=figsize, layout="constrained")
fig, axs = plt.subplots(1, 3, figsize=figsize, layout='constrained')
axs[0].imshow(image, cmap="gray")
axs[0].set_title("Original image")
axs[0].axis("off")
axs[0].imshow(image, cmap='gray')
axs[0].set_title('Original image')
axs[0].axis('off')
axs[1].imshow(image_lt, cmap="viridis")
axs[1].set_title("Local thickness")
axs[1].axis("off")
axs[1].imshow(image_lt, cmap='viridis')
axs[1].set_title('Local thickness')
axs[1].axis('off')
plt.colorbar(
axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical"
axs[1].imshow(image_lt, cmap='viridis'), ax=axs[1], orientation='vertical'
)
axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor="black")
axs[2].set_title("Local thickness histogram")
axs[2].set_xlabel("Local thickness")
axs[2].set_ylabel("Count")
axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor='black')
axs[2].set_title('Local thickness histogram')
axs[2].set_xlabel('Local thickness')
axs[2].set_ylabel('Count')
if show:
plt.show()
......@@ -87,7 +91,7 @@ def local_thickness(
if max_projection:
if slice_idx is not None:
log.warning(
"slice_idx is not used for max_projection. It will be ignored."
'slice_idx is not used for max_projection. It will be ignored.'
)
image = image.max(axis=axis)
image_lt = image_lt.max(axis=axis)
......@@ -98,7 +102,7 @@ def local_thickness(
elif isinstance(slice_idx, float):
if slice_idx < 0 or slice_idx > 1:
raise ValueError(
"Values of slice_idx of float type must be between 0 and 1."
'Values of slice_idx of float type must be between 0 and 1.'
)
slice_idx = int(slice_idx * image.shape[0]) - 1
slide_idx_slider = widgets.IntSlider(
......@@ -106,8 +110,8 @@ def local_thickness(
max=image.shape[axis] - 1,
step=1,
value=slice_idx,
description="Slice index",
layout=widgets.Layout(width="450px"),
description='Slice index',
layout=widgets.Layout(width='450px'),
)
widget_obj = widgets.interactive(
_local_thickness,
......@@ -118,15 +122,15 @@ def local_thickness(
axis=widgets.fixed(axis),
slice_idx=slide_idx_slider,
)
widget_obj.layout = widgets.Layout(align_items="center")
widget_obj.layout = widgets.Layout(align_items='center')
if show:
display(widget_obj)
return widget_obj
else:
if max_projection:
log.warning(
"max_projection is only used for 3D images. It will be ignored."
'max_projection is only used for 3D images. It will be ignored.'
)
if slice_idx is not None:
log.warning("slice_idx is only used for 3D images. It will be ignored.")
log.warning('slice_idx is only used for 3D images. It will be ignored.')
return _local_thickness(image, image_lt, show, figsize)
"""Visualization tools"""
import matplotlib
import matplotlib.figure
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
import torch
from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap
from qim3d.utils._logger import log
import matplotlib
def plot_metrics(
*metrics: tuple[dict[str, float]],
linestyle: str = "-",
batch_linestyle: str = "dotted",
linestyle: str = '-',
batch_linestyle: str = 'dotted',
labels: list | None = None,
figsize: tuple = (16, 6),
show: bool = False
show: bool = False,
):
"""
Plots the metrics over epochs and batches.
......@@ -35,6 +38,7 @@ def plot_metrics(
train_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
val_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss,val_loss, labels=['Train','Valid.'])
"""
import seaborn as snb
......@@ -44,9 +48,9 @@ def plot_metrics(
raise ValueError("The number of metrics doesn't match the number of labels.")
# plotting parameters
snb.set_style("darkgrid")
snb.set_style('darkgrid')
snb.set(font_scale=1.5)
plt.rcParams["lines.linewidth"] = 2
plt.rcParams['lines.linewidth'] = 2
fig = plt.figure(figsize=figsize)
......@@ -68,10 +72,10 @@ def plot_metrics(
plt.legend()
plt.ylabel(metric_name)
plt.xlabel("epoch")
plt.xlabel('epoch')
# reset plotting parameters
snb.set_style("white")
snb.set_style('white')
if show:
plt.show()
......@@ -81,14 +85,15 @@ def plot_metrics(
def grid_overview(
data: list,
data: list | torch.utils.data.Dataset,
num_images: int = 7,
cmap_im: str = "gray",
cmap_segm: str = "viridis",
cmap_im: str = 'gray',
cmap_segm: str = 'viridis',
alpha: float = 0.5,
show: bool = False,
) -> matplotlib.figure.Figure:
"""Displays an overview grid of images, labels, and masks (if they exist).
"""
Displays an overview grid of images, labels, and masks (if they exist).
Labels are the annotated target segmentations
Masks are applied to the output and target prior to the loss calculation in case of
......@@ -120,6 +125,7 @@ def grid_overview(
- The number of displayed images is limited to the minimum between `num_images`
and the length of the data.
- The grid layout and dimensions vary based on the presence of a mask.
"""
import torch
......@@ -128,12 +134,12 @@ def grid_overview(
# Check if image data is RGB and inform the user if it's the case
if len(data[0][0].squeeze().shape) > 2:
log.info("Input images are RGB: color map is ignored")
log.info('Input images are RGB: color map is ignored')
# Check if dataset have at least specified number of images
if len(data) < num_images:
log.warning(
"Not enough images in the dataset. Changing num_images=%d to num_images=%d",
'Not enough images in the dataset. Changing num_images=%d to num_images=%d',
num_images,
len(data),
)
......@@ -142,14 +148,14 @@ def grid_overview(
# Adapt segmentation cmap so that background is transparent
colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
colors_segm[:128, 3] = 0
custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
custom_cmap = LinearSegmentedColormap.from_list('CustomCmap', colors_segm)
# Check if data have the right format
if not isinstance(data[0], tuple):
raise ValueError("Data elements must be tuples")
raise ValueError('Data elements must be tuples')
# Define row titles
row_titles = ["Input images", "Ground truth segmentation", "Mask"]
row_titles = ['Input images', 'Ground truth segmentation', 'Mask']
# Make new list such that possible augmentations remain identical for all three rows
plot_data = [data[idx] for idx in range(num_images)]
......@@ -169,10 +175,10 @@ def grid_overview(
if row in [1, 2]: # Ground truth segmentation and mask
ax.imshow(plot_data[col][0].squeeze(), cmap=cmap_im)
ax.imshow(plot_data[col][row].squeeze(), cmap=custom_cmap, alpha=alpha)
ax.axis("off")
ax.axis('off')
else:
ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im)
ax.axis("off")
ax.axis('off')
if show:
plt.show()
......@@ -184,12 +190,13 @@ def grid_overview(
def grid_pred(
in_targ_preds: tuple[np.ndarray, np.ndarray, np.ndarray],
num_images: int = 7,
cmap_im: str = "gray",
cmap_segm: str = "viridis",
cmap_im: str = 'gray',
cmap_segm: str = 'viridis',
alpha: float = 0.5,
show: bool = False,
) -> matplotlib.figure.Figure:
"""Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
"""
Displays a grid of input images, predicted segmentations, ground truth segmentations, and their comparison.
Displays a grid of subplots representing different aspects of the input images and segmentations.
The grid includes the following rows:
......@@ -221,25 +228,26 @@ def grid_pred(
model = MySegmentationModel()
in_targ_preds = qim3d.ml.inference(dataset,model)
qim3d.viz.grid_pred(in_targ_preds, cmap_im='viridis', alpha=0.5)
"""
import torch
# Check if dataset have at least specified number of images
if len(in_targ_preds[0]) < num_images:
log.warning(
"Not enough images in the dataset. Changing num_images=%d to num_images=%d",
'Not enough images in the dataset. Changing num_images=%d to num_images=%d',
num_images,
len(in_targ_preds[0]),
)
num_images = len(in_targ_preds[0])
# Take only the number of images from in_targ_preds
inputs, targets, preds = [items[:num_images] for items in in_targ_preds]
inputs, targets, preds = (items[:num_images] for items in in_targ_preds)
# Adapt segmentation cmap so that background is transparent
colors_segm = colormaps.get_cmap(cmap_segm)(np.linspace(0, 1, 256))
colors_segm[:128, 3] = 0
custom_cmap = LinearSegmentedColormap.from_list("CustomCmap", colors_segm)
custom_cmap = LinearSegmentedColormap.from_list('CustomCmap', colors_segm)
N = num_images
H = inputs[0].shape[-2]
......@@ -251,10 +259,10 @@ def grid_pred(
comp_rgb[:, 3, :, :] = targets.logical_or(preds)
row_titles = [
"Input images",
"Predicted segmentation",
"Ground truth segmentation",
"True vs. predicted segmentation",
'Input images',
'Predicted segmentation',
'Ground truth segmentation',
'True vs. predicted segmentation',
]
fig = plt.figure(figsize=(2 * num_images, 10), constrained_layout=True)
......@@ -269,20 +277,20 @@ def grid_pred(
for col, ax in enumerate(np.atleast_1d(axs)):
if row == 0:
ax.imshow(inputs[col], cmap=cmap_im)
ax.axis("off")
ax.axis('off')
elif row == 1: # Predicted segmentation
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(preds[col], cmap=custom_cmap, alpha=alpha)
ax.axis("off")
ax.axis('off')
elif row == 2: # Ground truth segmentation
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(targets[col], cmap=custom_cmap, alpha=alpha)
ax.axis("off")
ax.axis('off')
else:
ax.imshow(inputs[col], cmap=cmap_im)
ax.imshow(comp_rgb[col].permute(1, 2, 0), alpha=alpha)
ax.axis("off")
ax.axis('off')
if show:
plt.show()
......@@ -315,8 +323,8 @@ def vol_masked(
"""
background = (vol.astype("float") + viz_delta) * (1 - vol_mask) * -1
foreground = (vol.astype("float") + viz_delta) * vol_mask
background = (vol.astype('float') + viz_delta) * (1 - vol_mask) * -1
foreground = (vol.astype('float') + viz_delta) * vol_mask
vol_masked_result = background + foreground
return vol_masked_result
......@@ -7,7 +7,7 @@ X_STRIDE = 4
Y_STRIDE = 8
BACK_TO_NORMAL = "\u001b[0m"
BACK_TO_NORMAL = '\u001b[0m'
END_MARKER = -10
"""
......@@ -18,75 +18,113 @@ like in a field 4x8.
BITMAPS = [
# Block graphics
# 0xffff0000, 0x2580, // upper 1/2; redundant with inverse lower 1/2
0x00000000, '\u00a0',
0x0000000f, '\u2581', # lower 1/8
0x000000ff, '\u2582', # lower 1/4
0x00000fff, '\u2583',
0x0000ffff, '\u2584', # lower 1/2
0x000fffff, '\u2585',
0x00ffffff, '\u2586', # lower 3/4
0x0fffffff, '\u2587',
0x00000000,
'\u00a0',
0x0000000F,
'\u2581', # lower 1/8
0x000000FF,
'\u2582', # lower 1/4
0x00000FFF,
'\u2583',
0x0000FFFF,
'\u2584', # lower 1/2
0x000FFFFF,
'\u2585',
0x00FFFFFF,
'\u2586', # lower 3/4
0x0FFFFFFF,
'\u2587',
# 0xffffffff, 0x2588, # full; redundant with inverse space
0xeeeeeeee, '\u258a', # left 3/4
0xcccccccc, '\u258c', # left 1/2
0x88888888, '\u258e', # left 1/4
0x0000cccc, '\u2596', # quadrant lower left
0x00003333, '\u2597', # quadrant lower right
0xcccc0000, '\u2598', # quadrant upper left
0xEEEEEEEE,
'\u258a', # left 3/4
0xCCCCCCCC,
'\u258c', # left 1/2
0x88888888,
'\u258e', # left 1/4
0x0000CCCC,
'\u2596', # quadrant lower left
0x00003333,
'\u2597', # quadrant lower right
0xCCCC0000,
'\u2598', # quadrant upper left
# 0xccccffff, 0x2599, # 3/4 redundant with inverse 1/4
0xcccc3333, '\u259a', # diagonal 1/2
0xCCCC3333,
'\u259a', # diagonal 1/2
# 0xffffcccc, 0x259b, # 3/4 redundant
# 0xffff3333, 0x259c, # 3/4 redundant
0x33330000, '\u259d', # quadrant upper right
0x33330000,
'\u259d', # quadrant upper right
# 0x3333cccc, 0x259e, # 3/4 redundant
# 0x3333ffff, 0x259f, # 3/4 redundant
# Line drawing subset: no double lines, no complex light lines
0x000ff000, '\u2501', # Heavy horizontal
0x66666666, '\u2503', # Heavy vertical
0x00077666, '\u250f', # Heavy down and right
0x000ee666, '\u2513', # Heavy down and left
0x66677000, '\u2517', # Heavy up and right
0x666ee000, '\u251b', # Heavy up and left
0x66677666, '\u2523', # Heavy vertical and right
0x666ee666, '\u252b', # Heavy vertical and left
0x000ff666, '\u2533', # Heavy down and horizontal
0x666ff000, '\u253b', # Heavy up and horizontal
0x666ff666, '\u254b', # Heavy cross
0x000cc000, '\u2578', # Bold horizontal left
0x00066000, '\u2579', # Bold horizontal up
0x00033000, '\u257a', # Bold horizontal right
0x00066000, '\u257b', # Bold horizontal down
0x06600660, '\u254f', # Heavy double dash vertical
0x000f0000, '\u2500', # Light horizontal
0x0000f000, '\u2500', #
0x44444444, '\u2502', # Light vertical
0x22222222, '\u2502',
0x000e0000, '\u2574', # light left
0x0000e000, '\u2574', # light left
0x44440000, '\u2575', # light up
0x22220000, '\u2575', # light up
0x00030000, '\u2576', # light right
0x00003000, '\u2576', # light right
0x00004444, '\u2577', # light down
0x00002222, '\u2577', # light down
0x11224488, '\u2571', # diagonals
0x88442211, '\u2572',
0x99666699, '\u2573',
0, END_MARKER, 0 # End marker
0x000FF000,
'\u2501', # Heavy horizontal
0x66666666,
'\u2503', # Heavy vertical
0x00077666,
'\u250f', # Heavy down and right
0x000EE666,
'\u2513', # Heavy down and left
0x66677000,
'\u2517', # Heavy up and right
0x666EE000,
'\u251b', # Heavy up and left
0x66677666,
'\u2523', # Heavy vertical and right
0x666EE666,
'\u252b', # Heavy vertical and left
0x000FF666,
'\u2533', # Heavy down and horizontal
0x666FF000,
'\u253b', # Heavy up and horizontal
0x666FF666,
'\u254b', # Heavy cross
0x000CC000,
'\u2578', # Bold horizontal left
0x00066000,
'\u2579', # Bold horizontal up
0x00033000,
'\u257a', # Bold horizontal right
0x00066000,
'\u257b', # Bold horizontal down
0x06600660,
'\u254f', # Heavy double dash vertical
0x000F0000,
'\u2500', # Light horizontal
0x0000F000,
'\u2500', #
0x44444444,
'\u2502', # Light vertical
0x22222222,
'\u2502',
0x000E0000,
'\u2574', # light left
0x0000E000,
'\u2574', # light left
0x44440000,
'\u2575', # light up
0x22220000,
'\u2575', # light up
0x00030000,
'\u2576', # light right
0x00003000,
'\u2576', # light right
0x00004444,
'\u2577', # light down
0x00002222,
'\u2577', # light down
0x11224488,
'\u2571', # diagonals
0x88442211,
'\u2572',
0x99666699,
'\u2573',
0,
END_MARKER,
0, # End marker
]
class Color:
def __init__(self, red: int, green: int, blue: int):
self.check_value(red)
......@@ -97,15 +135,17 @@ class Color:
self.blue = blue
def check_value(sel, value: int):
assert isinstance(value, int), F"Color value has to be integer, this is {type(value)}"
assert value < 256, F"Color value has to be between 0 and 255, this is {value}"
assert value >= 0, F"Color value has to be between 0 and 255, this is {value}"
assert isinstance(
value, int
), f'Color value has to be integer, this is {type(value)}'
assert value < 256, f'Color value has to be between 0 and 255, this is {value}'
assert value >= 0, f'Color value has to be between 0 and 255, this is {value}'
def __str__(self):
"""
Returns the string in ansi color format
"""
return F"{self.red};{self.green};{self.blue}"
return f'{self.red};{self.green};{self.blue}'
def chardata(unicodeChar: str, character_color: Color, background_color: Color) -> str:
......@@ -117,7 +157,8 @@ def chardata(unicodeChar: str, character_color:Color, background_color:Color) ->
assert isinstance(character_color, Color)
assert isinstance(background_color, Color)
assert isinstance(unicodeChar, str)
return F"\033[38;2;{character_color}m\033[48;2;{background_color}m{unicodeChar}"
return f'\033[38;2;{character_color}m\033[48;2;{background_color}m{unicodeChar}'
def get_best_unicode_pattern(bitmap: int) -> tuple[int, str, bool]:
"""
......@@ -125,19 +166,20 @@ def get_best_unicode_pattern(bitmap:int) -> tuple[int, str, bool]:
It computes the difference by counting 1s after XORing the two. If they are identical, the count will be 0.
This character will be printed
Parameters:
-----------
Parameters
----------
- bitmap (int): int representing the bitmap the image segment.
Returns:
----------
Returns
-------
- best_pattern (int): int representing the pattern that was the best match, is then used to calculate colors
- unicode (str): the unicode character that represents the given bitmap the best and is then printed
- inverse (bool): The list does't contain unicode characters that are inverse of each other. The match can be achieved by simply using
the inversed bitmap. But then we need to know if we have to switch background and foreground color.
"""
best_diff = 8
best_pattern = 0x0000ffff
best_pattern = 0x0000FFFF
unicode = '\u2584'
inverse = False
......@@ -159,6 +201,7 @@ def get_best_unicode_pattern(bitmap:int) -> tuple[int, str, bool]:
return best_pattern, unicode, inverse
def int_bitmap_from_ndarray(array_bitmap: np.ndarray) -> int:
"""
Flattens the array
......@@ -166,7 +209,8 @@ def int_bitmap_from_ndarray(array_bitmap:np.ndarray)->int:
Creates a string representing binary number
Casts it to integer
"""
return int(F"0b{''.join([str(i) for i in array_bitmap.flatten()])}", base = 2)
return int(f"0b{''.join([str(i) for i in array_bitmap.flatten()])}", base=2)
def ndarray_from_int_bitmap(bitmap: int, shape: tuple = (8, 4)) -> np.ndarray:
"""
......@@ -179,15 +223,17 @@ def ndarray_from_int_bitmap(bitmap:int, shape:tuple = (8, 4))-> np.ndarray:
string = str(bin(bitmap))[2:].zfill(shape[0] * shape[1])
return np.array([int(i) for i in string]).reshape(shape)
def create_bitmap(image_segment: np.ndarray) -> int:
"""
Parameters:
------------
Parameters
----------
image_segment: np.ndarray of shape (x, y, 3)
Returns:
----------
Returns
-------
bitmap: int, each bit says if the unicode character should cover this bit or not
"""
max_color = np.max(np.max(image_segment, axis=0), axis=0)
......@@ -199,44 +245,56 @@ def create_bitmap(image_segment:np.ndarray)->int:
split_threshold = rng[max_index] / 2 + min_color[max_index]
bitmap = np.array(image_segment[:, :, max_index] <= split_threshold, dtype=int)
return int_bitmap_from_ndarray(bitmap)
def get_color(image_segment: np.ndarray, char_array: np.ndarray) -> Color:
"""
Computes the average color of the segment from pixels specified in charr_array
The color is then average over the part then unicode character covers or the background
Parameters:
-----------
Parameters
----------
- image_segment: 4x8 part of the image with the original values so average color can be calculated
- char_array: indices saying which pixels out of the 4x8 should be used for color calculation
Returns:
---------
Returns
-------
- color: containing the average color over defined pixels
"""
colors = []
for channel_index in range(image_segment.shape[2]):
channel = image_segment[:, :, channel_index]
colors.append(int(np.average(channel[char_array])))
return Color(colors[0], colors[1], colors[2]) if len(colors) == 3 else Color(colors[0], colors[0], colors[0])
return (
Color(colors[0], colors[1], colors[2])
if len(colors) == 3
else Color(colors[0], colors[0], colors[0])
)
def get_colors(image_segment:np.ndarray, char_array:np.ndarray) -> tuple[Color, Color]:
def get_colors(
image_segment: np.ndarray, char_array: np.ndarray
) -> tuple[Color, Color]:
"""
Parameters:
Parameters
----------
- image_segment
- char_array
Returns:
----------
Returns
-------
- Foreground color
- Background color
"""
return get_color(image_segment, char_array == 1), get_color(image_segment, char_array == 0)
return get_color(image_segment, char_array == 1), get_color(
image_segment, char_array == 0
)
def segment_string(image_segment: np.ndarray) -> str:
"""
......@@ -257,23 +315,24 @@ def segment_string(image_segment:np.ndarray)-> str:
bg_color, fg_color = fg_color, bg_color
return chardata(unicode, fg_color, bg_color)
def image_ansi_string(image: np.ndarray) -> str:
"""
For each segment 4x8 finds the string with colored unicode character
Create the string for whole image
Parameters:
-----------
Parameters
----------
- image: image to be displayed in terminal
Returns:
----------
Returns
-------
- ansi_string: when printed, will render the image
"""
string = []
for y in range(0, image.shape[0], Y_STRIDE):
for x in range(0, image.shape[1], X_STRIDE):
this_segment = image[y : y + Y_STRIDE, x : x + X_STRIDE, :]
if this_segment.shape[0] != Y_STRIDE:
segment = np.zeros((Y_STRIDE, X_STRIDE, this_segment.shape[2]))
......@@ -281,17 +340,16 @@ def image_ansi_string(image:np.ndarray) -> str:
this_segment = segment
string.append(segment_string(this_segment))
string.append(F"{BACK_TO_NORMAL}\n")
string.append(f'{BACK_TO_NORMAL}\n')
return ''.join(string)
###################################################################
# Image preparation
###################################################################
def rescale_image(image: np.ndarray, size: tuple) -> np.ndarray:
"""
The unicode bitmaps are hardcoded for 4x8 segments, they cannot be scaled
......@@ -310,33 +368,37 @@ def check_and_adjust_image_dims(image:np.ndarray) -> np.ndarray:
if image.ndim == 2:
image = np.expand_dims(image, 2)
elif image.ndim == 3:
if image.shape[2] == 1: # grayscale image
pass
elif image.shape[2] == 3: # colorful image
if image.shape[2] == 1 or image.shape[2] == 3: # grayscale image
pass
elif image.shape[2] == 4: # contains alpha channel
image = image[:, :, :3]
elif image.shape[0] == 3: # torch images have color channels as the first axis
image = np.moveaxis(image, 0, -1)
else:
raise ValueError(F"Image must have 2 (grayscale) or 3 (colorful) dimensions. Yours has {image.ndim}")
raise ValueError(
f'Image must have 2 (grayscale) or 3 (colorful) dimensions. Yours has {image.ndim}'
)
return image
def check_and_adjust_values(image:np.ndarray, relative_intensity:bool = True) -> np.ndarray:
def check_and_adjust_values(
image: np.ndarray, relative_intensity: bool = True
) -> np.ndarray:
"""
Checks if the values are between 0 and 255
If not, normalizes the values so they are in that interval
Parameters:
-------------
Parameters
----------
- image
- relative_intensity: If maximum values are pretty low, they will be barely visible. If true, it normalizes
the values, so that the maximum is at 255
Returns:
-----------
Returns
-------
- adjusted_image
"""
m = np.max(image)
......@@ -351,6 +413,7 @@ def check_and_adjust_values(image:np.ndarray, relative_intensity:bool = True) ->
return image
def choose_slice(image: np.ndarray, axis: int = None, slice: int = None):
"""
Preview give the possibility to choose axis to be sliced and slice to be displayed
......@@ -365,11 +428,19 @@ def choose_slice(image:np.ndarray, axis:int = None, slice:int = None):
slice = image.shape[2] - 1
return image[:, :, slice]
###################################################################
# Main function
###################################################################
def image_preview(image:np.ndarray, image_width:int = 80, axis:int = None, slice:int = None, relative_intensity:bool = True):
def image_preview(
image: np.ndarray,
image_width: int = 80,
axis: int = None,
slice: int = None,
relative_intensity: bool = True,
):
if image.ndim == 3 and image.shape[2] > 4:
image = choose_slice(image, axis, slice)
image = check_and_adjust_image_dims(image)
......@@ -377,4 +448,3 @@ def image_preview(image:np.ndarray, image_width:int = 80, axis:int = None, slice
image = check_and_adjust_values(image, relative_intensity)
image = rescale_image(image, (X_STRIDE * image_width, int(ratio * image.shape[0])))
print(image_ansi_string(image))
import numpy as np
from typing import Optional, Union, Tuple
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ipywidgets as widgets
import logging
from qim3d.utils._logger import log
from typing import Tuple, Union
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from qim3d.utils._logger import log
previous_logging_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.CRITICAL)
import structure_tensor as st
logging.getLogger().setLevel(previous_logging_level)
......@@ -94,10 +93,9 @@ def vectors(
if grid_size < min_grid_size or grid_size > max_grid_size:
# Adjust grid size as little as possible to be within the limits
grid_size = min(max(min_grid_size, grid_size), max_grid_size)
log.warning(f"Adjusting grid size to {grid_size} as it is out of bounds.")
log.warning(f'Adjusting grid size to {grid_size} as it is out of bounds.')
def _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show):
# Choose the appropriate slice based on the specified dimension
if axis == 0:
data_slice = volume[slice_idx, :, :]
......@@ -118,10 +116,10 @@ def vectors(
vectors_slice_z = vec[0, :, :, slice_idx]
else:
raise ValueError("Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.")
raise ValueError('Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.')
# Create three subplots
fig, ax = plt.subplots(1, 3, figsize=figsize, layout="constrained")
fig, ax = plt.subplots(1, 3, figsize=figsize, layout='constrained')
blend_hue_saturation = (
lambda hue, sat: hue * (1 - sat) + 0.5 * sat
......@@ -164,7 +162,7 @@ def vectors(
vectors_slice_x[g, g],
vectors_slice_y[g, g],
color=rgba_quiver_flat,
angles="xy",
angles='xy',
)
ax[0].quiver(
ymesh[g, g],
......@@ -172,14 +170,14 @@ def vectors(
-vectors_slice_x[g, g],
-vectors_slice_y[g, g],
color=rgba_quiver_flat,
angles="xy",
angles='xy',
)
ax[0].imshow(data_slice, cmap=volume_cmap, vmin=vmin, vmax=vmax)
ax[0].set_title(
f"Orientation vectors (slice {slice_idx})"
f'Orientation vectors (slice {slice_idx})'
if not interactive
else "Orientation vectors"
else 'Orientation vectors'
)
ax[0].set_axis_off()
......@@ -218,14 +216,14 @@ def vectors(
)
ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=rgba_bin)
ax[1].set_xlabel("Angle [radians]")
ax[1].set_xlabel('Angle [radians]')
ax[1].set_xlim([0, np.pi])
ax[1].set_aspect(np.pi / ax[1].get_ylim()[1])
ax[1].set_xticks([0, np.pi / 2, np.pi])
ax[1].set_xticklabels(["0", "$\\frac{\\pi}{2}$", "$\\pi$"])
ax[1].set_xticklabels(['0', '$\\frac{\\pi}{2}$', '$\\pi$'])
ax[1].set_yticks([])
ax[1].set_ylabel("Frequency")
ax[1].set_title(f"Histogram over orientation angles")
ax[1].set_ylabel('Frequency')
ax[1].set_title('Histogram over orientation angles')
# ----- Subplot 3: Image slice colored according to orientation ----- #
# Calculate z-component (saturation)
......@@ -240,13 +238,13 @@ def vectors(
# Grayscale image slice blended with orientation colors
data_slice_orientation_colored = (
blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255
).astype("uint8")
).astype('uint8')
ax[2].imshow(data_slice_orientation_colored)
ax[2].set_title(
f"Colored orientations (slice {slice_idx})"
f'Colored orientations (slice {slice_idx})'
if not interactive
else "Colored orientations"
else 'Colored orientations'
)
ax[2].set_axis_off()
......@@ -260,7 +258,7 @@ def vectors(
if vec.ndim == 5:
vec = vec[0, ...]
log.warning(
"Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used."
'Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used.'
)
if slice_idx is None:
......@@ -269,7 +267,7 @@ def vectors(
elif isinstance(slice_idx, float):
if slice_idx < 0 or slice_idx > 1:
raise ValueError(
"Values of slice_idx of float type must be between 0 and 1."
'Values of slice_idx of float type must be between 0 and 1.'
)
slice_idx = int(slice_idx * volume.shape[0]) - 1
......@@ -279,8 +277,8 @@ def vectors(
max=volume.shape[axis] - 1,
step=1,
value=slice_idx,
description="Slice index",
layout=widgets.Layout(width="450px"),
description='Slice index',
layout=widgets.Layout(width='450px'),
)
grid_size_slider = widgets.IntSlider(
......@@ -288,8 +286,8 @@ def vectors(
max=max_grid_size,
step=1,
value=grid_size,
description="Grid size",
layout=widgets.Layout(width="450px"),
description='Grid size',
layout=widgets.Layout(width='450px'),
)
widget_obj = widgets.interactive(
......@@ -305,7 +303,7 @@ def vectors(
# Arrange sliders horizontally
sliders_box = widgets.HBox([slide_idx_slider, grid_size_slider])
widget_obj = widgets.VBox([sliders_box, widget_obj.children[-1]])
widget_obj.layout.align_items = "center"
widget_obj.layout.align_items = 'center'
if show:
display(widget_obj)
......
from ._segmentation import segmentation
from ._qim_colors import qim
from ._segmentation import segmentation
from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap
qim = LinearSegmentedColormap.from_list(
"qim",
'qim',
[
(0.6, 0.0, 0.0), # 990000
(1.0, 0.6, 0.0), # ff9900
......
......@@ -3,9 +3,10 @@ This module provides a collection of colormaps useful for 3D visualization.
"""
import colorsys
from typing import Union, Tuple
import numpy as np
import math
from typing import Tuple, Union
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
......@@ -34,7 +35,7 @@ def rearrange_colors(randRGBcolors_old, min_dist=0.5):
def segmentation(
num_labels: int,
style: str = "bright",
style: str = 'bright',
first_color_background: bool = True,
last_color_background: bool = False,
background_color: Union[Tuple[float, float, float], str] = (0.0, 0.0, 0.0),
......@@ -96,17 +97,18 @@ def segmentation(
Tip:
The `min_dist` parameter can be used to control the distance between neighboring colors.
![colormap objects mind_dist](../../assets/screenshots/viz-colormaps-min_dist.gif)
"""
from skimage import color
# Check style
if style not in ("bright", "soft", "earth", "ocean"):
if style not in ('bright', 'soft', 'earth', 'ocean'):
raise ValueError(
f'Please choose "bright", "soft", "earth" or "ocean" for style in qim3dCmap not "{style}"'
)
# Translate strings to background color
color_dict = {"black": (0.0, 0.0, 0.0), "white": (1.0, 1.0, 1.0)}
color_dict = {'black': (0.0, 0.0, 0.0), 'white': (1.0, 1.0, 1.0)}
if not isinstance(background_color, tuple):
try:
background_color = color_dict[background_color]
......@@ -122,7 +124,7 @@ def segmentation(
rng = np.random.default_rng(seed)
# Generate color map for bright colors, based on hsv
if style == "bright":
if style == 'bright':
randHSVcolors = [
(
rng.uniform(low=0.0, high=1),
......@@ -140,7 +142,7 @@ def segmentation(
)
# Generate soft pastel colors, by limiting the RGB spectrum
if style == "soft":
if style == 'soft':
low = 0.6
high = 0.95
randRGBcolors = [
......@@ -153,7 +155,7 @@ def segmentation(
]
# Generate color map for earthy colors, based on LAB
if style == "earth":
if style == 'earth':
randLABColors = [
(
rng.uniform(low=25, high=110),
......@@ -169,7 +171,7 @@ def segmentation(
randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())
# Generate color map for ocean colors, based on LAB
if style == "ocean":
if style == 'ocean':
randLABColors = [
(
rng.uniform(low=0, high=110),
......@@ -195,8 +197,6 @@ def segmentation(
randRGBcolors[-1] = background_color
# Create colormap
objects = LinearSegmentedColormap.from_list("objects", randRGBcolors, N=num_labels)
objects = LinearSegmentedColormap.from_list('objects', randRGBcolors, N=num_labels)
return objects
from pathlib import Path
import os
import platform
from pathlib import Path
from typing import Callable
import qim3d
class NotInstalledError(Exception): pass
SOURCE_FNM = "fnm env --use-on-cd | Out-String | Invoke-Expression;"
class NotInstalledError(Exception):
pass
SOURCE_FNM = 'fnm env --use-on-cd | Out-String | Invoke-Expression;'
LINUX = 'Linux'
WINDOWS = 'Windows'
MAC = 'Darwin'
def get_itk_dir() -> Path:
qim_dir = Path(qim3d.__file__).parents[0] # points to .../qim3d/qim3d/
dir = qim_dir.joinpath("viz/itk_vtk_viewer")
dir = qim_dir.joinpath('viz/itk_vtk_viewer')
return dir
def get_nvm_dir(dir: Path = None) -> Path:
if platform.system() in [LINUX, MAC]:
following_folder = ".nvm"
following_folder = '.nvm'
elif platform.system() == WINDOWS:
following_folder = ''
return dir.joinpath(following_folder) if dir is not None else get_qim_dir().joinpath(following_folder)
return (
dir.joinpath(following_folder)
if dir is not None
else get_qim_dir().joinpath(following_folder)
)
def get_node_binaries_dir(nvm_dir: Path = None) -> Path:
"""
......@@ -32,13 +42,17 @@ def get_node_binaries_dir(nvm_dir:Path = None) -> Path:
For windows we have to pass the argument nvm_dir and it is the itk-vtk_dir
"""
if platform.system() in [LINUX, MAC]:
following_folder = "versions/node"
following_folder = 'versions/node'
binaries_folder = 'bin'
elif platform.system() == WINDOWS:
following_folder = 'node-versions'
binaries_folder = 'installation'
node_folder = nvm_dir.joinpath(following_folder) if nvm_dir is not None else get_nvm_dir().joinpath(following_folder)
node_folder = (
nvm_dir.joinpath(following_folder)
if nvm_dir is not None
else get_nvm_dir().joinpath(following_folder)
)
# We don't wanna throw an error
# Instead we return None and check the returned value in run.py
......@@ -51,9 +65,15 @@ def get_node_binaries_dir(nvm_dir:Path = None) -> Path:
if os.path.isdir(path):
return path.joinpath(binaries_folder)
def get_viewer_dir(dir: Path = None) -> Path:
following_folder = "viewer_app"
return dir.joinpath(following_folder) if dir is not None else get_qim_dir().joinpath(following_folder)
following_folder = 'viewer_app'
return (
dir.joinpath(following_folder)
if dir is not None
else get_qim_dir().joinpath(following_folder)
)
def get_viewer_binaries(viewer_dir: Path = None) -> Path:
following_folder1 = 'node_modules'
......@@ -62,7 +82,10 @@ def get_viewer_binaries(viewer_dir:Path = None) -> Path:
viewer_dir = get_viewer_dir()
return viewer_dir.joinpath(following_folder1).joinpath(following_folder2)
def run_for_platform(linux_func:Callable, windows_func:Callable, macos_func:Callable):
def run_for_platform(
linux_func: Callable, windows_func: Callable, macos_func: Callable
):
this_platform = platform.system()
if this_platform == LINUX:
return linux_func()
......@@ -71,5 +94,6 @@ def run_for_platform(linux_func:Callable, windows_func:Callable, macos_func:Call
elif this_platform == MAC:
return macos_func()
def lambda_raise(err):
raise err
from pathlib import Path
import subprocess
import os
import platform
import subprocess
from pathlib import Path
from .helpers import get_itk_dir, get_nvm_dir, get_node_binaries_dir, get_viewer_dir, SOURCE_FNM, NotInstalledError, run_for_platform
from .helpers import (
SOURCE_FNM,
NotInstalledError,
get_itk_dir,
get_node_binaries_dir,
get_nvm_dir,
get_viewer_dir,
run_for_platform,
)
class Installer:
"""
Implements installation procedure of itk-vtk-viewer for each OS.
Also goes for minimal installation: checking if the necessary binaries aren't already installed
"""
def __init__(self):
self.platform = platform.system()
self.install_functions = (self.install_node_manager, self.install_node, self.install_viewer)
self.install_functions = (
self.install_node_manager,
self.install_node,
self.install_viewer,
)
self.dir = get_itk_dir() # itk_vtk_viewer folder within qim3d.viz
......@@ -32,11 +45,14 @@ class Installer:
"""
Checks for global and local installation of nvm (Node Version Manager)
"""
def _linux() -> bool:
command_f = lambda nvmsh: F'/bin/bash -c "source {nvmsh} && nvm"'
command_f = lambda nvmsh: f'/bin/bash -c "source {nvmsh} && nvm"'
if self.os_nvm_dir is not None:
nvmsh = self.os_nvm_dir.joinpath('nvm.sh')
output = subprocess.run(command_f(nvmsh), shell = True, capture_output = True)
output = subprocess.run(
command_f(nvmsh), shell=True, capture_output=True
)
if not output.stderr:
self.nvm_dir = self.os_nvm_dir
return True
......@@ -44,24 +60,31 @@ class Installer:
nvmsh = self.qim_nvm_dir.joinpath('nvm.sh')
output = subprocess.run(command_f(nvmsh), shell=True, capture_output=True)
self.nvm_dir = self.qim_nvm_dir
return not bool(output.stderr) # If there is an error running the above command then it is not installed (not in expected location)
return not bool(
output.stderr
) # If there is an error running the above command then it is not installed (not in expected location)
def _windows() -> bool:
output = subprocess.run(['powershell.exe', 'fnm --version'], capture_output=True)
output = subprocess.run(
['powershell.exe', 'fnm --version'], capture_output=True
)
return not bool(output.stderr)
return run_for_platform(linux_func=_linux, windows_func=_windows,macos_func= _linux)
return run_for_platform(
linux_func=_linux, windows_func=_windows, macos_func=_linux
)
@property
def is_node_already_installed(self) -> bool:
"""
Checks for global and local installation of Node.js and npm (Node Package Manager)
"""
def _linux() -> bool:
# get_node_binaries_dir might return None if the folder is not there
# In that case there is 'None' added to the PATH, thats not a problem
# the command will return an error to the output and it will be evaluated as not installed
command = F'export PATH="$PATH:{get_node_binaries_dir(self.nvm_dir)}" && npm version'
command = f'export PATH="$PATH:{get_node_binaries_dir(self.nvm_dir)}" && npm version'
output = subprocess.run(command, shell=True, capture_output=True)
return not bool(output.stderr)
......@@ -70,9 +93,9 @@ class Installer:
# Didn't figure out how to install the viewer and run it properly when using global npm
return False
return run_for_platform(linux_func=_linux,windows_func= _windows,macos_func= _linux)
return run_for_platform(
linux_func=_linux, windows_func=_windows, macos_func=_linux
)
def install(self):
"""
......@@ -82,10 +105,10 @@ class Installer:
"""
if self.is_node_manager_already_installed:
self.install_status = 1
print("Node manager already installed")
print('Node manager already installed')
if self.is_node_already_installed:
self.install_status = 2
print("Node.js already installed")
print('Node.js already installed')
else:
self.install_status = 0
......@@ -95,17 +118,28 @@ class Installer:
def install_node_manager(self):
def _linux():
print(F'Installing Node manager into {self.nvm_dir}...')
_ = subprocess.run([F'export NVM_DIR={self.nvm_dir} && bash {self.dir.joinpath("install_nvm.sh")}'], shell = True, capture_output=True)
print(f'Installing Node manager into {self.nvm_dir}...')
_ = subprocess.run(
[
f'export NVM_DIR={self.nvm_dir} && bash {self.dir.joinpath("install_nvm.sh")}'
],
shell=True,
capture_output=True,
)
def _windows():
print("Installing node manager...")
subprocess.run(["powershell.exe", F'$env:XDG_DATA_HOME = "{self.dir}";', "winget install Schniz.fnm"])
print('Installing node manager...')
subprocess.run(
[
'powershell.exe',
f'$env:XDG_DATA_HOME = "{self.dir}";',
'winget install Schniz.fnm',
]
)
# self._run_for_platform(_linux, None, _windows)
run_for_platform(linux_func=_linux, windows_func=_windows, macos_func=_linux)
print("Node manager installed")
print('Node manager installed')
def install_node(self):
def _linux():
......@@ -114,10 +148,10 @@ class Installer:
We have to source that file either way, to be able to call nvm function
If it was't installed before, we need to export NVM_DIR in order to install npm to correct location
"""
print(F'Installing node.js into {self.nvm_dir}...')
print(f'Installing node.js into {self.nvm_dir}...')
if self.install_status == 0:
nvm_dir = self.nvm_dir
prefix = F'export NVM_DIR={nvm_dir} && '
prefix = f'export NVM_DIR={nvm_dir} && '
elif self.install_status == 1:
nvm_dir = self.os_nvm_dir
......@@ -128,35 +162,49 @@ class Installer:
output = subprocess.run(command, shell=True, capture_output=True)
def _windows():
subprocess.run(["powershell.exe",F'$env:XDG_DATA_HOME = "{self.dir}";', SOURCE_FNM, F"fnm use --fnm-dir {self.dir} --install-if-missing 22"])
print(F'Installing node.js...')
subprocess.run(
[
'powershell.exe',
f'$env:XDG_DATA_HOME = "{self.dir}";',
SOURCE_FNM,
f'fnm use --fnm-dir {self.dir} --install-if-missing 22',
]
)
print('Installing node.js...')
run_for_platform(linux_func=_linux, windows_func=_windows, macos_func=_linux)
print("Node.js installed")
print('Node.js installed')
def install_viewer(self):
def _linux():
# Adds local binaries to the path in case we had to install node first (locally into qim folder), but shouldnt interfere even if
# npm is installed globally
command = F'export PATH="$PATH:{get_node_binaries_dir(self.nvm_dir)}" && npm install --prefix {self.viewer_dir} itk-vtk-viewer'
command = f'export PATH="$PATH:{get_node_binaries_dir(self.nvm_dir)}" && npm install --prefix {self.viewer_dir} itk-vtk-viewer'
output = subprocess.run([command], shell=True, capture_output=True)
# print(output.stderr)
def _windows():
try:
node_bin = get_node_binaries_dir(self.dir)
print(F'Installing into {self.viewer_dir}')
subprocess.run(["powershell.exe", F'$env:PATH=$env:PATH + \';{node_bin}\';', F"npm install --prefix {self.viewer_dir} itk-vtk-viewer"], capture_output=True)
print(f'Installing into {self.viewer_dir}')
subprocess.run(
[
'powershell.exe',
f"$env:PATH=$env:PATH + ';{node_bin}';",
f'npm install --prefix {self.viewer_dir} itk-vtk-viewer',
],
capture_output=True,
)
except NotInstalledError: # Not installed in qim
subprocess.run(["powershell.exe", SOURCE_FNM, F"npm install itk-vtk-viewer"], capture_output=True)
subprocess.run(
['powershell.exe', SOURCE_FNM, 'npm install itk-vtk-viewer'],
capture_output=True,
)
self.viewer_dir = get_viewer_dir(self.dir)
if not os.path.isdir(self.viewer_dir):
os.mkdir(self.viewer_dir)
print(F"Installing itk-vtk-viewer...")
print('Installing itk-vtk-viewer...')
run_for_platform(linux_func=_linux, windows_func=_windows, macos_func=_linux)
print("Itk-vtk-viewer installed")
\ No newline at end of file
print('Itk-vtk-viewer installed')
import subprocess
from pathlib import Path
import os
import webbrowser
import subprocess
import threading
import time
import webbrowser
from pathlib import Path
import qim3d.utils
from qim3d.utils._logger import log
......@@ -11,10 +11,8 @@ from qim3d.utils._logger import log
from .helpers import *
from .installation import Installer
# Start viewer
START_COMMAND = "itk-vtk-viewer -s"
START_COMMAND = 'itk-vtk-viewer -s'
# Lock, so two threads can safely read and write to is_installed
c = threading.Condition()
......@@ -23,12 +21,12 @@ is_installed = True
def run_global(port=3000):
linux_func = lambda: subprocess.run(
START_COMMAND+f" -p {port}", shell=True, stderr=subprocess.DEVNULL
START_COMMAND + f' -p {port}', shell=True, stderr=subprocess.DEVNULL
)
# First sourcing the node.js, if sourcing via fnm doesnt help and user would have to do it any other way, it would throw an error and suggest to install viewer to qim library
windows_func = lambda: subprocess.run(
["powershell.exe", SOURCE_FNM, START_COMMAND+f" -p {port}"],
['powershell.exe', SOURCE_FNM, START_COMMAND + f' -p {port}'],
shell=True,
stderr=subprocess.DEVNULL,
)
......@@ -48,7 +46,7 @@ def run_within_qim_dir(port=3000):
node_bin = get_node_binaries_dir(get_nvm_dir(dir))
if node_bin is None:
# Didn't find node binaries there so it looks for enviroment variable to tell it where is nvm folder
node_bin = get_node_binaries_dir(Path(str(os.getenv("NVM_DIR"))))
node_bin = get_node_binaries_dir(Path(str(os.getenv('NVM_DIR'))))
if node_bin is not None:
subprocess.run(
......@@ -62,9 +60,9 @@ def run_within_qim_dir(port=3000):
if node_bin is not None:
subprocess.run(
[
"powershell.exe",
'powershell.exe',
f"$env:PATH = $env:PATH + ';{viewer_bin};{node_bin}';",
START_COMMAND+f" -p {port}",
START_COMMAND + f' -p {port}',
],
stderr=subprocess.DEVNULL,
)
......@@ -169,7 +167,6 @@ def try_opening_itk_vtk(
global is_installed
c.acquire()
if is_installed:
# Normalize the filename. This is necessary for trailing slashes by the end of the path
filename_norm = os.path.normpath(os.path.abspath(filename))
......@@ -178,12 +175,12 @@ def try_opening_itk_vtk(
os.path.dirname(filename_norm), port=file_server_port
)
viz_url = f"http://localhost:{viewer_port}/?rotate=false&fileToLoad=http://localhost:{file_server_port}/{os.path.basename(filename_norm)}"
viz_url = f'http://localhost:{viewer_port}/?rotate=false&fileToLoad=http://localhost:{file_server_port}/{os.path.basename(filename_norm)}'
if open_browser:
webbrowser.open_new_tab(viz_url)
log.info(f"\nVisualization url:\n{viz_url}\n")
log.info(f'\nVisualization url:\n{viz_url}\n')
c.release()
# Start the delayed open in a separate thread
......@@ -214,7 +211,7 @@ def itk_vtk(
filename: str = None,
open_browser: bool = True,
file_server_port: int = 8042,
viewer_port: int = 3000
viewer_port: int = 3000,
):
"""
Command to run in cli/__init__.py. Tries to run the vizualization,
......@@ -224,18 +221,22 @@ def itk_vtk(
"""
try:
try_opening_itk_vtk(filename,
try_opening_itk_vtk(
filename,
open_browser=open_browser,
file_server_port=file_server_port,
viewer_port = viewer_port)
viewer_port=viewer_port,
)
except NotInstalledError:
message = "Itk-vtk-viewer is not installed or qim3d can not find it.\nYou can either:\n\to Use 'qim3d viz SOURCE -m k3d' to display data using different method\n\to Install itk-vtk-viewer yourself following https://kitware.github.io/itk-vtk-viewer/docs/cli.html#Installation\n\to Let qim3D install itk-vtk-viewer now (it will also install node.js in qim3d library)\nDo you want qim3D to install itk-vtk-viewer now?"
print(message)
answer = input("[Y/n]:")
if answer in "Yy":
answer = input('[Y/n]:')
if answer in 'Yy':
Installer().install()
try_opening_itk_vtk(filename,
try_opening_itk_vtk(
filename,
open_browser=open_browser,
file_server_port=file_server_port,
viewer_port = viewer_port)
\ No newline at end of file
viewer_port=viewer_port,
)
......@@ -12,8 +12,8 @@ scipy>=1.11.2
seaborn>=0.12.2
pydicom==2.4.4
setuptools>=68.0.0
imagecodecs==2023.7.10
tifffile==2023.8.12
imagecodecs>=2024.12.30
tifffile>=2025.1.10
torch>=2.0.1
torchvision>=0.15.2
torchinfo>=1.8.0
......@@ -31,4 +31,4 @@ ome_zarr>=0.9.0
dask-image>=2024.5.3
trimesh>=4.4.9
slgbuilder>=0.2.1
testbook>=0.4.2
PyGEL3D>=0.5.2
\ No newline at end of file
pre-commit>=4.1.0
ruff>=0.9.3
testbook>=0.4.2
\ No newline at end of file
import os
import re
from setuptools import find_packages, setup
# Read the contents of your README file
with open("README.md", "r", encoding="utf-8") as f:
with open('README.md', encoding='utf-8') as f:
long_description = f.read()
# Read the version from the __init__.py file
def read_version():
with open(os.path.join("qim3d", "__init__.py"), "r", encoding="utf-8") as f:
with open(os.path.join('qim3d', '__init__.py'), encoding='utf-8') as f:
version_file = f.read()
version_match = re.search(r'^__version__ = ["\']([^"\']*)["\']', version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
raise RuntimeError('Unable to find version string.')
setup(
name="qim3d",
name='qim3d',
version=read_version(),
author="Felipe Delestro",
author_email="fima@dtu.dk",
description="QIM tools and user interfaces for volumetric imaging",
author='Felipe Delestro',
author_email='fima@dtu.dk',
description='QIM tools and user interfaces for volumetric imaging',
long_description=long_description,
long_description_content_type="text/markdown",
url="https://platform.qim.dk/qim3d",
long_description_content_type='text/markdown',
url='https://platform.qim.dk/qim3d',
packages=find_packages(),
include_package_data=True,
entry_points = {
'console_scripts': [
'qim3d=qim3d.cli:main'
]
},
entry_points={'console_scripts': ['qim3d=qim3d.cli:main']},
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"Natural Language :: English",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Image Processing",
"Topic :: Scientific/Engineering :: Visualization",
"Topic :: Software Development :: User Interfaces",
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'Natural Language :: English',
'Operating System :: OS Independent',
'Topic :: Scientific/Engineering :: Image Processing',
'Topic :: Scientific/Engineering :: Visualization',
'Topic :: Software Development :: User Interfaces',
],
python_requires=">=3.10",
python_requires='>=3.10',
install_requires=[
"gradio==4.44",
"h5py>=3.9.0",
......@@ -57,8 +56,8 @@ setup(
"scipy>=1.11.2",
"seaborn>=0.12.2",
"setuptools>=68.0.0",
"tifffile==2023.8.12",
"imagecodecs==2023.7.10",
"tifffile>=2025.1.10",
"imagecodecs>=2024.12.30",
"tqdm>=4.65.0",
"nibabel>=5.2.0",
"ipywidgets>=8.1.2",
......@@ -72,7 +71,8 @@ setup(
"ome_zarr>=0.9.0",
"dask-image>=2024.5.3",
"scikit-image>=0.24.0",
"trimesh>=4.4.9"
"trimesh>=4.4.9",
"PyGEL3D>=0.5.2"
],
extras_require={
"deep-learning": [
......@@ -81,6 +81,9 @@ setup(
"torchvision>=0.15.2",
"torchinfo>=1.8.0",
"monai>=1.2.0",
],
'test': [
'testbook>=0.4.2'
]
}
)