diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py index cf1416f9c750fc9f38acfe6fde2202dfe6c4367f..e434c5c2af99c238f99e89de40830070fcf79c89 100644 --- a/qim3d/utils/internal_tools.py +++ b/qim3d/utils/internal_tools.py @@ -1,22 +1,25 @@ """ Provides a collection of internal utility functions.""" -import socket +import getpass import hashlib -import outputformat as ouf -import matplotlib.pyplot as plt -import matplotlib -import numpy as np import os import shutil -import requests -import getpass -from PIL import Image +import socket from pathlib import Path -from qim3d.io.logger import log -from fastapi import FastAPI + import gradio as gr +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import outputformat as ouf +import requests +from fastapi import FastAPI +from PIL import Image +from scipy.ndimage import zoom from uvicorn import run +from qim3d.io.logger import log + def mock_plot(): """Creates a mock plot of a sine wave. @@ -306,6 +309,29 @@ def get_css(): return css_content +def downscale_img(img, max_voxels=512**3): + """ Downscale image if total number of voxels exceeds 512³. + + Args: + img (np.Array): Input image. + max_voxels (int, optional): Max number of voxels. Defaults to 512³=134217728. + + Returns: + np.Array: Downscaled image if total number of voxels exceeds 512³. + """ + + # Calculate total number of pixels in the image + total_voxels = np.prod(img.shape) + + # If total pixels is less than or equal to 512³, return original image + if total_voxels <= max_voxels: + return img + + # Calculate zoom factor + zoom_factor = (max_voxels / total_voxels) ** (1/3) + + # Downscale image + return zoom(img, zoom_factor) def scale_to_float16(arr: np.ndarray): """ @@ -335,4 +361,4 @@ def scale_to_float16(arr: np.ndarray): # Convert the scaled array to float16 data type arr = arr.astype(np.float16) - return arr + return arr \ No newline at end of file diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py index 8b56df64f350c5bdcb4f3bab4cd1a1b52be5a521..837b2fc748198aecad16f94053301fbe2530fe61 100644 --- a/qim3d/viz/k3d.py +++ b/qim3d/viz/k3d.py @@ -9,10 +9,22 @@ Volumetric visualization using K3D import k3d import numpy as np -from qim3d.utils.internal_tools import scale_to_float16 - -def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=None, samples="auto", **kwargs): +from qim3d.io.logger import log +from qim3d.utils.internal_tools import downscale_img, scale_to_float16 + + +def vol( + img, + aspectmode="data", + show=True, + save=False, + grid_visible=False, + cmap=None, + samples="auto", + max_voxels=412**3, + **kwargs, +): """ Visualizes a 3D volume using volumetric rendering. @@ -28,7 +40,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap= file will be saved. Defaults to False. grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False. cmap (list, optional): The color map to be used for the volume rendering. Defaults to None. - samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512. + samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512. Lower values will render faster but with lower quality. **kwargs: Additional keyword arguments to be passed to the `k3d.plot` function. @@ -45,7 +57,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap= import qim3d vol = qim3d.examples.bone_128x128x128 - qim3d.viz.vol(vol) + qim3d.viz.vol(vol) ``` <iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe> @@ -56,26 +68,35 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap= vol = qim3d.examples.bone_128x128x128 plot = qim3d.viz.vol(vol, show=False, save="plot.html") ``` - + """ pixel_count = img.shape[0] * img.shape[1] * img.shape[2] # target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html if samples == "auto": - y1,x1 = 256, 16777216 # 256 samples at res 256*256*256=16.777.216 - y2,x2 = 32, 134217728 # 32 samples at res 512*512*512=134.217.728 + y1, x1 = 256, 16777216 # 256 samples at res 256*256*256=16.777.216 + y2, x2 = 32, 134217728 # 32 samples at res 512*512*512=134.217.728 # we fit linear function to the two points - a = (y1-y2)/(x1-x2) - b = y1 - a*x1 + a = (y1 - y2) / (x1 - x2) + b = y1 - a * x1 - samples = int(min(max(a*pixel_count+b,32),512)) + samples = int(min(max(a * pixel_count + b, 32), 512)) else: - samples = int(samples) # make sure it's an integer - + samples = int(samples) # make sure it's an integer if aspectmode.lower() not in ["data", "cube"]: raise ValueError("aspectmode should be either 'data' or 'cube'") + # check if image should be downsampled for visualization + original_shape = img.shape + img = downscale_img(img, max_voxels=max_voxels) + new_shape = img.shape + + if original_shape != new_shape: + log.warning( + f"Downsampled image for visualization. From {original_shape} to {new_shape}" + ) + plt_volume = k3d.volume( scale_to_float16(img), bounds=( @@ -86,7 +107,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap= color_map=cmap, samples=samples, ) - plot = k3d.plot(grid_visible=grid_visible,**kwargs) + plot = k3d.plot(grid_visible=grid_visible, **kwargs) plot += plt_volume if save: