Skip to content
Snippets Groups Projects
Commit ba974ff9 authored by ofhkr's avatar ofhkr
Browse files

small changes to update UNet channels

parent 353731d4
No related branches found
No related tags found
1 merge request!47(Work in progress) Implementation of adaptive Dataset class which adapts to different data structures
...@@ -35,7 +35,8 @@ class UNet(nn.Module): ...@@ -35,7 +35,8 @@ class UNet(nn.Module):
up_kernel_size = 3, up_kernel_size = 3,
activation = 'PReLU', activation = 'PReLU',
bias = True, bias = True,
adn_order = 'NDA' adn_order = 'NDA',
img_channels = 1
): ):
super().__init__() super().__init__()
if size not in ['small','medium','large']: if size not in ['small','medium','large']:
...@@ -50,8 +51,9 @@ class UNet(nn.Module): ...@@ -50,8 +51,9 @@ class UNet(nn.Module):
self.activation = activation self.activation = activation
self.bias = bias self.bias = bias
self.adn_order = adn_order self.adn_order = adn_order
self.img_channels = img_channels
self.model = self._model_choice() self._model_choice()
def _model_choice(self): def _model_choice(self):
...@@ -64,7 +66,7 @@ class UNet(nn.Module): ...@@ -64,7 +66,7 @@ class UNet(nn.Module):
model = monai_UNet( model = monai_UNet(
spatial_dims=2, spatial_dims=2,
in_channels=1, #TODO: check if image has 1 or multiple input channels in_channels=self.img_channels, #TODO: check if image has 1 or multiple input channels
out_channels=1, out_channels=1,
channels=self.channels, channels=self.channels,
strides=(2,) * (len(self.channels) - 1), strides=(2,) * (len(self.channels) - 1),
...@@ -75,8 +77,12 @@ class UNet(nn.Module): ...@@ -75,8 +77,12 @@ class UNet(nn.Module):
bias=self.bias, bias=self.bias,
adn_ordering=self.adn_order adn_ordering=self.adn_order
) )
return model
self.model = model
def update_params(self):
self._model_choice()
def forward(self,x): def forward(self,x):
x = self.model(x) x = self.model(x)
......
...@@ -144,8 +144,10 @@ def model_summary(dataloader,model): ...@@ -144,8 +144,10 @@ def model_summary(dataloader,model):
print(summary) print(summary)
""" """
images,_ = next(iter(dataloader)) images,_ = next(iter(dataloader))
batch_size = tuple(images.shape) model_s = summary(model,
model_s = summary(model,batch_size,depth = torch.inf) input_size = images.size(),
depth = torch.inf
)
return model_s return model_s
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment