diff --git a/qim3d/ml/models/__init__.py b/qim3d/ml/models/__init__.py index 9152ab5434fafb55cc7b2cb6ba6f3227e47e219b..4624be5fb111dac4b820b52f5ab072bf744c9154 100644 --- a/qim3d/ml/models/__init__.py +++ b/qim3d/ml/models/__init__.py @@ -1 +1 @@ -from ._unet import UNet, UNet2D, Hyperparameters +from ._unet import UNet, Hyperparameters diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py index 5d47b61c26d163137f8674fa942b0c04c6c0857b..b2950a2b21e939aaf35f038347f0e438e05a7a5b 100644 --- a/qim3d/ml/models/_unet.py +++ b/qim3d/ml/models/_unet.py @@ -84,87 +84,6 @@ class UNet(nn.Module): x = self.model(x) return x - -class UNet2D(nn.Module): - """ - 2D UNet model for QIM imaging. - - This class represents a 2D UNet model designed for imaging segmentation tasks. - - Args: - size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'. - dropout (float, optional): Dropout rate between 0 and 1. Defaults to 0. - kernel_size (int, optional): Convolution kernel size. Defaults to 3. - up_kernel_size (int, optional): Up-convolution kernel size. Defaults to 3. - activation (str, optional): Activation function. Defaults to 'PReLU'. - bias (bool, optional): Whether to include bias in convolutions. Defaults to True. - adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Defaults to 'NDA'. - - Raises: - ValueError: If `size` is not one of 'small', 'medium', or 'large'. - """ - - def __init__( - self, - size="medium", - dropout=0, - kernel_size=3, - up_kernel_size=3, - activation="PReLU", - bias=True, - adn_order="NDA", - ): - super().__init__() - if size not in ["small", "medium", "large"]: - raise ValueError( - f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'." - ) - - self.size = size - self.dropout = dropout - self.kernel_size = kernel_size - self.up_kernel_size = up_kernel_size - self.activation = activation - self.bias = bias - self.adn_order = adn_order - - self.model = self._model_choice() - - def _model_choice(self): - from monai.networks.nets import UNet as monai_UNet - - if self.size == "small": - # 3 layers - self.channels = (64, 128, 256) - - elif self.size == "medium": - # 5 layers - self.channels = (64, 128, 256, 512, 1024) - - elif self.size == "large": - # 6 layers - self.channels = (64, 128, 256, 512, 1024, 2048) - - model = monai_UNet( - spatial_dims=2, - in_channels=1, # TODO: check if image has 1 or multiple input channels - out_channels=1, - channels=self.channels, - strides=(2,) * (len(self.channels) - 1), - kernel_size=self.kernel_size, - up_kernel_size=self.up_kernel_size, - act=self.activation, - dropout=self.dropout, - bias=self.bias, - adn_ordering=self.adn_order, - ) - return model - - def forward(self, x): - x = self.model(x) - return x - - class Hyperparameters: """ Hyperparameters for QIM segmentation.