Skip to content
Snippets Groups Projects
Commit ae49d74c authored by s193396's avatar s193396
Browse files

3D UNet implementatioN

parent 7f406282
No related branches found
No related tags found
No related merge requests found
from ._unet import UNet2D, Hyperparameters
from ._unet import UNet, UNet2D, Hyperparameters
......@@ -4,6 +4,85 @@ import torch.nn as nn
from qim3d.utils import log
class UNet(nn.Module):
"""
3D UNet model for QIM imaging.
This class represents a 3D UNet model designed for imaging segmentation tasks.
Args:
size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'.
dropout (float, optional): Dropout rate between 0 and 1. Defaults to 0.
kernel_size (int, optional): Convolution kernel size. Defaults to 3.
up_kernel_size (int, optional): Up-convolution kernel size. Defaults to 3.
activation (str, optional): Activation function. Defaults to 'PReLU'.
bias (bool, optional): Whether to include bias in convolutions. Defaults to True.
adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Defaults to 'NDA'.
Raises:
ValueError: If `size` is not one of 'small', 'medium', or 'large'.
"""
def __init__(
self,
size="medium",
dropout=0,
kernel_size=3,
up_kernel_size=3,
activation="PReLU",
bias=True,
adn_order="NDA",
):
super().__init__()
if size not in ["small", "medium", "large"]:
raise ValueError(
f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'."
)
self.size = size
self.dropout = dropout
self.kernel_size = kernel_size
self.up_kernel_size = up_kernel_size
self.activation = activation
self.bias = bias
self.adn_order = adn_order
self.model = self._model_choice()
def _model_choice(self):
from monai.networks.nets import UNet as monai_UNet
if self.size == "small":
# 3 layers
self.channels = (64, 128, 256)
elif self.size == "medium":
# 5 layers
self.channels = (64, 128, 256, 512, 1024)
elif self.size == "large":
# 6 layers
self.channels = (64, 128, 256, 512, 1024, 2048)
model = monai_UNet(
spatial_dims=3,
in_channels=1, # TODO: check if image has 1 or multiple input channels
out_channels=1,
channels=self.channels,
strides=(2,) * (len(self.channels) - 1),
kernel_size=self.kernel_size,
up_kernel_size=self.up_kernel_size,
act=self.activation,
dropout=self.dropout,
bias=self.bias,
adn_ordering=self.adn_order,
)
return model
def forward(self, x):
x = self.model(x)
return x
class UNet2D(nn.Module):
"""
......@@ -54,10 +133,15 @@ class UNet2D(nn.Module):
from monai.networks.nets import UNet as monai_UNet
if self.size == "small":
# 3 layers
self.channels = (64, 128, 256)
elif self.size == "medium":
# 5 layers
self.channels = (64, 128, 256, 512, 1024)
elif self.size == "large":
# 6 layers
self.channels = (64, 128, 256, 512, 1024, 2048)
model = monai_UNet(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment