From ae49d74cc3e6a272791c49f70f7d8499f887d1ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anna=20B=C3=B8gevang=20Ekner?= <s193396@dtu.dk> Date: Tue, 4 Feb 2025 15:46:27 +0100 Subject: [PATCH] 3D UNet implementatioN --- qim3d/ml/models/__init__.py | 2 +- qim3d/ml/models/_unet.py | 84 +++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/qim3d/ml/models/__init__.py b/qim3d/ml/models/__init__.py index 65493e60..9152ab54 100644 --- a/qim3d/ml/models/__init__.py +++ b/qim3d/ml/models/__init__.py @@ -1 +1 @@ -from ._unet import UNet2D, Hyperparameters +from ._unet import UNet, UNet2D, Hyperparameters diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py index 40006ad4..22e7794b 100644 --- a/qim3d/ml/models/_unet.py +++ b/qim3d/ml/models/_unet.py @@ -4,6 +4,85 @@ import torch.nn as nn from qim3d.utils import log +class UNet(nn.Module): + """ + 3D UNet model for QIM imaging. + + This class represents a 3D UNet model designed for imaging segmentation tasks. + + Args: + size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'. + dropout (float, optional): Dropout rate between 0 and 1. Defaults to 0. + kernel_size (int, optional): Convolution kernel size. Defaults to 3. + up_kernel_size (int, optional): Up-convolution kernel size. Defaults to 3. + activation (str, optional): Activation function. Defaults to 'PReLU'. + bias (bool, optional): Whether to include bias in convolutions. Defaults to True. + adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Defaults to 'NDA'. + + Raises: + ValueError: If `size` is not one of 'small', 'medium', or 'large'. + """ + + def __init__( + self, + size="medium", + dropout=0, + kernel_size=3, + up_kernel_size=3, + activation="PReLU", + bias=True, + adn_order="NDA", + ): + super().__init__() + if size not in ["small", "medium", "large"]: + raise ValueError( + f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'." + ) + + self.size = size + self.dropout = dropout + self.kernel_size = kernel_size + self.up_kernel_size = up_kernel_size + self.activation = activation + self.bias = bias + self.adn_order = adn_order + + self.model = self._model_choice() + + def _model_choice(self): + from monai.networks.nets import UNet as monai_UNet + + if self.size == "small": + # 3 layers + self.channels = (64, 128, 256) + + elif self.size == "medium": + # 5 layers + self.channels = (64, 128, 256, 512, 1024) + + elif self.size == "large": + # 6 layers + self.channels = (64, 128, 256, 512, 1024, 2048) + + model = monai_UNet( + spatial_dims=3, + in_channels=1, # TODO: check if image has 1 or multiple input channels + out_channels=1, + channels=self.channels, + strides=(2,) * (len(self.channels) - 1), + kernel_size=self.kernel_size, + up_kernel_size=self.up_kernel_size, + act=self.activation, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_order, + ) + return model + + def forward(self, x): + x = self.model(x) + return x + class UNet2D(nn.Module): """ @@ -54,10 +133,15 @@ class UNet2D(nn.Module): from monai.networks.nets import UNet as monai_UNet if self.size == "small": + # 3 layers self.channels = (64, 128, 256) + elif self.size == "medium": + # 5 layers self.channels = (64, 128, 256, 512, 1024) + elif self.size == "large": + # 6 layers self.channels = (64, 128, 256, 512, 1024, 2048) model = monai_UNet( -- GitLab