diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py
index cf1416f9c750fc9f38acfe6fde2202dfe6c4367f..e434c5c2af99c238f99e89de40830070fcf79c89 100644
--- a/qim3d/utils/internal_tools.py
+++ b/qim3d/utils/internal_tools.py
@@ -1,22 +1,25 @@
 """ Provides a collection of internal utility functions."""
 
-import socket
+import getpass
 import hashlib
-import outputformat as ouf
-import matplotlib.pyplot as plt
-import matplotlib
-import numpy as np
 import os
 import shutil
-import requests
-import getpass
-from PIL import Image
+import socket
 from pathlib import Path
-from qim3d.io.logger import log
-from fastapi import FastAPI
+
 import gradio as gr
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import outputformat as ouf
+import requests
+from fastapi import FastAPI
+from PIL import Image
+from scipy.ndimage import zoom
 from uvicorn import run
 
+from qim3d.io.logger import log
+
 
 def mock_plot():
     """Creates a mock plot of a sine wave.
@@ -306,6 +309,29 @@ def get_css():
 
     return css_content
 
+def downscale_img(img, max_voxels=512**3):
+    """ Downscale image if total number of voxels exceeds 512³.
+
+    Args:
+        img (np.Array): Input image.
+        max_voxels (int, optional): Max number of voxels. Defaults to 512³=134217728.
+
+    Returns:
+        np.Array: Downscaled image if total number of voxels exceeds 512³.
+    """
+
+    # Calculate total number of pixels in the image
+    total_voxels = np.prod(img.shape)
+
+    # If total pixels is less than or equal to 512³, return original image
+    if total_voxels <= max_voxels:
+        return img
+
+    # Calculate zoom factor
+    zoom_factor = (max_voxels / total_voxels) ** (1/3)
+
+    # Downscale image
+    return zoom(img, zoom_factor)
 
 def scale_to_float16(arr: np.ndarray):
     """
@@ -335,4 +361,4 @@ def scale_to_float16(arr: np.ndarray):
     # Convert the scaled array to float16 data type
     arr = arr.astype(np.float16)
 
-    return arr
+    return arr
\ No newline at end of file
diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py
index 8b56df64f350c5bdcb4f3bab4cd1a1b52be5a521..837b2fc748198aecad16f94053301fbe2530fe61 100644
--- a/qim3d/viz/k3d.py
+++ b/qim3d/viz/k3d.py
@@ -9,10 +9,22 @@ Volumetric visualization using K3D
 
 import k3d
 import numpy as np
-from qim3d.utils.internal_tools import scale_to_float16
 
-
-def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=None, samples="auto", **kwargs):
+from qim3d.io.logger import log
+from qim3d.utils.internal_tools import downscale_img, scale_to_float16
+
+
+def vol(
+    img,
+    aspectmode="data",
+    show=True,
+    save=False,
+    grid_visible=False,
+    cmap=None,
+    samples="auto",
+    max_voxels=412**3,
+    **kwargs,
+):
     """
     Visualizes a 3D volume using volumetric rendering.
 
@@ -28,7 +40,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=
             file will be saved. Defaults to False.
         grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
         cmap (list, optional): The color map to be used for the volume rendering. Defaults to None.
-        samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512. 
+        samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512.
             Lower values will render faster but with lower quality.
         **kwargs: Additional keyword arguments to be passed to the `k3d.plot` function.
 
@@ -45,7 +57,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=
         import qim3d
 
         vol = qim3d.examples.bone_128x128x128
-        qim3d.viz.vol(vol) 
+        qim3d.viz.vol(vol)
         ```
         <iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe>
 
@@ -56,26 +68,35 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=
         vol = qim3d.examples.bone_128x128x128
         plot = qim3d.viz.vol(vol, show=False, save="plot.html")
         ```
-        
+
     """
     pixel_count = img.shape[0] * img.shape[1] * img.shape[2]
     # target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html
     if samples == "auto":
-        y1,x1 = 256, 16777216 # 256 samples at res 256*256*256=16.777.216
-        y2,x2 = 32, 134217728 # 32 samples at res 512*512*512=134.217.728
+        y1, x1 = 256, 16777216  # 256 samples at res 256*256*256=16.777.216
+        y2, x2 = 32, 134217728  # 32 samples at res 512*512*512=134.217.728
 
         # we fit linear function to the two points
-        a = (y1-y2)/(x1-x2)
-        b = y1 - a*x1
+        a = (y1 - y2) / (x1 - x2)
+        b = y1 - a * x1
 
-        samples = int(min(max(a*pixel_count+b,32),512))
+        samples = int(min(max(a * pixel_count + b, 32), 512))
     else:
-        samples = int(samples) # make sure it's an integer
-
+        samples = int(samples)  # make sure it's an integer
 
     if aspectmode.lower() not in ["data", "cube"]:
         raise ValueError("aspectmode should be either 'data' or 'cube'")
 
+    # check if image should be downsampled for visualization
+    original_shape = img.shape
+    img = downscale_img(img, max_voxels=max_voxels)
+    new_shape = img.shape
+
+    if original_shape != new_shape:
+        log.warning(
+            f"Downsampled image for visualization. From {original_shape} to {new_shape}"
+        )
+
     plt_volume = k3d.volume(
         scale_to_float16(img),
         bounds=(
@@ -86,7 +107,7 @@ def vol(img, aspectmode="data", show=True, save=False, grid_visible=False, cmap=
         color_map=cmap,
         samples=samples,
     )
-    plot = k3d.plot(grid_visible=grid_visible,**kwargs)
+    plot = k3d.plot(grid_visible=grid_visible, **kwargs)
     plot += plt_volume
 
     if save: