diff --git a/qim3d/ml/__init__.py b/qim3d/ml/__init__.py index dc15f8102461cbcd1ddbed1f3bf94dd72888f5c6..3f05eb8aa132652fa632f81050e5a2a0d6e06e3d 100644 --- a/qim3d/ml/__init__.py +++ b/qim3d/ml/__init__.py @@ -1,4 +1,4 @@ from ._augmentations import Augmentation -from ._data import Dataset, prepare_dataloaders, prepare_datasets -from ._ml_utils import inference, volume_inference, model_summary, train_model +from ._data import Dataset, prepare_datasets, prepare_dataloaders +from ._ml_utils import model_summary, train_model, inference from .models import * \ No newline at end of file diff --git a/qim3d/ml/_augmentations.py b/qim3d/ml/_augmentations.py index 276e33f08304ec548bd637881a84934c4312cdac..e3b269e47e4309e4582c63e6b189bdbe6e213dd8 100644 --- a/qim3d/ml/_augmentations.py +++ b/qim3d/ml/_augmentations.py @@ -6,34 +6,30 @@ class Augmentation: Args: resize (str, optional): Specifies how the images should be reshaped to the appropriate size. - trainsform_train (str, optional): level of transformation for the training set. - transform_validation (str, optional): level of transformation for the validation set. - transform_test (str, optional): level of transformation for the test set. - mean (float, optional): The mean value for normalizing pixel intensities. - std (float, optional): The standard deviation value for normalizing pixel intensities. + trainsform_train (str, optional): Level of transformation for the training set. + transform_validation (str, optional): Level of transformation for the validation set. + transform_test (str, optional): Level of transformation for the test set. Raises: - ValueError: If the ´resize´ is neither 'crop', 'resize' or 'padding'. + ValueError: If the ´resize´ is neither 'crop', 'resize' nor 'padding'. Example: - my_augmentation = Augmentation(resize = 'crop', transform_train = 'heavy') + my_augmentation = Augmentation(resize = 'crop', transform_train = 'light') """ - def __init__(self, - resize: str = 'crop', - transform_train: str = 'moderate', - transform_validation: str | None = None, - transform_test: str | None = None, - mean: float = 0.5, - std: float = 0.5, - ): + def __init__( + self, + resize: str = 'crop', + transform_train: str = 'moderate', + transform_validation: str | None = None, + transform_test: str | None = None, + ): if resize not in ['crop', 'reshape', 'padding']: - raise ValueError(f"Invalid resize type: {resize}. Use either 'crop', 'resize' or 'padding'.") + msg = f"Invalid resize type: {resize}. Use either 'crop', 'resize' or 'padding'." + raise ValueError(msg) self.resize = resize - self.mean = mean - self.std = std self.transform_train = transform_train self.transform_validation = transform_validation self.transform_test = transform_test @@ -47,7 +43,8 @@ class Augmentation: level (str, optional): Level of augmentation. One of [None, 'light', 'moderate', 'heavy']. Raises: - ValueError: If `level` is neither None, light, moderate nor heavy. + ValueError: If `img_shape` is not 3D. + ValueError: If `level` is neither None, 'light', 'moderate' nor 'heavy'. """ from monai.transforms import ( Compose, RandRotate90d, RandFlipd, RandAffined, ToTensor, \ @@ -69,7 +66,7 @@ class Augmentation: # Baseline augmentations # TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise) - baseline_aug = [ToTensor()] + baseline_aug = [ToTensor()] #, NormalizeIntensityd(keys=["image"])] # Resize augmentations if self.resize == 'crop':