Skip to content
Snippets Groups Projects
Commit ba974ff9 authored by ofhkr's avatar ofhkr
Browse files

small changes to update UNet channels

parent 353731d4
No related branches found
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
......@@ -35,7 +35,8 @@ class UNet(nn.Module):
up_kernel_size = 3,
activation = 'PReLU',
bias = True,
adn_order = 'NDA'
adn_order = 'NDA',
img_channels = 1
):
super().__init__()
if size not in ['small','medium','large']:
......@@ -50,8 +51,9 @@ class UNet(nn.Module):
self.activation = activation
self.bias = bias
self.adn_order = adn_order
self.img_channels = img_channels
self.model = self._model_choice()
self._model_choice()
def _model_choice(self):
......@@ -64,7 +66,7 @@ class UNet(nn.Module):
model = monai_UNet(
spatial_dims=2,
in_channels=1, #TODO: check if image has 1 or multiple input channels
in_channels=self.img_channels, #TODO: check if image has 1 or multiple input channels
out_channels=1,
channels=self.channels,
strides=(2,) * (len(self.channels) - 1),
......@@ -75,7 +77,11 @@ class UNet(nn.Module):
bias=self.bias,
adn_ordering=self.adn_order
)
return model
self.model = model
def update_params(self):
self._model_choice()
def forward(self,x):
......
......@@ -144,8 +144,10 @@ def model_summary(dataloader,model):
print(summary)
"""
images,_ = next(iter(dataloader))
batch_size = tuple(images.shape)
model_s = summary(model,batch_size,depth = torch.inf)
model_s = summary(model,
input_size = images.size(),
depth = torch.inf
)
return model_s
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment