diff --git a/qim3d/models/unet.py b/qim3d/models/unet.py index 7912923664a648882ad7f4b03f68335a07f7702e..6e0f8a279aef132ddf5c93898ab31da16a9004c6 100644 --- a/qim3d/models/unet.py +++ b/qim3d/models/unet.py @@ -35,7 +35,8 @@ class UNet(nn.Module): up_kernel_size = 3, activation = 'PReLU', bias = True, - adn_order = 'NDA' + adn_order = 'NDA', + img_channels = 1 ): super().__init__() if size not in ['small','medium','large']: @@ -50,8 +51,9 @@ class UNet(nn.Module): self.activation = activation self.bias = bias self.adn_order = adn_order + self.img_channels = img_channels - self.model = self._model_choice() + self._model_choice() def _model_choice(self): @@ -64,7 +66,7 @@ class UNet(nn.Module): model = monai_UNet( spatial_dims=2, - in_channels=1, #TODO: check if image has 1 or multiple input channels + in_channels=self.img_channels, #TODO: check if image has 1 or multiple input channels out_channels=1, channels=self.channels, strides=(2,) * (len(self.channels) - 1), @@ -75,8 +77,12 @@ class UNet(nn.Module): bias=self.bias, adn_ordering=self.adn_order ) - return model + + self.model = model + def update_params(self): + self._model_choice() + def forward(self,x): x = self.model(x) diff --git a/qim3d/utils/models.py b/qim3d/utils/models.py index 19a2a844fe225932e03b041c922328129aa771e8..72f5a591a5c9dbf1a3fd32964ff9039f912a811b 100644 --- a/qim3d/utils/models.py +++ b/qim3d/utils/models.py @@ -144,8 +144,10 @@ def model_summary(dataloader,model): print(summary) """ images,_ = next(iter(dataloader)) - batch_size = tuple(images.shape) - model_s = summary(model,batch_size,depth = torch.inf) + model_s = summary(model, + input_size = images.size(), + depth = torch.inf + ) return model_s