Skip to content
Snippets Groups Projects
Commit 6f8cb981 authored by s193396's avatar s193396
Browse files

removed UNet2D

parent 3d734ce5
No related branches found
No related tags found
No related merge requests found
from ._unet import UNet, UNet2D, Hyperparameters
from ._unet import UNet, Hyperparameters
......@@ -84,87 +84,6 @@ class UNet(nn.Module):
x = self.model(x)
return x
class UNet2D(nn.Module):
"""
2D UNet model for QIM imaging.
This class represents a 2D UNet model designed for imaging segmentation tasks.
Args:
size ('small' or 'medium' or 'large', optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'.
dropout (float, optional): Dropout rate between 0 and 1. Defaults to 0.
kernel_size (int, optional): Convolution kernel size. Defaults to 3.
up_kernel_size (int, optional): Up-convolution kernel size. Defaults to 3.
activation (str, optional): Activation function. Defaults to 'PReLU'.
bias (bool, optional): Whether to include bias in convolutions. Defaults to True.
adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Defaults to 'NDA'.
Raises:
ValueError: If `size` is not one of 'small', 'medium', or 'large'.
"""
def __init__(
self,
size="medium",
dropout=0,
kernel_size=3,
up_kernel_size=3,
activation="PReLU",
bias=True,
adn_order="NDA",
):
super().__init__()
if size not in ["small", "medium", "large"]:
raise ValueError(
f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'."
)
self.size = size
self.dropout = dropout
self.kernel_size = kernel_size
self.up_kernel_size = up_kernel_size
self.activation = activation
self.bias = bias
self.adn_order = adn_order
self.model = self._model_choice()
def _model_choice(self):
from monai.networks.nets import UNet as monai_UNet
if self.size == "small":
# 3 layers
self.channels = (64, 128, 256)
elif self.size == "medium":
# 5 layers
self.channels = (64, 128, 256, 512, 1024)
elif self.size == "large":
# 6 layers
self.channels = (64, 128, 256, 512, 1024, 2048)
model = monai_UNet(
spatial_dims=2,
in_channels=1, # TODO: check if image has 1 or multiple input channels
out_channels=1,
channels=self.channels,
strides=(2,) * (len(self.channels) - 1),
kernel_size=self.kernel_size,
up_kernel_size=self.up_kernel_size,
act=self.activation,
dropout=self.dropout,
bias=self.bias,
adn_ordering=self.adn_order,
)
return model
def forward(self, x):
x = self.model(x)
return x
class Hyperparameters:
"""
Hyperparameters for QIM segmentation.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment