Skip to content
Snippets Groups Projects
Commit ba974ff9 authored by ofhkr's avatar ofhkr
Browse files

small changes to update UNet channels

parent 353731d4
Branches
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
......@@ -35,7 +35,8 @@ class UNet(nn.Module):
up_kernel_size = 3,
activation = 'PReLU',
bias = True,
adn_order = 'NDA'
adn_order = 'NDA',
img_channels = 1
):
super().__init__()
if size not in ['small','medium','large']:
......@@ -50,8 +51,9 @@ class UNet(nn.Module):
self.activation = activation
self.bias = bias
self.adn_order = adn_order
self.img_channels = img_channels
self.model = self._model_choice()
self._model_choice()
def _model_choice(self):
......@@ -64,7 +66,7 @@ class UNet(nn.Module):
model = monai_UNet(
spatial_dims=2,
in_channels=1, #TODO: check if image has 1 or multiple input channels
in_channels=self.img_channels, #TODO: check if image has 1 or multiple input channels
out_channels=1,
channels=self.channels,
strides=(2,) * (len(self.channels) - 1),
......@@ -75,7 +77,11 @@ class UNet(nn.Module):
bias=self.bias,
adn_ordering=self.adn_order
)
return model
self.model = model
def update_params(self):
self._model_choice()
def forward(self,x):
......
......@@ -144,8 +144,10 @@ def model_summary(dataloader,model):
print(summary)
"""
images,_ = next(iter(dataloader))
batch_size = tuple(images.shape)
model_s = summary(model,batch_size,depth = torch.inf)
model_s = summary(model,
input_size = images.size(),
depth = torch.inf
)
return model_s
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment