diff --git a/qim3d/models/unet.py b/qim3d/models/unet.py
index 7912923664a648882ad7f4b03f68335a07f7702e..6e0f8a279aef132ddf5c93898ab31da16a9004c6 100644
--- a/qim3d/models/unet.py
+++ b/qim3d/models/unet.py
@@ -35,7 +35,8 @@ class UNet(nn.Module):
                  up_kernel_size = 3,
                  activation = 'PReLU',
                  bias = True,
-                 adn_order = 'NDA'
+                 adn_order = 'NDA',
+                 img_channels = 1
                 ):
         super().__init__()
         if size not in ['small','medium','large']:
@@ -50,8 +51,9 @@ class UNet(nn.Module):
         self.activation = activation
         self.bias = bias
         self.adn_order = adn_order
+        self.img_channels = img_channels
         
-        self.model = self._model_choice()
+        self._model_choice()
     
     def _model_choice(self):
         
@@ -64,7 +66,7 @@ class UNet(nn.Module):
 
         model = monai_UNet(
             spatial_dims=2,
-            in_channels=1, #TODO: check if image has 1 or multiple input channels
+            in_channels=self.img_channels, #TODO: check if image has 1 or multiple input channels
             out_channels=1,
             channels=self.channels,
             strides=(2,) * (len(self.channels) - 1),
@@ -75,8 +77,12 @@ class UNet(nn.Module):
             bias=self.bias,
             adn_ordering=self.adn_order
         )
-        return model
+        
+        self.model = model
     
+    def update_params(self):
+        self._model_choice()
+
     
     def forward(self,x):
         x = self.model(x)
diff --git a/qim3d/utils/models.py b/qim3d/utils/models.py
index 19a2a844fe225932e03b041c922328129aa771e8..72f5a591a5c9dbf1a3fd32964ff9039f912a811b 100644
--- a/qim3d/utils/models.py
+++ b/qim3d/utils/models.py
@@ -144,8 +144,10 @@ def model_summary(dataloader,model):
         print(summary)
     """
     images,_ = next(iter(dataloader)) 
-    batch_size = tuple(images.shape)
-    model_s = summary(model,batch_size,depth = torch.inf)
+    model_s = summary(model,
+        input_size = images.size(),
+        depth = torch.inf
+    )
     
     return model_s