From 9e62018acd6f9c3be20a546aec5b1230d7a87144 Mon Sep 17 00:00:00 2001 From: fima <fima@dtu.dk> Date: Mon, 17 Jun 2024 13:06:00 +0200 Subject: [PATCH] Import speed refactoring --- docs/releases.md | 4 + qim3d/__init__.py | 10 +- qim3d/gui/annotation_tool.py | 6 +- qim3d/io/loading.py | 13 +- qim3d/io/saving.py | 94 ++++++++------ qim3d/models/unet.py | 180 ++++++++++++++++---------- qim3d/processing/cc.py | 1 - qim3d/processing/detection.py | 3 +- qim3d/processing/local_thickness_.py | 14 +- qim3d/processing/operations.py | 6 +- qim3d/processing/structure_tensor_.py | 2 +- qim3d/tests/viz/test_img.py | 2 +- qim3d/utils/augmentations.py | 6 +- qim3d/utils/cli.py | 28 ++-- qim3d/utils/data.py | 5 +- qim3d/viz/colormaps.py | 63 ++++----- qim3d/viz/k3d.py | 3 +- requirements.txt | 44 +++---- setup.py | 2 +- 19 files changed, 263 insertions(+), 223 deletions(-) diff --git a/docs/releases.md b/docs/releases.md index b6e15863..257c7d49 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -9,6 +9,10 @@ As the library is still in its early development stages, **there may be breaking And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that you have the latest features! +### v0.3.7 (17/06/2024) +- Performance improvements when importing +- Refactoring for blob detection + ### v0.3.6 (30/05/2024) - Refactoring for performance improvement - Welcome message for the CLI diff --git a/qim3d/__init__.py b/qim3d/__init__.py index c89b6fd0..ebb4e7b2 100644 --- a/qim3d/__init__.py +++ b/qim3d/__init__.py @@ -8,20 +8,14 @@ Documentation available at https://platform.qim.dk/qim3d/ """ -__version__ = "0.3.6" - -import logging - -logging.basicConfig(level=logging.ERROR) +__version__ = "0.3.7" from . import io from . import gui from . import viz from . import utils from . import processing - -# Commenting out models because it takes too long to import -# from . import models +from . import models examples = io.ImgExamples() io.logger.set_level_info() diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py index da664528..40b3fda8 100644 --- a/qim3d/gui/annotation_tool.py +++ b/qim3d/gui/annotation_tool.py @@ -23,16 +23,11 @@ app = annotation_tool.launch(vol[0]) import getpass import os import tempfile -import time import gradio as gr import numpy as np -import tifffile -from PIL import Image - import qim3d.utils from qim3d.io import load, save -from qim3d.io.logger import log class Session: @@ -100,6 +95,7 @@ class Interface: return gr.update(visible=True) def create_interface(self, img=None): + from PIL import Image if img is not None: custom_css = "annotation-tool" diff --git a/qim3d/io/loading.py b/qim3d/io/loading.py index e45eb121..28f5abe4 100644 --- a/qim3d/io/loading.py +++ b/qim3d/io/loading.py @@ -17,11 +17,7 @@ from pathlib import Path import dask import dask.array as da -import h5py -import nibabel as nib import numpy as np -import olefile -import pydicom import tifffile from dask import delayed from PIL import Image, UnidentifiedImageError @@ -122,6 +118,7 @@ class DataLoader: ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file ValueError: If no datasets are found in the file. """ + import h5py # Read file f = h5py.File(path, "r") @@ -256,6 +253,7 @@ class DataLoader: Raises: ValueError: If the dxchange library is not installed """ + import olefile try: import dxchange @@ -323,6 +321,7 @@ class DataLoader: If 'self.virtual_stack' is True, returns a nibabel.arrayproxy.ArrayProxy object If 'self.return_metadata' is True, returns a tuple (volume, metadata). """ + import nibabel as nib data = nib.load(path) @@ -557,6 +556,8 @@ class DataLoader: Args: path (str): Path to file """ + import pydicom + dcm_data = pydicom.dcmread(path) if self.return_metadata: @@ -570,6 +571,8 @@ class DataLoader: Args: path (str): Directory path """ + import pydicom + if not self.contains: raise ValueError( "Please specify a part of the name that is common for the DICOM file stack with the argument 'contains'" @@ -709,6 +712,8 @@ class DataLoader: def _get_h5_dataset_keys(f): + import h5py + keys = [] f.visit(lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None) return keys diff --git a/qim3d/io/saving.py b/qim3d/io/saving.py index 64a6a1ca..675fa20e 100644 --- a/qim3d/io/saving.py +++ b/qim3d/io/saving.py @@ -21,18 +21,12 @@ Example: ``` """ + import datetime import os - -import h5py -import nibabel as nib import numpy as np import PIL -import pydicom import tifffile -from pydicom.dataset import FileDataset, FileMetaDataset -from pydicom.uid import UID - from qim3d.io.logger import log from qim3d.utils.internal_tools import sizeof, stringify_path @@ -116,24 +110,30 @@ class DataSaver: filepath = os.path.join(path, filename) self.save_tiff(filepath, sliced) - pattern_string = filepath[:-(len(extension)+zfill_val)] + "-"*zfill_val + extension + pattern_string = ( + filepath[: -(len(extension) + zfill_val)] + "-" * zfill_val + extension + ) - log.info(f"Total of {no_slices} files saved following the pattern '{pattern_string}'") + log.info( + f"Total of {no_slices} files saved following the pattern '{pattern_string}'" + ) def save_nifti(self, path, data): - """ Save data to a NIfTI file to the given path. + """Save data to a NIfTI file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved """ + import nibabel as nib + # Create header header = nib.Nifti1Header() header.set_data_dtype(data.dtype) # Create NIfTI image object img = nib.Nifti1Image(data, np.eye(4), header) - + # nib does automatically compress if filetype ends with .gz if self.compression and not path.endswith(".gz"): path += ".gz" @@ -141,13 +141,15 @@ class DataSaver: if not self.compression and path.endswith(".gz"): path = path[:-3] - log.warning("File extension '.gz' is ignored since compression is disabled.") + log.warning( + "File extension '.gz' is ignored since compression is disabled." + ) # Save image nib.save(img, path) def save_vol(self, path, data): - """ Save data to a VOL file to the given path. + """Save data to a VOL file to the given path. Args: path (str): The path to save file to @@ -155,15 +157,21 @@ class DataSaver: """ # No support for compression yet if self.compression: - raise NotImplementedError("Saving compressed .vol files is not yet supported") + raise NotImplementedError( + "Saving compressed .vol files is not yet supported" + ) # Create custom .vgi metadata file metadata = "" - metadata += "{volume1}\n" # .vgi organization - metadata += "[file1]\n" # .vgi organization - metadata += "Size = {} {} {}\n".format(data.shape[1], data.shape[2], data.shape[0]) # Swap axes to match .vol format - metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string - metadata += "Name = {}.vol\n".format(path.rsplit('/', 1)[-1][:-4]) # Get filename without extension + metadata += "{volume1}\n" # .vgi organization + metadata += "[file1]\n" # .vgi organization + metadata += "Size = {} {} {}\n".format( + data.shape[1], data.shape[2], data.shape[0] + ) # Swap axes to match .vol format + metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string + metadata += "Name = {}.vol\n".format( + path.rsplit("/", 1)[-1][:-4] + ) # Get filename without extension # Save metadata with open(path[:-4] + ".vgi", "w") as f: @@ -173,39 +181,45 @@ class DataSaver: data.tofile(path[:-4] + ".vol") def save_h5(self, path, data): - """ Save data to a HDF5 file to the given path. + """Save data to a HDF5 file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved """ + import h5py with h5py.File(path, "w") as f: - f.create_dataset("dataset", data=data, compression="gzip" if self.compression else None) - + f.create_dataset( + "dataset", data=data, compression="gzip" if self.compression else None + ) + def save_dicom(self, path, data): - """ Save data to a DICOM file to the given path. + """Save data to a DICOM file to the given path. Args: path (str): The path to save file to data (numpy.ndarray): The data to be saved """ + import pydicom + from pydicom.dataset import FileDataset, FileMetaDataset + from pydicom.uid import UID + # based on https://pydicom.github.io/pydicom/stable/auto_examples/input_output/plot_write_dicom.html # Populate required values for file meta information file_meta = FileMetaDataset() - file_meta.MediaStorageSOPClassUID = UID('1.2.840.10008.5.1.4.1.1.2') + file_meta.MediaStorageSOPClassUID = UID("1.2.840.10008.5.1.4.1.1.2") file_meta.MediaStorageSOPInstanceUID = UID("1.2.3") file_meta.ImplementationClassUID = UID("1.2.3.4") # Create the FileDataset instance (initially no data elements, but file_meta # supplied) - ds = FileDataset(path, {}, - file_meta=file_meta, preamble=b"\0" * 128) + ds = FileDataset(path, {}, file_meta=file_meta, preamble=b"\0" * 128) ds.PatientName = "Test^Firstname" ds.PatientID = "123456" - ds.StudyInstanceUID = "1.2.3.4.5" + ds.StudyInstanceUID = "1.2.3.4.5" ds.SamplesPerPixel = 1 ds.PixelRepresentation = 0 ds.BitsStored = 16 @@ -220,8 +234,8 @@ class DataSaver: # Set creation date/time dt = datetime.datetime.now() - ds.ContentDate = dt.strftime('%Y%m%d') - timeStr = dt.strftime('%H%M%S.%f') # long format with micro seconds + ds.ContentDate = dt.strftime("%Y%m%d") + timeStr = dt.strftime("%H%M%S.%f") # long format with micro seconds ds.ContentTime = timeStr # Needs to be here because of bug in pydicom ds.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian @@ -234,10 +248,9 @@ class DataSaver: ds.PixelData = data_bytes ds.save_as(path) - - + def save_PIL(self, path, data): - """ Save data to a PIL file to the given path. + """Save data to a PIL file to the given path. Args: path (str): The path to save file to @@ -246,7 +259,7 @@ class DataSaver: # No support for compression yet if self.compression and path.endswith(".png"): raise NotImplementedError("png does not support compression") - elif not self.compression and path.endswith((".jpeg",".jpg")): + elif not self.compression and path.endswith((".jpeg", ".jpg")): raise NotImplementedError("jpeg does not support no compression") # Convert to PIL image @@ -255,7 +268,6 @@ class DataSaver: # Save image img.save(path) - def save(self, path, data): """Save data to the given path. @@ -320,17 +332,19 @@ class DataSaver: if path.endswith((".tif", ".tiff")): return self.save_tiff(path, data) - elif path.endswith((".nii","nii.gz")): + elif path.endswith((".nii", "nii.gz")): return self.save_nifti(path, data) - elif path.endswith(("TXRM","XRM","TXM")): - raise NotImplementedError("Saving TXRM files is not yet supported") + elif path.endswith(("TXRM", "XRM", "TXM")): + raise NotImplementedError( + "Saving TXRM files is not yet supported" + ) elif path.endswith((".h5")): return self.save_h5(path, data) - elif path.endswith((".vol",".vgi")): + elif path.endswith((".vol", ".vgi")): return self.save_vol(path, data) - elif path.endswith((".dcm",".DCM")): + elif path.endswith((".dcm", ".DCM")): return self.save_dicom(path, data) - elif path.endswith((".jpeg",".jpg", ".png")): + elif path.endswith((".jpeg", ".jpg", ".png")): return self.save_PIL(path, data) else: raise ValueError("Unsupported file format") diff --git a/qim3d/models/unet.py b/qim3d/models/unet.py index c85648fd..6ca19bcb 100644 --- a/qim3d/models/unet.py +++ b/qim3d/models/unet.py @@ -1,14 +1,10 @@ """UNet model and Hyperparameters class.""" -from monai.networks.nets import UNet as monai_UNet -from monai.losses import FocalLoss, DiceLoss, DiceCELoss - import torch.nn as nn -from torch.nn import BCEWithLogitsLoss -from torch.optim import Adam, SGD, RMSprop from qim3d.io.logger import log + class UNet(nn.Module): """ 2D UNet model for QIM imaging. @@ -32,20 +28,23 @@ class UNet(nn.Module): model = UNet(size='large') ``` """ - def __init__(self, size = 'medium', - dropout = 0, - kernel_size = 3, - up_kernel_size = 3, - activation = 'PReLU', - bias = True, - adn_order = 'NDA' - ): + + def __init__( + self, + size="medium", + dropout=0, + kernel_size=3, + up_kernel_size=3, + activation="PReLU", + bias=True, + adn_order="NDA", + ): super().__init__() - if size not in ['small','medium','large']: + if size not in ["small", "medium", "large"]: raise ValueError( f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'." ) - + self.size = size self.dropout = dropout self.kernel_size = kernel_size @@ -53,21 +52,22 @@ class UNet(nn.Module): self.activation = activation self.bias = bias self.adn_order = adn_order - + self.model = self._model_choice() - + def _model_choice(self): - - if self.size == 'small': + from monai.networks.nets import UNet as monai_UNet + + if self.size == "small": self.channels = (64, 128, 256) - elif self.size == 'medium': + elif self.size == "medium": self.channels = (64, 128, 256, 512, 1024) - elif self.size == 'large': + elif self.size == "large": self.channels = (64, 128, 256, 512, 1024, 2048) model = monai_UNet( spatial_dims=2, - in_channels=1, #TODO: check if image has 1 or multiple input channels + in_channels=1, # TODO: check if image has 1 or multiple input channels out_channels=1, channels=self.channels, strides=(2,) * (len(self.channels) - 1), @@ -76,12 +76,11 @@ class UNet(nn.Module): act=self.activation, dropout=self.dropout, bias=self.bias, - adn_ordering=self.adn_order + adn_ordering=self.adn_order, ) return model - - - def forward(self,x): + + def forward(self, x): x = self.model(x) return x @@ -114,29 +113,37 @@ class Hyperparameters: n_epochs = params_dict['n_epochs'] ``` """ - def __init__(self, - model, - n_epochs = 10, - learning_rate = 1e-3, - optimizer = 'Adam', - momentum = 0, - weight_decay = 0, - loss_function = 'Focal'): + def __init__( + self, + model, + n_epochs=10, + learning_rate=1e-3, + optimizer="Adam", + momentum=0, + weight_decay=0, + loss_function="Focal", + ): # TODO: implement custom loss_functions? then add a check to see if loss works for segmentation. - if loss_function not in ['BCE','Dice','Focal','DiceCE']: - raise ValueError(f"Invalid loss function: {loss_function}. Loss criterion must " - "be one of the following: 'BCE','Dice','Focal','DiceCE'.") + if loss_function not in ["BCE", "Dice", "Focal", "DiceCE"]: + raise ValueError( + f"Invalid loss function: {loss_function}. Loss criterion must " + "be one of the following: 'BCE','Dice','Focal','DiceCE'." + ) # TODO: implement custom optimizer? and add check to see if valid. - if optimizer not in ['Adam','SGD','RMSprop']: - raise ValueError(f"Invalid optimizer: {optimizer}. Optimizer must " - "be one of the following: 'Adam', 'SGD', 'RMSprop'.") - - if (momentum != 0) and optimizer == 'Adam': - log.info("Momentum isn't an input in the 'Adam' optimizer. " - "Change optimizer to 'SGD' or 'RMSprop' to use momentum.") - + if optimizer not in ["Adam", "SGD", "RMSprop"]: + raise ValueError( + f"Invalid optimizer: {optimizer}. Optimizer must " + "be one of the following: 'Adam', 'SGD', 'RMSprop'." + ) + + if (momentum != 0) and optimizer == "Adam": + log.info( + "Momentum isn't an input in the 'Adam' optimizer. " + "Change optimizer to 'SGD' or 'RMSprop' to use momentum." + ) + self.model = model self.n_epochs = n_epochs self.learning_rate = learning_rate @@ -146,41 +153,72 @@ class Hyperparameters: self.loss_function = loss_function def __call__(self): - return self.model_params(self.model, self.n_epochs, self.optimizer, self.learning_rate, - self.weight_decay, self.momentum, self.loss_function) + return self.model_params( + self.model, + self.n_epochs, + self.optimizer, + self.learning_rate, + self.weight_decay, + self.momentum, + self.loss_function, + ) - def model_params(self, model, n_epochs, optimizer, learning_rate, weight_decay, momentum, loss_function): + def model_params( + self, + model, + n_epochs, + optimizer, + learning_rate, + weight_decay, + momentum, + loss_function, + ): optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum) criterion = self._loss_functions(loss_function) - hyper_dict = {'optimizer': optim, - 'criterion': criterion, - 'n_epochs' : n_epochs, - } + hyper_dict = { + "optimizer": optim, + "criterion": criterion, + "n_epochs": n_epochs, + } return hyper_dict # selecting the optimizer def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum): - if optimizer == 'Adam': - optim = Adam(model.parameters(), lr = learning_rate, - weight_decay = weight_decay) - elif optimizer == 'SGD': - optim = SGD(model.parameters(), lr = learning_rate, - momentum = momentum, weight_decay = weight_decay) - elif optimizer == 'RMSprop': - optim = RMSprop(model.parameters(),lr = learning_rate, - weight_decay = weight_decay, momentum = momentum) + from torch.optim import Adam, SGD, RMSprop + + if optimizer == "Adam": + optim = Adam( + model.parameters(), lr=learning_rate, weight_decay=weight_decay + ) + elif optimizer == "SGD": + optim = SGD( + model.parameters(), + lr=learning_rate, + momentum=momentum, + weight_decay=weight_decay, + ) + elif optimizer == "RMSprop": + optim = RMSprop( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + momentum=momentum, + ) return optim # selecting the loss function - def _loss_functions(self,loss_function): - if loss_function =='BCE': - criterion = BCEWithLogitsLoss(reduction='mean') - elif loss_function == 'Dice': - criterion = DiceLoss(sigmoid=True,reduction='mean') - elif loss_function == 'Focal': - criterion = FocalLoss(reduction='mean') - elif loss_function == 'DiceCE': - criterion = DiceCELoss(sigmoid=True,reduction='mean') - return criterion \ No newline at end of file + def _loss_functions(self, loss_function): + from monai.losses import FocalLoss, DiceLoss, DiceCELoss + from torch.nn import BCEWithLogitsLoss + + if loss_function == "BCE": + criterion = BCEWithLogitsLoss(reduction="mean") + elif loss_function == "Dice": + criterion = DiceLoss(sigmoid=True, reduction="mean") + elif loss_function == "Focal": + criterion = FocalLoss(reduction="mean") + elif loss_function == "DiceCE": + criterion = DiceCELoss(sigmoid=True, reduction="mean") + return criterion diff --git a/qim3d/processing/cc.py b/qim3d/processing/cc.py index d6ec0e85..48cc004a 100644 --- a/qim3d/processing/cc.py +++ b/qim3d/processing/cc.py @@ -1,7 +1,6 @@ import numpy as np import torch from scipy.ndimage import find_objects, label - from qim3d.io.logger import log diff --git a/qim3d/processing/detection.py b/qim3d/processing/detection.py index c5743a15..b967e844 100644 --- a/qim3d/processing/detection.py +++ b/qim3d/processing/detection.py @@ -2,8 +2,6 @@ import numpy as np from qim3d.io.logger import log -from skimage.feature import blob_dog - def blob_detection( vol: np.ndarray, @@ -63,6 +61,7 @@ def blob_detection( ```  """ + from skimage.feature import blob_dog if background == "bright": log.info("Bright background selected, volume will be inverted.") diff --git a/qim3d/processing/local_thickness_.py b/qim3d/processing/local_thickness_.py index f0de7e99..7a9e30aa 100644 --- a/qim3d/processing/local_thickness_.py +++ b/qim3d/processing/local_thickness_.py @@ -1,13 +1,11 @@ """Wrapper for the local thickness function from the localthickness package including visualization functions.""" -import localthickness as lt import numpy as np from typing import Optional -from skimage.filters import threshold_otsu from qim3d.io.logger import log -#from qim3d.viz import local_thickness as viz_local_thickness import qim3d + def local_thickness( image: np.ndarray, scale: float = 1, @@ -17,10 +15,10 @@ def local_thickness( ) -> np.ndarray: """Wrapper for the local thickness function from the [local thickness package](https://github.com/vedranaa/local-thickness) - The "Fast Local Thickness" by Vedrana Andersen Dahl and Anders Bjorholm Dahl from the Technical University of Denmark is a efficient algorithm for computing local thickness in 2D and 3D images. - Their method significantly reduces computation time compared to traditional algorithms by utilizing iterative dilation with small structuring elements, rather than the large ones typically used. - This approach allows the local thickness to be determined much faster, making it feasible for high-resolution volumetric data that are common in contemporary 3D microscopy. - + The "Fast Local Thickness" by Vedrana Andersen Dahl and Anders Bjorholm Dahl from the Technical University of Denmark is a efficient algorithm for computing local thickness in 2D and 3D images. + Their method significantly reduces computation time compared to traditional algorithms by utilizing iterative dilation with small structuring elements, rather than the large ones typically used. + This approach allows the local thickness to be determined much faster, making it feasible for high-resolution volumetric data that are common in contemporary 3D microscopy. + Testing against conventional methods and other Python-based tools like PoreSpy shows that the new algorithm is both accurate and faster, offering significant improvements in processing time for large datasets. @@ -79,6 +77,8 @@ def local_thickness( """ + import localthickness as lt + from skimage.filters import threshold_otsu # Check if input is binary if np.unique(image).size > 2: diff --git a/qim3d/processing/operations.py b/qim3d/processing/operations.py index dfaeb42e..1efaf9c9 100644 --- a/qim3d/processing/operations.py +++ b/qim3d/processing/operations.py @@ -1,7 +1,4 @@ import numpy as np -import scipy -import skimage - import qim3d.processing.filters as filters from qim3d.io.logger import log @@ -86,6 +83,9 @@ def watershed(  """ + import skimage + import scipy + # Compute distance transform of binary volume distance= scipy.ndimage.distance_transform_edt(bin_vol) diff --git a/qim3d/processing/structure_tensor_.py b/qim3d/processing/structure_tensor_.py index a02e376b..1a1bad3f 100644 --- a/qim3d/processing/structure_tensor_.py +++ b/qim3d/processing/structure_tensor_.py @@ -2,7 +2,6 @@ from typing import Tuple import numpy as np -import structure_tensor as st from qim3d.viz.structure_tensor import vectors @@ -74,6 +73,7 @@ def structure_tensor( ``` """ + import structure_tensor as st if vol.ndim != 3: raise ValueError("The input volume must be 3D") diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index 1179a15f..84dd4ee5 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -62,7 +62,7 @@ def test_slices_torch_tensor_input(): def test_slices_wrong_input_format(): input = 'not_a_volume' - with pytest.raises(ValueError, match = 'Input must be a numpy.ndarray or torch.Tensor'): + with pytest.raises(ValueError, match = 'Data type not supported'): qim3d.viz.slices(input) def test_slices_not_volume(): diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py index 19154b70..ea81e53a 100644 --- a/qim3d/utils/augmentations.py +++ b/qim3d/utils/augmentations.py @@ -1,6 +1,4 @@ """Class for choosing the level of data augmentations with albumentations""" -import albumentations as A -from albumentations.pytorch import ToTensorV2 class Augmentation: """ @@ -54,7 +52,9 @@ class Augmentation: Raises: ValueError: If `level` is neither None, light, moderate nor heavy. """ - + import albumentations as A + from albumentations.pytorch import ToTensorV2 + # Check if one of standard augmentation levels if level not in [None,'light','moderate','heavy']: raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.") diff --git a/qim3d/utils/cli.py b/qim3d/utils/cli.py index 92f01e33..e4a64b9a 100644 --- a/qim3d/utils/cli.py +++ b/qim3d/utils/cli.py @@ -5,8 +5,15 @@ from qim3d.gui import annotation_tool, data_explorer, iso3d, local_thickness from qim3d.io.loading import DataLoader from qim3d.utils import image_preview from qim3d import __version__ as version +import outputformat as ouf import qim3d +QIM_TITLE = ouf.rainbow( + f"\n _ _____ __ \n ____ _(_)___ ___ |__ /____/ / \n / __ `/ / __ `__ \ /_ </ __ / \n/ /_/ / / / / / / /__/ / /_/ / \n\__, /_/_/ /_/ /_/____/\__,_/ \n /_/ v{version}\n\n", + return_str=True, + cmap="hot", +) + def main(): parser = argparse.ArgumentParser(description="Qim3d command-line interface.") @@ -139,29 +146,14 @@ def main(): ) elif args.subcommand is None: + print(QIM_TITLE) welcome_text = ( - "\n" - " _ _____ __ \n" - " ____ _(_)___ ___ |__ /____/ / \n" - " / __ `/ / __ `__ \ /_ </ __ / \n" - "/ /_/ / / / / / / /__/ / /_/ / \n" - "\__, /_/_/ /_/ /_/____/\__,_/ \n" - " /_/ \n" - "\n" - "--- Welcome to qim3d command-line interface ---\n" - "qim3d is a Python package for 3D image processing and visualization.\n" - "For more information, please visit: https://platform.qim.dk/qim3d/\n" - f"Current version of qim3d: {version}\n" - " \n" - "The qim3d command-line interface provides the following subcommands:\n" - "- gui: Graphical User Interfaces\n" - "- viz: Volumetric visualizations of volumes\n" - "- preview: Preview of an volume directly in the terminal\n" + "\nqim3d is a Python package for 3D image processing and visualization.\n" + f"For more information, please visit {ouf.c('https://platform.qim.dk/qim3d/', color='orange', return_str=True)}\n" " \n" "For more information on each subcommand, type 'qim3d <subcommand> --help'.\n" ) print(welcome_text) - print("--- Help page for qim3d command-line interface shown below ---\n") parser.print_help() print("\n") diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py index 332856e4..ae910179 100644 --- a/qim3d/utils/data.py +++ b/qim3d/utils/data.py @@ -2,8 +2,6 @@ from pathlib import Path from PIL import Image from qim3d.io.logger import log -from torch.utils.data import DataLoader - import torch import numpy as np @@ -187,7 +185,8 @@ def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train num_workers (int, optional): Defines how many processes should be run in parallel. pin_memory (bool, optional): Loads the datasets as CUDA tensors. """ - + from torch.utils.data import DataLoader + train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory) val_loader = DataLoader(dataset=val_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) test_loader = DataLoader(dataset=test_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) diff --git a/qim3d/viz/colormaps.py b/qim3d/viz/colormaps.py index 547e10fb..e99b0fd8 100644 --- a/qim3d/viz/colormaps.py +++ b/qim3d/viz/colormaps.py @@ -1,16 +1,16 @@ """ This module provides a collection of colormaps useful for 3D visualization. """ - + import colorsys from typing import Union, Tuple import numpy as np import math from matplotlib.colors import LinearSegmentedColormap from matplotlib import colormaps -from skimage import color -def rearrange_colors(randRGBcolors_old, min_dist = 0.5): + +def rearrange_colors(randRGBcolors_old, min_dist=0.5): # Create new list for re-arranged colors randRGBcolors_new = [randRGBcolors_old.pop(0)] @@ -32,6 +32,7 @@ def rearrange_colors(randRGBcolors_old, min_dist = 0.5): return randRGBcolors_new + def objects( nlabels: int, style: str = "bright", @@ -66,12 +67,12 @@ def objects( cmap_earth = qim3d.viz.colormaps.objects(nlabels=100, style = 'earth', first_color_background=True, background_color="black", min_dist=0.8) cmap_ocean = qim3d.viz.colormaps.objects(nlabels=100, style = 'ocean', first_color_background=True, background_color="black", min_dist=0.9) - display(cmap_bright) + display(cmap_bright) display(cmap_soft) display(cmap_earth) display(cmap_ocean) ``` -  +  ```python import qim3d @@ -83,14 +84,16 @@ def objects( cmap = qim3d.viz.colormaps.objects(num_labels, style = 'bright') qim3d.viz.slicer(labeled_volume, axis = 1, cmap=cmap) ``` -  - +  + Tip: - The `min_dist` parameter can be used to control the distance between neighboring colors. -  - + The `min_dist` parameter can be used to control the distance between neighboring colors. +  + """ + from skimage import color + # Check style if style not in ("bright", "soft", "earth", "ocean"): raise ValueError( @@ -148,9 +151,9 @@ def objects( if style == "earth": randLABColors = [ ( - rng.uniform(low=25, high=110), - rng.uniform(low=-120, high=70), - rng.uniform(low=-70, high=70), + rng.uniform(low=25, high=110), + rng.uniform(low=-120, high=70), + rng.uniform(low=-70, high=70), ) for i in range(nlabels) ] @@ -158,17 +161,15 @@ def objects( # Convert LAB list to RGB randRGBcolors = [] for LabColor in randLABColors: - randRGBcolors.append( - color.lab2rgb([[LabColor]])[0][0].tolist() - ) + randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist()) # Generate color map for ocean colors, based on LAB if style == "ocean": randLABColors = [ ( - rng.uniform(low=0, high=110), - rng.uniform(low=-128, high=160), - rng.uniform(low=-128, high=0), + rng.uniform(low=0, high=110), + rng.uniform(low=-128, high=160), + rng.uniform(low=-128, high=0), ) for i in range(nlabels) ] @@ -176,10 +177,8 @@ def objects( # Convert LAB list to RGB randRGBcolors = [] for LabColor in randLABColors: - randRGBcolors.append( - color.lab2rgb([[LabColor]])[0][0].tolist() - ) - + randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist()) + # Re-arrange colors to have a minimum distance between neighboring colors randRGBcolors = rearrange_colors(randRGBcolors, min_dist) @@ -191,16 +190,18 @@ def objects( randRGBcolors[-1] = background_color # Create colormap - objects = LinearSegmentedColormap.from_list( - "objects", randRGBcolors, N=nlabels - ) + objects = LinearSegmentedColormap.from_list("objects", randRGBcolors, N=nlabels) return objects -qim = LinearSegmentedColormap.from_list('qim', - [(0.6, 0.0, 0.0), #990000 - (1.0, 0.6, 0.0), #ff9900 - ]) + +qim = LinearSegmentedColormap.from_list( + "qim", + [ + (0.6, 0.0, 0.0), # 990000 + (1.0, 0.6, 0.0), # ff9900 + ], +) """ Defines colormap in QIM logo colors. Can be accessed as module attribute or easily by ```cmap = 'qim'``` @@ -213,4 +214,4 @@ Example: ```  """ -colormaps.register(qim) \ No newline at end of file +colormaps.register(qim) diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py index e612568b..fc3a2cf4 100644 --- a/qim3d/viz/k3d.py +++ b/qim3d/viz/k3d.py @@ -7,9 +7,7 @@ Volumetric visualization using K3D """ -import k3d import numpy as np - from qim3d.io.logger import log from qim3d.utils.internal_tools import downscale_img, scale_to_float16 @@ -73,6 +71,7 @@ def vol( ``` """ + import k3d pixel_count = img.shape[0] * img.shape[1] * img.shape[2] # target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html diff --git a/requirements.txt b/requirements.txt index c22d3ed7..0f8cd33d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,25 +1,25 @@ -albumentations>=1.3.1, -gradio>=4.27.0, -h5py>=3.9.0, -localthickness>=0.1.2, -matplotlib>=3.8.0, -monai>=1.2.0, -numpy>=1.26.0, -outputformat>=0.1.3, -Pillow>=10.0.1, -plotly>=5.14.1, -scipy>=1.11.2, -seaborn>=0.12.2, -pydicom>=2.4.4, -setuptools>=68.0.0, -tifffile>=2023.4.12, -torch>=2.0.1, -torchvision>=0.15.2, -torchinfo>=1.8.0, -tqdm>=4.65.0, -nibabel>=5.2.0, -ipywidgets>=8.1.2, -dask>=2023.6.0, +albumentations>=1.3.1 +gradio>=4.27.0 +h5py>=3.9.0 +localthickness>=0.1.2 +matplotlib>=3.8.0 +monai>=1.2.0 +numpy>=1.26.0 +outputformat>=0.1.3 +Pillow>=10.0.1 +plotly>=5.14.1 +scipy>=1.11.2 +seaborn>=0.12.2 +pydicom>=2.4.4 +setuptools>=68.0.0 +tifffile>=2023.4.12 +torch>=2.0.1 +torchvision>=0.15.2 +torchinfo>=1.8.0 +tqdm>=4.65.0 +nibabel>=5.2.0 +ipywidgets>=8.1.2 +dask>=2023.6.0 k3d>=2.16.1 olefile>=0.46 psutil>=5.9.0 diff --git a/setup.py b/setup.py index 017d277a..5efe335a 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f: setup( name="qim3d", - version="0.3.6", + version="0.3.7", author="Felipe Delestro", author_email="fima@dtu.dk", description="QIM tools and user interfaces", -- GitLab