diff --git a/qim3d/ml/models/__init__.py b/qim3d/ml/models/__init__.py
index 4624be5fb111dac4b820b52f5ab072bf744c9154..65493e60d531f65687895dddff1c38fea11c5271 100644
--- a/qim3d/ml/models/__init__.py
+++ b/qim3d/ml/models/__init__.py
@@ -1 +1 @@
-from ._unet import UNet, Hyperparameters
+from ._unet import UNet2D, Hyperparameters
diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py
index fa41c0214d3dedccc916f79812f08e06b1eb03a2..40006ad4ffa8f5aa2814747dfb7957ea5cdcf94e 100644
--- a/qim3d/ml/models/_unet.py
+++ b/qim3d/ml/models/_unet.py
@@ -5,7 +5,7 @@ import torch.nn as nn
 from qim3d.utils import log
 
 
-class UNet(nn.Module):
+class UNet2D(nn.Module):
     """
     2D UNet model for QIM imaging.