From 9e62018acd6f9c3be20a546aec5b1230d7a87144 Mon Sep 17 00:00:00 2001
From: fima <fima@dtu.dk>
Date: Mon, 17 Jun 2024 13:06:00 +0200
Subject: [PATCH] Import speed refactoring

---
 docs/releases.md                      |   4 +
 qim3d/__init__.py                     |  10 +-
 qim3d/gui/annotation_tool.py          |   6 +-
 qim3d/io/loading.py                   |  13 +-
 qim3d/io/saving.py                    |  94 ++++++++------
 qim3d/models/unet.py                  | 180 ++++++++++++++++----------
 qim3d/processing/cc.py                |   1 -
 qim3d/processing/detection.py         |   3 +-
 qim3d/processing/local_thickness_.py  |  14 +-
 qim3d/processing/operations.py        |   6 +-
 qim3d/processing/structure_tensor_.py |   2 +-
 qim3d/tests/viz/test_img.py           |   2 +-
 qim3d/utils/augmentations.py          |   6 +-
 qim3d/utils/cli.py                    |  28 ++--
 qim3d/utils/data.py                   |   5 +-
 qim3d/viz/colormaps.py                |  63 ++++-----
 qim3d/viz/k3d.py                      |   3 +-
 requirements.txt                      |  44 +++----
 setup.py                              |   2 +-
 19 files changed, 263 insertions(+), 223 deletions(-)

diff --git a/docs/releases.md b/docs/releases.md
index b6e15863..257c7d49 100644
--- a/docs/releases.md
+++ b/docs/releases.md
@@ -9,6 +9,10 @@ As the library is still in its early development stages, **there may be breaking
 
 And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that you have the latest features!
 
+### v0.3.7 (17/06/2024)
+- Performance improvements when importing 
+- Refactoring for blob detection
+
 ### v0.3.6 (30/05/2024)
 - Refactoring for performance improvement
 - Welcome message for the CLI
diff --git a/qim3d/__init__.py b/qim3d/__init__.py
index c89b6fd0..ebb4e7b2 100644
--- a/qim3d/__init__.py
+++ b/qim3d/__init__.py
@@ -8,20 +8,14 @@ Documentation available at https://platform.qim.dk/qim3d/
 
 """
 
-__version__ = "0.3.6"
-
-import logging
-
-logging.basicConfig(level=logging.ERROR)
+__version__ = "0.3.7"
 
 from . import io
 from . import gui
 from . import viz
 from . import utils
 from . import processing
-
-# Commenting out models because it takes too long to import
-# from . import models
+from . import models
 
 examples = io.ImgExamples()
 io.logger.set_level_info()
diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py
index da664528..40b3fda8 100644
--- a/qim3d/gui/annotation_tool.py
+++ b/qim3d/gui/annotation_tool.py
@@ -23,16 +23,11 @@ app = annotation_tool.launch(vol[0])
 import getpass
 import os
 import tempfile
-import time
 
 import gradio as gr
 import numpy as np
-import tifffile
-from PIL import Image
-
 import qim3d.utils
 from qim3d.io import load, save
-from qim3d.io.logger import log
 
 
 class Session:
@@ -100,6 +95,7 @@ class Interface:
         return gr.update(visible=True)
 
     def create_interface(self, img=None):
+        from PIL import Image
 
         if img is not None:
             custom_css = "annotation-tool"
diff --git a/qim3d/io/loading.py b/qim3d/io/loading.py
index e45eb121..28f5abe4 100644
--- a/qim3d/io/loading.py
+++ b/qim3d/io/loading.py
@@ -17,11 +17,7 @@ from pathlib import Path
 
 import dask
 import dask.array as da
-import h5py
-import nibabel as nib
 import numpy as np
-import olefile
-import pydicom
 import tifffile
 from dask import delayed
 from PIL import Image, UnidentifiedImageError
@@ -122,6 +118,7 @@ class DataLoader:
             ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file
             ValueError: If no datasets are found in the file.
         """
+        import h5py
 
         # Read file
         f = h5py.File(path, "r")
@@ -256,6 +253,7 @@ class DataLoader:
         Raises:
             ValueError: If the dxchange library is not installed
         """
+        import olefile
 
         try:
             import dxchange
@@ -323,6 +321,7 @@ class DataLoader:
                 If 'self.virtual_stack' is True, returns a nibabel.arrayproxy.ArrayProxy object
                 If 'self.return_metadata' is True, returns a tuple (volume, metadata).
         """
+        import nibabel as nib
 
         data = nib.load(path)
 
@@ -557,6 +556,8 @@ class DataLoader:
         Args:
             path (str): Path to file
         """
+        import pydicom
+
         dcm_data = pydicom.dcmread(path)
 
         if self.return_metadata:
@@ -570,6 +571,8 @@ class DataLoader:
         Args:
             path (str): Directory path
         """
+        import pydicom
+
         if not self.contains:
             raise ValueError(
                 "Please specify a part of the name that is common for the DICOM file stack with the argument 'contains'"
@@ -709,6 +712,8 @@ class DataLoader:
 
 
 def _get_h5_dataset_keys(f):
+    import h5py
+
     keys = []
     f.visit(lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None)
     return keys
diff --git a/qim3d/io/saving.py b/qim3d/io/saving.py
index 64a6a1ca..675fa20e 100644
--- a/qim3d/io/saving.py
+++ b/qim3d/io/saving.py
@@ -21,18 +21,12 @@ Example:
     ```
 
 """
+
 import datetime
 import os
-
-import h5py
-import nibabel as nib
 import numpy as np
 import PIL
-import pydicom
 import tifffile
-from pydicom.dataset import FileDataset, FileMetaDataset
-from pydicom.uid import UID
-
 from qim3d.io.logger import log
 from qim3d.utils.internal_tools import sizeof, stringify_path
 
@@ -116,24 +110,30 @@ class DataSaver:
                 filepath = os.path.join(path, filename)
                 self.save_tiff(filepath, sliced)
 
-            pattern_string = filepath[:-(len(extension)+zfill_val)] + "-"*zfill_val + extension
+            pattern_string = (
+                filepath[: -(len(extension) + zfill_val)] + "-" * zfill_val + extension
+            )
 
-            log.info(f"Total of {no_slices} files saved following the pattern '{pattern_string}'")
+            log.info(
+                f"Total of {no_slices} files saved following the pattern '{pattern_string}'"
+            )
 
     def save_nifti(self, path, data):
-        """ Save data to a NIfTI file to the given path.
+        """Save data to a NIfTI file to the given path.
 
         Args:
             path (str): The path to save file to
             data (numpy.ndarray): The data to be saved
         """
+        import nibabel as nib
+
         # Create header
         header = nib.Nifti1Header()
         header.set_data_dtype(data.dtype)
 
         # Create NIfTI image object
         img = nib.Nifti1Image(data, np.eye(4), header)
-        
+
         # nib does automatically compress if filetype ends with .gz
         if self.compression and not path.endswith(".gz"):
             path += ".gz"
@@ -141,13 +141,15 @@ class DataSaver:
 
         if not self.compression and path.endswith(".gz"):
             path = path[:-3]
-            log.warning("File extension '.gz' is ignored since compression is disabled.")
+            log.warning(
+                "File extension '.gz' is ignored since compression is disabled."
+            )
 
         # Save image
         nib.save(img, path)
 
     def save_vol(self, path, data):
-        """ Save data to a VOL file to the given path.
+        """Save data to a VOL file to the given path.
 
         Args:
             path (str): The path to save file to
@@ -155,15 +157,21 @@ class DataSaver:
         """
         # No support for compression yet
         if self.compression:
-            raise NotImplementedError("Saving compressed .vol files is not yet supported")
+            raise NotImplementedError(
+                "Saving compressed .vol files is not yet supported"
+            )
 
         # Create custom .vgi metadata file
         metadata = ""
-        metadata += "{volume1}\n" # .vgi organization 
-        metadata += "[file1]\n" # .vgi organization 
-        metadata += "Size = {} {} {}\n".format(data.shape[1], data.shape[2], data.shape[0]) # Swap axes to match .vol format
-        metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string
-        metadata += "Name = {}.vol\n".format(path.rsplit('/', 1)[-1][:-4]) # Get filename without extension
+        metadata += "{volume1}\n"  # .vgi organization
+        metadata += "[file1]\n"  # .vgi organization
+        metadata += "Size = {} {} {}\n".format(
+            data.shape[1], data.shape[2], data.shape[0]
+        )  # Swap axes to match .vol format
+        metadata += "Datatype = {}\n".format(str(data.dtype))  # Get datatype as string
+        metadata += "Name = {}.vol\n".format(
+            path.rsplit("/", 1)[-1][:-4]
+        )  # Get filename without extension
 
         # Save metadata
         with open(path[:-4] + ".vgi", "w") as f:
@@ -173,39 +181,45 @@ class DataSaver:
         data.tofile(path[:-4] + ".vol")
 
     def save_h5(self, path, data):
-        """ Save data to a HDF5 file to the given path.
+        """Save data to a HDF5 file to the given path.
 
         Args:
             path (str): The path to save file to
             data (numpy.ndarray): The data to be saved
         """
+        import h5py
 
         with h5py.File(path, "w") as f:
-            f.create_dataset("dataset", data=data, compression="gzip" if self.compression else None)
-        
+            f.create_dataset(
+                "dataset", data=data, compression="gzip" if self.compression else None
+            )
+
     def save_dicom(self, path, data):
-        """ Save data to a DICOM file to the given path.
+        """Save data to a DICOM file to the given path.
 
         Args:
             path (str): The path to save file to
             data (numpy.ndarray): The data to be saved
         """
+        import pydicom
+        from pydicom.dataset import FileDataset, FileMetaDataset
+        from pydicom.uid import UID
+
         # based on https://pydicom.github.io/pydicom/stable/auto_examples/input_output/plot_write_dicom.html
 
         # Populate required values for file meta information
         file_meta = FileMetaDataset()
-        file_meta.MediaStorageSOPClassUID = UID('1.2.840.10008.5.1.4.1.1.2')
+        file_meta.MediaStorageSOPClassUID = UID("1.2.840.10008.5.1.4.1.1.2")
         file_meta.MediaStorageSOPInstanceUID = UID("1.2.3")
         file_meta.ImplementationClassUID = UID("1.2.3.4")
 
         # Create the FileDataset instance (initially no data elements, but file_meta
         # supplied)
-        ds = FileDataset(path, {},
-                        file_meta=file_meta, preamble=b"\0" * 128)
+        ds = FileDataset(path, {}, file_meta=file_meta, preamble=b"\0" * 128)
 
         ds.PatientName = "Test^Firstname"
         ds.PatientID = "123456"
-        ds.StudyInstanceUID = "1.2.3.4.5" 
+        ds.StudyInstanceUID = "1.2.3.4.5"
         ds.SamplesPerPixel = 1
         ds.PixelRepresentation = 0
         ds.BitsStored = 16
@@ -220,8 +234,8 @@ class DataSaver:
 
         # Set creation date/time
         dt = datetime.datetime.now()
-        ds.ContentDate = dt.strftime('%Y%m%d')
-        timeStr = dt.strftime('%H%M%S.%f')  # long format with micro seconds
+        ds.ContentDate = dt.strftime("%Y%m%d")
+        timeStr = dt.strftime("%H%M%S.%f")  # long format with micro seconds
         ds.ContentTime = timeStr
         # Needs to be here because of bug in pydicom
         ds.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian
@@ -234,10 +248,9 @@ class DataSaver:
         ds.PixelData = data_bytes
 
         ds.save_as(path)
-        
-    
+
     def save_PIL(self, path, data):
-        """ Save data to a PIL file to the given path.
+        """Save data to a PIL file to the given path.
 
         Args:
             path (str): The path to save file to
@@ -246,7 +259,7 @@ class DataSaver:
         # No support for compression yet
         if self.compression and path.endswith(".png"):
             raise NotImplementedError("png does not support compression")
-        elif not self.compression and path.endswith((".jpeg",".jpg")):
+        elif not self.compression and path.endswith((".jpeg", ".jpg")):
             raise NotImplementedError("jpeg does not support no compression")
 
         # Convert to PIL image
@@ -255,7 +268,6 @@ class DataSaver:
         # Save image
         img.save(path)
 
-        
     def save(self, path, data):
         """Save data to the given path.
 
@@ -320,17 +332,19 @@ class DataSaver:
 
                     if path.endswith((".tif", ".tiff")):
                         return self.save_tiff(path, data)
-                    elif path.endswith((".nii","nii.gz")):
+                    elif path.endswith((".nii", "nii.gz")):
                         return self.save_nifti(path, data)
-                    elif path.endswith(("TXRM","XRM","TXM")):
-                        raise NotImplementedError("Saving TXRM files is not yet supported")
+                    elif path.endswith(("TXRM", "XRM", "TXM")):
+                        raise NotImplementedError(
+                            "Saving TXRM files is not yet supported"
+                        )
                     elif path.endswith((".h5")):
                         return self.save_h5(path, data)
-                    elif path.endswith((".vol",".vgi")):
+                    elif path.endswith((".vol", ".vgi")):
                         return self.save_vol(path, data)
-                    elif path.endswith((".dcm",".DCM")):
+                    elif path.endswith((".dcm", ".DCM")):
                         return self.save_dicom(path, data)
-                    elif path.endswith((".jpeg",".jpg", ".png")):
+                    elif path.endswith((".jpeg", ".jpg", ".png")):
                         return self.save_PIL(path, data)
                     else:
                         raise ValueError("Unsupported file format")
diff --git a/qim3d/models/unet.py b/qim3d/models/unet.py
index c85648fd..6ca19bcb 100644
--- a/qim3d/models/unet.py
+++ b/qim3d/models/unet.py
@@ -1,14 +1,10 @@
 """UNet model and Hyperparameters class."""
 
-from monai.networks.nets import UNet as monai_UNet
-from monai.losses import FocalLoss, DiceLoss, DiceCELoss
-
 import torch.nn as nn
-from torch.nn import BCEWithLogitsLoss
-from torch.optim import Adam, SGD, RMSprop
 
 from qim3d.io.logger import log
 
+
 class UNet(nn.Module):
     """
     2D UNet model for QIM imaging.
@@ -32,20 +28,23 @@ class UNet(nn.Module):
         model = UNet(size='large')
         ```
     """
-    def __init__(self, size = 'medium',
-                 dropout = 0,
-                 kernel_size = 3,
-                 up_kernel_size = 3,
-                 activation = 'PReLU',
-                 bias = True,
-                 adn_order = 'NDA'
-                ):
+
+    def __init__(
+        self,
+        size="medium",
+        dropout=0,
+        kernel_size=3,
+        up_kernel_size=3,
+        activation="PReLU",
+        bias=True,
+        adn_order="NDA",
+    ):
         super().__init__()
-        if size not in ['small','medium','large']:
+        if size not in ["small", "medium", "large"]:
             raise ValueError(
                 f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'."
             )
-        
+
         self.size = size
         self.dropout = dropout
         self.kernel_size = kernel_size
@@ -53,21 +52,22 @@ class UNet(nn.Module):
         self.activation = activation
         self.bias = bias
         self.adn_order = adn_order
-        
+
         self.model = self._model_choice()
-    
+
     def _model_choice(self):
-        
-        if self.size == 'small':
+        from monai.networks.nets import UNet as monai_UNet
+
+        if self.size == "small":
             self.channels = (64, 128, 256)
-        elif self.size == 'medium':
+        elif self.size == "medium":
             self.channels = (64, 128, 256, 512, 1024)
-        elif self.size == 'large':
+        elif self.size == "large":
             self.channels = (64, 128, 256, 512, 1024, 2048)
 
         model = monai_UNet(
             spatial_dims=2,
-            in_channels=1, #TODO: check if image has 1 or multiple input channels
+            in_channels=1,  # TODO: check if image has 1 or multiple input channels
             out_channels=1,
             channels=self.channels,
             strides=(2,) * (len(self.channels) - 1),
@@ -76,12 +76,11 @@ class UNet(nn.Module):
             act=self.activation,
             dropout=self.dropout,
             bias=self.bias,
-            adn_ordering=self.adn_order
+            adn_ordering=self.adn_order,
         )
         return model
-    
-    
-    def forward(self,x):
+
+    def forward(self, x):
         x = self.model(x)
         return x
 
@@ -114,29 +113,37 @@ class Hyperparameters:
         n_epochs  = params_dict['n_epochs']
         ```
     """
-    def __init__(self,
-                 model,
-                 n_epochs = 10,
-                 learning_rate = 1e-3,
-                 optimizer = 'Adam',
-                 momentum = 0,
-                 weight_decay = 0,
-                 loss_function = 'Focal'):
 
+    def __init__(
+        self,
+        model,
+        n_epochs=10,
+        learning_rate=1e-3,
+        optimizer="Adam",
+        momentum=0,
+        weight_decay=0,
+        loss_function="Focal",
+    ):
 
         # TODO: implement custom loss_functions? then add a check to see if loss works for segmentation.
-        if loss_function not in ['BCE','Dice','Focal','DiceCE']:
-            raise ValueError(f"Invalid loss function: {loss_function}. Loss criterion must "
-                             "be one of the following: 'BCE','Dice','Focal','DiceCE'.")
+        if loss_function not in ["BCE", "Dice", "Focal", "DiceCE"]:
+            raise ValueError(
+                f"Invalid loss function: {loss_function}. Loss criterion must "
+                "be one of the following: 'BCE','Dice','Focal','DiceCE'."
+            )
         # TODO: implement custom optimizer? and add check to see if valid.
-        if optimizer not in ['Adam','SGD','RMSprop']:
-            raise ValueError(f"Invalid optimizer: {optimizer}. Optimizer must "
-                             "be one of the following: 'Adam', 'SGD', 'RMSprop'.")
-        
-        if (momentum != 0) and optimizer == 'Adam':
-            log.info("Momentum isn't an input in the 'Adam' optimizer. "
-                        "Change optimizer to 'SGD' or 'RMSprop' to use momentum.")          
-        
+        if optimizer not in ["Adam", "SGD", "RMSprop"]:
+            raise ValueError(
+                f"Invalid optimizer: {optimizer}. Optimizer must "
+                "be one of the following: 'Adam', 'SGD', 'RMSprop'."
+            )
+
+        if (momentum != 0) and optimizer == "Adam":
+            log.info(
+                "Momentum isn't an input in the 'Adam' optimizer. "
+                "Change optimizer to 'SGD' or 'RMSprop' to use momentum."
+            )
+
         self.model = model
         self.n_epochs = n_epochs
         self.learning_rate = learning_rate
@@ -146,41 +153,72 @@ class Hyperparameters:
         self.loss_function = loss_function
 
     def __call__(self):
-        return self.model_params(self.model, self.n_epochs, self.optimizer, self.learning_rate,
-                                 self.weight_decay, self.momentum, self.loss_function)
+        return self.model_params(
+            self.model,
+            self.n_epochs,
+            self.optimizer,
+            self.learning_rate,
+            self.weight_decay,
+            self.momentum,
+            self.loss_function,
+        )
 
-    def model_params(self, model, n_epochs, optimizer, learning_rate, weight_decay, momentum, loss_function):
+    def model_params(
+        self,
+        model,
+        n_epochs,
+        optimizer,
+        learning_rate,
+        weight_decay,
+        momentum,
+        loss_function,
+    ):
 
         optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
         criterion = self._loss_functions(loss_function)
 
-        hyper_dict = {'optimizer': optim,
-                      'criterion': criterion,
-                      'n_epochs' : n_epochs,
-                     }
+        hyper_dict = {
+            "optimizer": optim,
+            "criterion": criterion,
+            "n_epochs": n_epochs,
+        }
         return hyper_dict
 
     # selecting the optimizer
     def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum):
