diff --git a/qim3d/ml/__init__.py b/qim3d/ml/__init__.py
index dc15f8102461cbcd1ddbed1f3bf94dd72888f5c6..3f05eb8aa132652fa632f81050e5a2a0d6e06e3d 100644
--- a/qim3d/ml/__init__.py
+++ b/qim3d/ml/__init__.py
@@ -1,4 +1,4 @@
 from ._augmentations import Augmentation
-from ._data import Dataset, prepare_dataloaders, prepare_datasets
-from ._ml_utils import inference, volume_inference, model_summary, train_model
+from ._data import Dataset, prepare_datasets, prepare_dataloaders
+from ._ml_utils import model_summary, train_model, inference
 from .models import *
\ No newline at end of file
diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py
index 276e33f08304ec548bd637881a84934c4312cdac..e3b269e47e4309e4582c63e6b189bdbe6e213dd8 100644
--- a/qim3d/ml/_augmentations.py
+++ b/qim3d/ml/_augmentations.py
@@ -6,34 +6,30 @@ class Augmentation:
         
     Args:
         resize (str, optional): Specifies how the images should be reshaped to the appropriate size.
-        trainsform_train (str, optional): level of transformation for the training set.
-        transform_validation (str, optional): level of transformation for the validation set.
-        transform_test (str, optional): level of transformation for the test set.
-        mean (float, optional): The mean value for normalizing pixel intensities.
-        std (float, optional): The standard deviation value for normalizing pixel intensities.
+        trainsform_train (str, optional): Level of transformation for the training set.
+        transform_validation (str, optional): Level of transformation for the validation set.
+        transform_test (str, optional): Level of transformation for the test set.
 
     Raises:
-        ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'.
+        ValueError: If the ´resize´ is neither 'crop', 'resize' nor 'padding'.
     
     Example:
-        my_augmentation = Augmentation(resize = 'crop', transform_train = 'heavy')
+        my_augmentation = Augmentation(resize = 'crop', transform_train = 'light')
     """
     
-    def __init__(self, 
-                 resize: str = 'crop', 
-                 transform_train: str = 'moderate', 
-                 transform_validation: str | None = None,
-                 transform_test: str | None = None,
-                 mean: float = 0.5, 
-                 std: float = 0.5,
-                ):
+    def __init__(
+            self, 
+            resize: str = 'crop', 
+            transform_train: str = 'moderate', 
+            transform_validation: str | None = None,
+            transform_test: str | None = None,
+        ):
 
         if resize not in ['crop', 'reshape', 'padding']:
-            raise ValueError(f"Invalid resize type: {resize}. Use either 'crop', 'resize' or 'padding'.")
+            msg = f"Invalid resize type: {resize}. Use either 'crop', 'resize' or 'padding'."
+            raise ValueError(msg)
 
         self.resize = resize
-        self.mean = mean
-        self.std = std
         self.transform_train = transform_train
         self.transform_validation = transform_validation
         self.transform_test = transform_test
@@ -47,7 +43,8 @@ class Augmentation:
             level (str, optional): Level of augmentation. One of [None, 'light', 'moderate', 'heavy'].
 
         Raises:
-            ValueError: If `level` is neither None, light, moderate nor heavy.
+            ValueError: If `img_shape` is not 3D.
+            ValueError: If `level` is neither None, 'light', 'moderate' nor 'heavy'.
         """
         from monai.transforms import (
             Compose, RandRotate90d, RandFlipd, RandAffined, ToTensor, \
@@ -69,7 +66,7 @@ class Augmentation:
 
         # Baseline augmentations
         # TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise)
-        baseline_aug = [ToTensor()]
+        baseline_aug = [ToTensor()] #, NormalizeIntensityd(keys=["image"])]
 
         # Resize augmentations
         if self.resize == 'crop':