diff --git a/qim3d/ml/models/__init__.py b/qim3d/ml/models/__init__.py
index 65493e60d531f65687895dddff1c38fea11c5271..9152ab5434fafb55cc7b2cb6ba6f3227e47e219b 100644
--- a/qim3d/ml/models/__init__.py
+++ b/qim3d/ml/models/__init__.py
@@ -1 +1 @@
-from ._unet import UNet2D, Hyperparameters
+from ._unet import UNet, UNet2D, Hyperparameters
diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py
index 40006ad4ffa8f5aa2814747dfb7957ea5cdcf94e..22e7794b3d5f58133956842067a5b77a68f9d1fa 100644
--- a/qim3d/ml/models/_unet.py
+++ b/qim3d/ml/models/_unet.py
@@ -4,6 +4,85 @@ import torch.nn as nn
 
 from qim3d.utils import log
 
+class UNet(nn.Module):
+    """
+    3D UNet model for QIM imaging.
+
+    This class represents a 3D UNet model designed for imaging segmentation tasks.
+
+    Args:
+        size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'.
+        dropout (float, optional): Dropout rate between 0 and 1. Defaults to 0.
+        kernel_size (int, optional): Convolution kernel size. Defaults to 3.
+        up_kernel_size (int, optional): Up-convolution kernel size. Defaults to 3.
+        activation (str, optional): Activation function. Defaults to 'PReLU'.
+        bias (bool, optional): Whether to include bias in convolutions. Defaults to True.
+        adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Defaults to 'NDA'.
+
+    Raises:
+        ValueError: If `size` is not one of 'small', 'medium', or 'large'.
+    """
+
+    def __init__(
+        self,
+        size="medium",
+        dropout=0,
+        kernel_size=3,
+        up_kernel_size=3,
+        activation="PReLU",
+        bias=True,
+        adn_order="NDA",
+    ):
+        super().__init__()
+        if size not in ["small", "medium", "large"]:
+            raise ValueError(
+                f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'."
+            )
+
+        self.size = size
+        self.dropout = dropout
+        self.kernel_size = kernel_size
+        self.up_kernel_size = up_kernel_size
+        self.activation = activation
+        self.bias = bias
+        self.adn_order = adn_order
+
+        self.model = self._model_choice()
+
+    def _model_choice(self):
+        from monai.networks.nets import UNet as monai_UNet
+
+        if self.size == "small":
+            # 3 layers
+            self.channels = (64, 128, 256)
+
+        elif self.size == "medium":
+            # 5 layers
+            self.channels = (64, 128, 256, 512, 1024)
+
+        elif self.size == "large":
+            # 6 layers
+            self.channels = (64, 128, 256, 512, 1024, 2048)
+
+        model = monai_UNet(
+            spatial_dims=3,
+            in_channels=1,  # TODO: check if image has 1 or multiple input channels
+            out_channels=1,
+            channels=self.channels,
+            strides=(2,) * (len(self.channels) - 1),
+            kernel_size=self.kernel_size,
+            up_kernel_size=self.up_kernel_size,
+            act=self.activation,
+            dropout=self.dropout,
+            bias=self.bias,
+            adn_ordering=self.adn_order,
+        )
+        return model
+
+    def forward(self, x):
+        x = self.model(x)
+        return x
+
 
 class UNet2D(nn.Module):
     """
@@ -54,10 +133,15 @@ class UNet2D(nn.Module):
         from monai.networks.nets import UNet as monai_UNet
 
         if self.size == "small":
+            # 3 layers
             self.channels = (64, 128, 256)
+
         elif self.size == "medium":
+            # 5 layers
             self.channels = (64, 128, 256, 512, 1024)
+
         elif self.size == "large":
+            # 6 layers
             self.channels = (64, 128, 256, 512, 1024, 2048)
 
         model = monai_UNet(