-        if optimizer == 'Adam':
-            optim = Adam(model.parameters(), lr = learning_rate,
-                         weight_decay = weight_decay)
-        elif optimizer == 'SGD':
-            optim = SGD(model.parameters(), lr = learning_rate,
-                        momentum = momentum, weight_decay = weight_decay)
-        elif optimizer == 'RMSprop':
-            optim = RMSprop(model.parameters(),lr = learning_rate,
-                            weight_decay = weight_decay, momentum = momentum)
+        from torch.optim import Adam, SGD, RMSprop
+
+        if optimizer == "Adam":
+            optim = Adam(
+                model.parameters(), lr=learning_rate, weight_decay=weight_decay
+            )
+        elif optimizer == "SGD":
+            optim = SGD(
+                model.parameters(),
+                lr=learning_rate,
+                momentum=momentum,
+                weight_decay=weight_decay,
+            )
+        elif optimizer == "RMSprop":
+            optim = RMSprop(
+                model.parameters(),
+                lr=learning_rate,
+                weight_decay=weight_decay,
+                momentum=momentum,
+            )
         return optim
 
     # selecting the loss function
-    def _loss_functions(self,loss_function):
-        if loss_function =='BCE':
-            criterion = BCEWithLogitsLoss(reduction='mean')
-        elif loss_function == 'Dice':
-            criterion = DiceLoss(sigmoid=True,reduction='mean')
-        elif loss_function == 'Focal':
-            criterion = FocalLoss(reduction='mean')
-        elif loss_function == 'DiceCE':
-            criterion = DiceCELoss(sigmoid=True,reduction='mean')
-        return criterion
\ No newline at end of file
+    def _loss_functions(self, loss_function):
+        from monai.losses import FocalLoss, DiceLoss, DiceCELoss
+        from torch.nn import BCEWithLogitsLoss
+
+        if loss_function == "BCE":
+            criterion = BCEWithLogitsLoss(reduction="mean")
+        elif loss_function == "Dice":
+            criterion = DiceLoss(sigmoid=True, reduction="mean")
+        elif loss_function == "Focal":
+            criterion = FocalLoss(reduction="mean")
+        elif loss_function == "DiceCE":
+            criterion = DiceCELoss(sigmoid=True, reduction="mean")
+        return criterion
diff --git a/qim3d/processing/cc.py b/qim3d/processing/cc.py
index d6ec0e85..48cc004a 100644
--- a/qim3d/processing/cc.py
+++ b/qim3d/processing/cc.py
@@ -1,7 +1,6 @@
 import numpy as np
 import torch
 from scipy.ndimage import find_objects, label
-
 from qim3d.io.logger import log
 
 
diff --git a/qim3d/processing/detection.py b/qim3d/processing/detection.py
index c5743a15..b967e844 100644
--- a/qim3d/processing/detection.py
+++ b/qim3d/processing/detection.py
@@ -2,8 +2,6 @@
 
 import numpy as np
 from qim3d.io.logger import log
-from skimage.feature import blob_dog
-
 
 def blob_detection(
     vol: np.ndarray,
@@ -63,6 +61,7 @@ def blob_detection(
             ```
             ![blob detection](assets/screenshots/blob_get_mask.gif)
     """
+    from skimage.feature import blob_dog
 
     if background == "bright":
         log.info("Bright background selected, volume will be inverted.")
diff --git a/qim3d/processing/local_thickness_.py b/qim3d/processing/local_thickness_.py
index f0de7e99..7a9e30aa 100644
--- a/qim3d/processing/local_thickness_.py
+++ b/qim3d/processing/local_thickness_.py
@@ -1,13 +1,11 @@
 """Wrapper for the local thickness function from the localthickness package including visualization functions."""
 
-import localthickness as lt
 import numpy as np
 from typing import Optional
-from skimage.filters import threshold_otsu
 from qim3d.io.logger import log
-#from qim3d.viz import local_thickness as viz_local_thickness
 import qim3d
 
+
 def local_thickness(
     image: np.ndarray,
     scale: float = 1,
@@ -17,10 +15,10 @@ def local_thickness(
 ) -> np.ndarray:
     """Wrapper for the local thickness function from the [local thickness package](https://github.com/vedranaa/local-thickness)
 
-    The "Fast Local Thickness" by Vedrana Andersen Dahl and Anders Bjorholm Dahl from the Technical University of Denmark is a efficient algorithm for computing local thickness in 2D and 3D images. 
-    Their method significantly reduces computation time compared to traditional algorithms by utilizing iterative dilation with small structuring elements, rather than the large ones typically used. 
-    This approach allows the local thickness to be determined much faster, making it feasible for high-resolution volumetric data that are common in contemporary 3D microscopy. 
-    
+    The "Fast Local Thickness" by Vedrana Andersen Dahl and Anders Bjorholm Dahl from the Technical University of Denmark is a efficient algorithm for computing local thickness in 2D and 3D images.
+    Their method significantly reduces computation time compared to traditional algorithms by utilizing iterative dilation with small structuring elements, rather than the large ones typically used.
+    This approach allows the local thickness to be determined much faster, making it feasible for high-resolution volumetric data that are common in contemporary 3D microscopy.
+
     Testing against conventional methods and other Python-based tools like PoreSpy shows that the new algorithm is both accurate and faster, offering significant improvements in processing time for large datasets.
 
 
@@ -79,6 +77,8 @@ def local_thickness(
 
 
     """
+    import localthickness as lt
+    from skimage.filters import threshold_otsu
 
     # Check if input is binary
     if np.unique(image).size > 2:
diff --git a/qim3d/processing/operations.py b/qim3d/processing/operations.py
index dfaeb42e..1efaf9c9 100644
--- a/qim3d/processing/operations.py
+++ b/qim3d/processing/operations.py
@@ -1,7 +1,4 @@
 import numpy as np
-import scipy
-import skimage
-
 import qim3d.processing.filters as filters
 from qim3d.io.logger import log
 
@@ -86,6 +83,9 @@ def watershed(
         ![operations-watershed_after](assets/screenshots/operations-watershed_after.png)  
 
     """
+    import skimage
+    import scipy
+
     # Compute distance transform of binary volume
     distance= scipy.ndimage.distance_transform_edt(bin_vol)
 
diff --git a/qim3d/processing/structure_tensor_.py b/qim3d/processing/structure_tensor_.py
index a02e376b..1a1bad3f 100644
--- a/qim3d/processing/structure_tensor_.py
+++ b/qim3d/processing/structure_tensor_.py
@@ -2,7 +2,6 @@
 
 from typing import Tuple
 import numpy as np
-import structure_tensor as st 
 from qim3d.viz.structure_tensor import vectors
 
 
@@ -74,6 +73,7 @@ def structure_tensor(
         ```
 
     """
+    import structure_tensor as st 
 
     if vol.ndim != 3:
         raise ValueError("The input volume must be 3D")
diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py
index 1179a15f..84dd4ee5 100644
--- a/qim3d/tests/viz/test_img.py
+++ b/qim3d/tests/viz/test_img.py
@@ -62,7 +62,7 @@ def test_slices_torch_tensor_input():
 
 def test_slices_wrong_input_format():
     input = 'not_a_volume'
-    with pytest.raises(ValueError, match = 'Input must be a numpy.ndarray or torch.Tensor'):
+    with pytest.raises(ValueError, match = 'Data type not supported'):
         qim3d.viz.slices(input)
 
 def test_slices_not_volume():
diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py
index 19154b70..ea81e53a 100644
--- a/qim3d/utils/augmentations.py
+++ b/qim3d/utils/augmentations.py
@@ -1,6 +1,4 @@
 """Class for choosing the level of data augmentations with albumentations"""
-import albumentations as A
-from albumentations.pytorch import ToTensorV2
 
 class Augmentation:
     """
@@ -54,7 +52,9 @@ class Augmentation:
         Raises:
             ValueError: If `level` is neither None, light, moderate nor heavy.
         """
-        
+        import albumentations as A
+        from albumentations.pytorch import ToTensorV2
+
         # Check if one of standard augmentation levels
         if level not in [None,'light','moderate','heavy']:
             raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
diff --git a/qim3d/utils/cli.py b/qim3d/utils/cli.py
index 92f01e33..e4a64b9a 100644
--- a/qim3d/utils/cli.py
+++ b/qim3d/utils/cli.py
@@ -5,8 +5,15 @@ from qim3d.gui import annotation_tool, data_explorer, iso3d, local_thickness
 from qim3d.io.loading import DataLoader
 from qim3d.utils import image_preview
 from qim3d import __version__ as version
+import outputformat as ouf
 import qim3d
 
+QIM_TITLE = ouf.rainbow(
+    f"\n         _          _____     __ \n  ____ _(_)___ ___ |__  /____/ / \n / __ `/ / __ `__ \ /_ </ __  /  \n/ /_/ / / / / / / /__/ / /_/ /   \n\__, /_/_/ /_/ /_/____/\__,_/    \n  /_/                 v{version}\n\n",
+    return_str=True,
+    cmap="hot",
+)
+
 
 def main():
     parser = argparse.ArgumentParser(description="Qim3d command-line interface.")
@@ -139,29 +146,14 @@ def main():
         )
 
     elif args.subcommand is None:
+        print(QIM_TITLE)
         welcome_text = (
-            "\n"
-            "         _          _____     __ \n"
-            "  ____ _(_)___ ___ |__  /____/ / \n"
-            " / __ `/ / __ `__ \ /_ </ __  /  \n"
-            "/ /_/ / / / / / / /__/ / /_/ /   \n"
-            "\__, /_/_/ /_/ /_/____/\__,_/    \n"
-            "  /_/                            \n"
-            "\n"
-            "--- Welcome to qim3d command-line interface ---\n"
-            "qim3d is a Python package for 3D image processing and visualization.\n"
-            "For more information, please visit: https://platform.qim.dk/qim3d/\n"
-            f"Current version of qim3d: {version}\n"
-            " \n"
-            "The qim3d command-line interface provides the following subcommands:\n"
-            "- gui: Graphical User Interfaces\n"
-            "- viz: Volumetric visualizations of volumes\n"
-            "- preview: Preview of an volume directly in the terminal\n"
+            "\nqim3d is a Python package for 3D image processing and visualization.\n"
+            f"For more information, please visit {ouf.c('https://platform.qim.dk/qim3d/', color='orange', return_str=True)}\n"
             " \n"
             "For more information on each subcommand, type 'qim3d <subcommand> --help'.\n"
         )
         print(welcome_text)
-        print("--- Help page for qim3d command-line interface shown below ---\n")
         parser.print_help()
         print("\n")
 
diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py
index 332856e4..ae910179 100644
--- a/qim3d/utils/data.py
+++ b/qim3d/utils/data.py
@@ -2,8 +2,6 @@
 from pathlib import Path
 from PIL import Image
 from qim3d.io.logger import log
-from torch.utils.data import DataLoader
-
 import torch
 import numpy as np
 
@@ -187,7 +185,8 @@ def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train
         num_workers (int, optional): Defines how many processes should be run in parallel.
         pin_memory (bool, optional): Loads the datasets as CUDA tensors.
     """
-    
+    from torch.utils.data import DataLoader
+
     train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory)
     val_loader = DataLoader(dataset=val_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
     test_loader = DataLoader(dataset=test_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
diff --git a/qim3d/viz/colormaps.py b/qim3d/viz/colormaps.py
index 547e10fb..e99b0fd8 100644
--- a/qim3d/viz/colormaps.py
+++ b/qim3d/viz/colormaps.py
@@ -1,16 +1,16 @@
 """
 This module provides a collection of colormaps useful for 3D visualization.
 """
-   
+
 import colorsys
 from typing import Union, Tuple
 import numpy as np
 import math
 from matplotlib.colors import LinearSegmentedColormap
 from matplotlib import colormaps
-from skimage import color
 
-def rearrange_colors(randRGBcolors_old, min_dist = 0.5):
+
+def rearrange_colors(randRGBcolors_old, min_dist=0.5):
     # Create new list for re-arranged colors
     randRGBcolors_new = [randRGBcolors_old.pop(0)]
 
@@ -32,6 +32,7 @@ def rearrange_colors(randRGBcolors_old, min_dist = 0.5):
 
     return randRGBcolors_new
 
+
 def objects(
     nlabels: int,
     style: str = "bright",
@@ -66,12 +67,12 @@ def objects(
         cmap_earth = qim3d.viz.colormaps.objects(nlabels=100, style = 'earth', first_color_background=True, background_color="black", min_dist=0.8)
         cmap_ocean = qim3d.viz.colormaps.objects(nlabels=100, style = 'ocean', first_color_background=True, background_color="black", min_dist=0.9)
 
-        display(cmap_bright) 
+        display(cmap_bright)
         display(cmap_soft)
         display(cmap_earth)
         display(cmap_ocean)
         ```
-        ![colormap objects](assets/screenshots/viz-colormaps-objects-all.png)  
+        ![colormap objects](assets/screenshots/viz-colormaps-objects-all.png)
 
         ```python
         import qim3d
@@ -83,14 +84,16 @@ def objects(
         cmap = qim3d.viz.colormaps.objects(num_labels, style = 'bright')
         qim3d.viz.slicer(labeled_volume, axis = 1, cmap=cmap)
         ```
-        ![colormap objects](assets/screenshots/viz-colormaps-objects.gif) 
-    
+        ![colormap objects](assets/screenshots/viz-colormaps-objects.gif)
+
     Tip:
-        The `min_dist` parameter can be used to control the distance between neighboring colors. 
-        ![colormap objects mind_dist](assets/screenshots/viz-colormaps-min_dist.gif) 
-    
+        The `min_dist` parameter can be used to control the distance between neighboring colors.
+        ![colormap objects mind_dist](assets/screenshots/viz-colormaps-min_dist.gif)
+
 
     """
+    from skimage import color
+
     # Check style
     if style not in ("bright", "soft", "earth", "ocean"):
         raise ValueError(
@@ -148,9 +151,9 @@ def objects(
     if style == "earth":
         randLABColors = [
             (
-                rng.uniform(low=25, high=110),  
-                rng.uniform(low=-120, high=70),  
-                rng.uniform(low=-70, high=70),  
+                rng.uniform(low=25, high=110),
+                rng.uniform(low=-120, high=70),
+                rng.uniform(low=-70, high=70),
             )
             for i in range(nlabels)
         ]
@@ -158,17 +161,15 @@ def objects(
         # Convert LAB list to RGB
         randRGBcolors = []
         for LabColor in randLABColors:
-            randRGBcolors.append(
-                color.lab2rgb([[LabColor]])[0][0].tolist()
-            )
+            randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())
 
     # Generate color map for ocean colors, based on LAB
     if style == "ocean":
         randLABColors = [
             (
-                rng.uniform(low=0, high=110), 
-                rng.uniform(low=-128, high=160),  
-                rng.uniform(low=-128, high=0), 
+                rng.uniform(low=0, high=110),
+                rng.uniform(low=-128, high=160),
+                rng.uniform(low=-128, high=0),
             )
             for i in range(nlabels)
         ]
@@ -176,10 +177,8 @@ def objects(
         # Convert LAB list to RGB
         randRGBcolors = []
         for LabColor in randLABColors:
-            randRGBcolors.append(
-                color.lab2rgb([[LabColor]])[0][0].tolist()
-                )
-            
+            randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())
+
     # Re-arrange colors to have a minimum distance between neighboring colors
     randRGBcolors = rearrange_colors(randRGBcolors, min_dist)
 
@@ -191,16 +190,18 @@ def objects(
         randRGBcolors[-1] = background_color
 
     # Create colormap
-    objects = LinearSegmentedColormap.from_list(
-        "objects", randRGBcolors, N=nlabels
-    )
+    objects = LinearSegmentedColormap.from_list("objects", randRGBcolors, N=nlabels)
 
     return objects
 
-qim = LinearSegmentedColormap.from_list('qim', 
-                                        [(0.6, 0.0, 0.0), #990000
-                                         (1.0, 0.6, 0.0), #ff9900
-                                         ])
+
+qim = LinearSegmentedColormap.from_list(
+    "qim",
+    [
+        (0.6, 0.0, 0.0),  # 990000
+        (1.0, 0.6, 0.0),  # ff9900
+    ],
+)
 """
 Defines colormap in QIM logo colors. Can be accessed as module attribute or easily by ```cmap = 'qim'```
 
@@ -213,4 +214,4 @@ Example:
     ```
     ![colormap objects](assets/screenshots/viz-colormaps-qim.png)
 """
-colormaps.register(qim)
\ No newline at end of file
+colormaps.register(qim)
diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py
index e612568b..fc3a2cf4 100644
--- a/qim3d/viz/k3d.py
+++ b/qim3d/viz/k3d.py
@@ -7,9 +7,7 @@ Volumetric visualization using K3D
 
 """
 
-import k3d
 import numpy as np
-
 from qim3d.io.logger import log
 from qim3d.utils.internal_tools import downscale_img, scale_to_float16
 
@@ -73,6 +71,7 @@ def vol(
         ```
 
     """
+    import k3d
 
     pixel_count = img.shape[0] * img.shape[1] * img.shape[2]
     # target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html
diff --git a/requirements.txt b/requirements.txt
index c22d3ed7..0f8cd33d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,25 +1,25 @@
-albumentations>=1.3.1,
-gradio>=4.27.0,
-h5py>=3.9.0,
-localthickness>=0.1.2,
-matplotlib>=3.8.0,
-monai>=1.2.0,
-numpy>=1.26.0,
-outputformat>=0.1.3,
-Pillow>=10.0.1,
-plotly>=5.14.1,
-scipy>=1.11.2,
-seaborn>=0.12.2,
-pydicom>=2.4.4,
-setuptools>=68.0.0,
-tifffile>=2023.4.12,
-torch>=2.0.1,
-torchvision>=0.15.2,
-torchinfo>=1.8.0,
-tqdm>=4.65.0,
-nibabel>=5.2.0,
-ipywidgets>=8.1.2,
-dask>=2023.6.0,
+albumentations>=1.3.1
+gradio>=4.27.0
+h5py>=3.9.0
+localthickness>=0.1.2
+matplotlib>=3.8.0
+monai>=1.2.0
+numpy>=1.26.0
+outputformat>=0.1.3
+Pillow>=10.0.1
+plotly>=5.14.1
+scipy>=1.11.2
+seaborn>=0.12.2
+pydicom>=2.4.4
+setuptools>=68.0.0
+tifffile>=2023.4.12
+torch>=2.0.1
+torchvision>=0.15.2
+torchinfo>=1.8.0
+tqdm>=4.65.0
+nibabel>=5.2.0
+ipywidgets>=8.1.2
+dask>=2023.6.0
 k3d>=2.16.1
 olefile>=0.46
 psutil>=5.9.0
diff --git a/setup.py b/setup.py
index 017d277a..5efe335a 100644
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f:
 
 setup(
     name="qim3d",
-    version="0.3.6",
+    version="0.3.7",
     author="Felipe Delestro",
     author_email="fima@dtu.dk",
     description="QIM tools and user interfaces",
-- 
GitLab