Skip to content
Snippets Groups Projects
Commit 6f8cb981 authored by s193396's avatar s193396
Browse files

removed UNet2D

parent 3d734ce5
Branches
No related tags found
No related merge requests found
from ._unet import UNet, UNet2D, Hyperparameters
from ._unet import UNet, Hyperparameters
......@@ -84,87 +84,6 @@ class UNet(nn.Module):
x = self.model(x)
return x
class UNet2D(nn.Module):
"""
2D UNet model for QIM imaging.
This class represents a 2D UNet model designed for imaging segmentation tasks.
Args:
size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'.
dropout (float, optional): Dropout rate between 0 and 1. Defaults to 0.
kernel_size (int, optional): Convolution kernel size. Defaults to 3.
up_kernel_size (int, optional): Up-convolution kernel size. Defaults to 3.
activation (str, optional): Activation function. Defaults to 'PReLU'.
bias (bool, optional): Whether to include bias in convolutions. Defaults to True.
adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Defaults to 'NDA'.
Raises:
ValueError: If `size` is not one of 'small', 'medium', or 'large'.
"""
def __init__(
self,
size="medium",
dropout=0,
kernel_size=3,
up_kernel_size=3,
activation="PReLU",
bias=True,
adn_order="NDA",
):
super().__init__()
if size not in ["small", "medium", "large"]:
raise ValueError(
f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'."
)
self.size = size
self.dropout = dropout
self.kernel_size = kernel_size
self.up_kernel_size = up_kernel_size
self.activation = activation
self.bias = bias
self.adn_order = adn_order
self.model = self._model_choice()
def _model_choice(self):
from monai.networks.nets import UNet as monai_UNet
if self.size == "small":
# 3 layers
self.channels = (64, 128, 256)
elif self.size == "medium":
# 5 layers
self.channels = (64, 128, 256, 512, 1024)
elif self.size == "large":
# 6 layers
self.channels = (64, 128, 256, 512, 1024, 2048)
model = monai_UNet(
spatial_dims=2,
in_channels=1, # TODO: check if image has 1 or multiple input channels
out_channels=1,
channels=self.channels,
strides=(2,) * (len(self.channels) - 1),
kernel_size=self.kernel_size,
up_kernel_size=self.up_kernel_size,
act=self.activation,
dropout=self.dropout,
bias=self.bias,
adn_ordering=self.adn_order,
)
return model
def forward(self, x):
x = self.model(x)
return x
class Hyperparameters:
"""
Hyperparameters for QIM segmentation.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment