Skip to content
Snippets Groups Projects
Commit 9e62018a authored by fima's avatar fima :beers:
Browse files

Import speed refactoring

parent a05e2fe7
No related branches found
No related tags found
1 merge request!101Import speed refactoring
Showing
with 263 additions and 223 deletions
...@@ -9,6 +9,10 @@ As the library is still in its early development stages, **there may be breaking ...@@ -9,6 +9,10 @@ As the library is still in its early development stages, **there may be breaking
And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that you have the latest features! And remember to keep your pip installation [up to date](/qim3d/#upgrade) so that you have the latest features!
### v0.3.7 (17/06/2024)
- Performance improvements when importing
- Refactoring for blob detection
### v0.3.6 (30/05/2024) ### v0.3.6 (30/05/2024)
- Refactoring for performance improvement - Refactoring for performance improvement
- Welcome message for the CLI - Welcome message for the CLI
......
...@@ -8,20 +8,14 @@ Documentation available at https://platform.qim.dk/qim3d/ ...@@ -8,20 +8,14 @@ Documentation available at https://platform.qim.dk/qim3d/
""" """
__version__ = "0.3.6" __version__ = "0.3.7"
import logging
logging.basicConfig(level=logging.ERROR)
from . import io from . import io
from . import gui from . import gui
from . import viz from . import viz
from . import utils from . import utils
from . import processing from . import processing
from . import models
# Commenting out models because it takes too long to import
# from . import models
examples = io.ImgExamples() examples = io.ImgExamples()
io.logger.set_level_info() io.logger.set_level_info()
...@@ -23,16 +23,11 @@ app = annotation_tool.launch(vol[0]) ...@@ -23,16 +23,11 @@ app = annotation_tool.launch(vol[0])
import getpass import getpass
import os import os
import tempfile import tempfile
import time
import gradio as gr import gradio as gr
import numpy as np import numpy as np
import tifffile
from PIL import Image
import qim3d.utils import qim3d.utils
from qim3d.io import load, save from qim3d.io import load, save
from qim3d.io.logger import log
class Session: class Session:
...@@ -100,6 +95,7 @@ class Interface: ...@@ -100,6 +95,7 @@ class Interface:
return gr.update(visible=True) return gr.update(visible=True)
def create_interface(self, img=None): def create_interface(self, img=None):
from PIL import Image
if img is not None: if img is not None:
custom_css = "annotation-tool" custom_css = "annotation-tool"
......
...@@ -17,11 +17,7 @@ from pathlib import Path ...@@ -17,11 +17,7 @@ from pathlib import Path
import dask import dask
import dask.array as da import dask.array as da
import h5py
import nibabel as nib
import numpy as np import numpy as np
import olefile
import pydicom
import tifffile import tifffile
from dask import delayed from dask import delayed
from PIL import Image, UnidentifiedImageError from PIL import Image, UnidentifiedImageError
...@@ -122,6 +118,7 @@ class DataLoader: ...@@ -122,6 +118,7 @@ class DataLoader:
ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file
ValueError: If no datasets are found in the file. ValueError: If no datasets are found in the file.
""" """
import h5py
# Read file # Read file
f = h5py.File(path, "r") f = h5py.File(path, "r")
...@@ -256,6 +253,7 @@ class DataLoader: ...@@ -256,6 +253,7 @@ class DataLoader:
Raises: Raises:
ValueError: If the dxchange library is not installed ValueError: If the dxchange library is not installed
""" """
import olefile
try: try:
import dxchange import dxchange
...@@ -323,6 +321,7 @@ class DataLoader: ...@@ -323,6 +321,7 @@ class DataLoader:
If 'self.virtual_stack' is True, returns a nibabel.arrayproxy.ArrayProxy object If 'self.virtual_stack' is True, returns a nibabel.arrayproxy.ArrayProxy object
If 'self.return_metadata' is True, returns a tuple (volume, metadata). If 'self.return_metadata' is True, returns a tuple (volume, metadata).
""" """
import nibabel as nib
data = nib.load(path) data = nib.load(path)
...@@ -557,6 +556,8 @@ class DataLoader: ...@@ -557,6 +556,8 @@ class DataLoader:
Args: Args:
path (str): Path to file path (str): Path to file
""" """
import pydicom
dcm_data = pydicom.dcmread(path) dcm_data = pydicom.dcmread(path)
if self.return_metadata: if self.return_metadata:
...@@ -570,6 +571,8 @@ class DataLoader: ...@@ -570,6 +571,8 @@ class DataLoader:
Args: Args:
path (str): Directory path path (str): Directory path
""" """
import pydicom
if not self.contains: if not self.contains:
raise ValueError( raise ValueError(
"Please specify a part of the name that is common for the DICOM file stack with the argument 'contains'" "Please specify a part of the name that is common for the DICOM file stack with the argument 'contains'"
...@@ -709,6 +712,8 @@ class DataLoader: ...@@ -709,6 +712,8 @@ class DataLoader:
def _get_h5_dataset_keys(f): def _get_h5_dataset_keys(f):
import h5py
keys = [] keys = []
f.visit(lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None) f.visit(lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None)
return keys return keys
......
...@@ -21,18 +21,12 @@ Example: ...@@ -21,18 +21,12 @@ Example:
``` ```
""" """
import datetime import datetime
import os import os
import h5py
import nibabel as nib
import numpy as np import numpy as np
import PIL import PIL
import pydicom
import tifffile import tifffile
from pydicom.dataset import FileDataset, FileMetaDataset
from pydicom.uid import UID
from qim3d.io.logger import log from qim3d.io.logger import log
from qim3d.utils.internal_tools import sizeof, stringify_path from qim3d.utils.internal_tools import sizeof, stringify_path
...@@ -116,9 +110,13 @@ class DataSaver: ...@@ -116,9 +110,13 @@ class DataSaver:
filepath = os.path.join(path, filename) filepath = os.path.join(path, filename)
self.save_tiff(filepath, sliced) self.save_tiff(filepath, sliced)
pattern_string = filepath[:-(len(extension)+zfill_val)] + "-"*zfill_val + extension pattern_string = (
filepath[: -(len(extension) + zfill_val)] + "-" * zfill_val + extension
)
log.info(f"Total of {no_slices} files saved following the pattern '{pattern_string}'") log.info(
f"Total of {no_slices} files saved following the pattern '{pattern_string}'"
)
def save_nifti(self, path, data): def save_nifti(self, path, data):
"""Save data to a NIfTI file to the given path. """Save data to a NIfTI file to the given path.
...@@ -127,6 +125,8 @@ class DataSaver: ...@@ -127,6 +125,8 @@ class DataSaver:
path (str): The path to save file to path (str): The path to save file to
data (numpy.ndarray): The data to be saved data (numpy.ndarray): The data to be saved
""" """
import nibabel as nib
# Create header # Create header
header = nib.Nifti1Header() header = nib.Nifti1Header()
header.set_data_dtype(data.dtype) header.set_data_dtype(data.dtype)
...@@ -141,7 +141,9 @@ class DataSaver: ...@@ -141,7 +141,9 @@ class DataSaver:
if not self.compression and path.endswith(".gz"): if not self.compression and path.endswith(".gz"):
path = path[:-3] path = path[:-3]
log.warning("File extension '.gz' is ignored since compression is disabled.") log.warning(
"File extension '.gz' is ignored since compression is disabled."
)
# Save image # Save image
nib.save(img, path) nib.save(img, path)
...@@ -155,15 +157,21 @@ class DataSaver: ...@@ -155,15 +157,21 @@ class DataSaver:
""" """
# No support for compression yet # No support for compression yet
if self.compression: if self.compression:
raise NotImplementedError("Saving compressed .vol files is not yet supported") raise NotImplementedError(
"Saving compressed .vol files is not yet supported"
)
# Create custom .vgi metadata file # Create custom .vgi metadata file
metadata = "" metadata = ""
metadata += "{volume1}\n" # .vgi organization metadata += "{volume1}\n" # .vgi organization
metadata += "[file1]\n" # .vgi organization metadata += "[file1]\n" # .vgi organization
metadata += "Size = {} {} {}\n".format(data.shape[1], data.shape[2], data.shape[0]) # Swap axes to match .vol format metadata += "Size = {} {} {}\n".format(
data.shape[1], data.shape[2], data.shape[0]
) # Swap axes to match .vol format
metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string metadata += "Datatype = {}\n".format(str(data.dtype)) # Get datatype as string
metadata += "Name = {}.vol\n".format(path.rsplit('/', 1)[-1][:-4]) # Get filename without extension metadata += "Name = {}.vol\n".format(
path.rsplit("/", 1)[-1][:-4]
) # Get filename without extension
# Save metadata # Save metadata
with open(path[:-4] + ".vgi", "w") as f: with open(path[:-4] + ".vgi", "w") as f:
...@@ -179,9 +187,12 @@ class DataSaver: ...@@ -179,9 +187,12 @@ class DataSaver:
path (str): The path to save file to path (str): The path to save file to
data (numpy.ndarray): The data to be saved data (numpy.ndarray): The data to be saved
""" """
import h5py
with h5py.File(path, "w") as f: with h5py.File(path, "w") as f:
f.create_dataset("dataset", data=data, compression="gzip" if self.compression else None) f.create_dataset(
"dataset", data=data, compression="gzip" if self.compression else None
)
def save_dicom(self, path, data): def save_dicom(self, path, data):
"""Save data to a DICOM file to the given path. """Save data to a DICOM file to the given path.
...@@ -190,18 +201,21 @@ class DataSaver: ...@@ -190,18 +201,21 @@ class DataSaver:
path (str): The path to save file to path (str): The path to save file to
data (numpy.ndarray): The data to be saved data (numpy.ndarray): The data to be saved
""" """
import pydicom
from pydicom.dataset import FileDataset, FileMetaDataset
from pydicom.uid import UID
# based on https://pydicom.github.io/pydicom/stable/auto_examples/input_output/plot_write_dicom.html # based on https://pydicom.github.io/pydicom/stable/auto_examples/input_output/plot_write_dicom.html
# Populate required values for file meta information # Populate required values for file meta information
file_meta = FileMetaDataset() file_meta = FileMetaDataset()
file_meta.MediaStorageSOPClassUID = UID('1.2.840.10008.5.1.4.1.1.2') file_meta.MediaStorageSOPClassUID = UID("1.2.840.10008.5.1.4.1.1.2")
file_meta.MediaStorageSOPInstanceUID = UID("1.2.3") file_meta.MediaStorageSOPInstanceUID = UID("1.2.3")
file_meta.ImplementationClassUID = UID("1.2.3.4") file_meta.ImplementationClassUID = UID("1.2.3.4")
# Create the FileDataset instance (initially no data elements, but file_meta # Create the FileDataset instance (initially no data elements, but file_meta
# supplied) # supplied)
ds = FileDataset(path, {}, ds = FileDataset(path, {}, file_meta=file_meta, preamble=b"\0" * 128)
file_meta=file_meta, preamble=b"\0" * 128)
ds.PatientName = "Test^Firstname" ds.PatientName = "Test^Firstname"
ds.PatientID = "123456" ds.PatientID = "123456"
...@@ -220,8 +234,8 @@ class DataSaver: ...@@ -220,8 +234,8 @@ class DataSaver:
# Set creation date/time # Set creation date/time
dt = datetime.datetime.now() dt = datetime.datetime.now()
ds.ContentDate = dt.strftime('%Y%m%d') ds.ContentDate = dt.strftime("%Y%m%d")
timeStr = dt.strftime('%H%M%S.%f') # long format with micro seconds timeStr = dt.strftime("%H%M%S.%f") # long format with micro seconds
ds.ContentTime = timeStr ds.ContentTime = timeStr
# Needs to be here because of bug in pydicom # Needs to be here because of bug in pydicom
ds.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian ds.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian
...@@ -235,7 +249,6 @@ class DataSaver: ...@@ -235,7 +249,6 @@ class DataSaver:
ds.save_as(path) ds.save_as(path)
def save_PIL(self, path, data): def save_PIL(self, path, data):
"""Save data to a PIL file to the given path. """Save data to a PIL file to the given path.
...@@ -255,7 +268,6 @@ class DataSaver: ...@@ -255,7 +268,6 @@ class DataSaver:
# Save image # Save image
img.save(path) img.save(path)
def save(self, path, data): def save(self, path, data):
"""Save data to the given path. """Save data to the given path.
...@@ -323,7 +335,9 @@ class DataSaver: ...@@ -323,7 +335,9 @@ class DataSaver:
elif path.endswith((".nii", "nii.gz")): elif path.endswith((".nii", "nii.gz")):
return self.save_nifti(path, data) return self.save_nifti(path, data)
elif path.endswith(("TXRM", "XRM", "TXM")): elif path.endswith(("TXRM", "XRM", "TXM")):
raise NotImplementedError("Saving TXRM files is not yet supported") raise NotImplementedError(
"Saving TXRM files is not yet supported"
)
elif path.endswith((".h5")): elif path.endswith((".h5")):
return self.save_h5(path, data) return self.save_h5(path, data)
elif path.endswith((".vol", ".vgi")): elif path.endswith((".vol", ".vgi")):
......
"""UNet model and Hyperparameters class.""" """UNet model and Hyperparameters class."""
from monai.networks.nets import UNet as monai_UNet
from monai.losses import FocalLoss, DiceLoss, DiceCELoss
import torch.nn as nn import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam, SGD, RMSprop
from qim3d.io.logger import log from qim3d.io.logger import log
class UNet(nn.Module): class UNet(nn.Module):
""" """
2D UNet model for QIM imaging. 2D UNet model for QIM imaging.
...@@ -32,16 +28,19 @@ class UNet(nn.Module): ...@@ -32,16 +28,19 @@ class UNet(nn.Module):
model = UNet(size='large') model = UNet(size='large')
``` ```
""" """
def __init__(self, size = 'medium',
def __init__(
self,
size="medium",
dropout=0, dropout=0,
kernel_size=3, kernel_size=3,
up_kernel_size=3, up_kernel_size=3,
activation = 'PReLU', activation="PReLU",
bias=True, bias=True,
adn_order = 'NDA' adn_order="NDA",
): ):
super().__init__() super().__init__()
if size not in ['small','medium','large']: if size not in ["small", "medium", "large"]:
raise ValueError( raise ValueError(
f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'." f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'."
) )
...@@ -57,12 +56,13 @@ class UNet(nn.Module): ...@@ -57,12 +56,13 @@ class UNet(nn.Module):
self.model = self._model_choice() self.model = self._model_choice()
def _model_choice(self): def _model_choice(self):
from monai.networks.nets import UNet as monai_UNet
if self.size == 'small': if self.size == "small":
self.channels = (64, 128, 256) self.channels = (64, 128, 256)
elif self.size == 'medium': elif self.size == "medium":
self.channels = (64, 128, 256, 512, 1024) self.channels = (64, 128, 256, 512, 1024)
elif self.size == 'large': elif self.size == "large":
self.channels = (64, 128, 256, 512, 1024, 2048) self.channels = (64, 128, 256, 512, 1024, 2048)
model = monai_UNet( model = monai_UNet(
...@@ -76,11 +76,10 @@ class UNet(nn.Module): ...@@ -76,11 +76,10 @@ class UNet(nn.Module):
act=self.activation, act=self.activation,
dropout=self.dropout, dropout=self.dropout,
bias=self.bias, bias=self.bias,
adn_ordering=self.adn_order adn_ordering=self.adn_order,
) )
return model return model
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.model(x)
return x return x
...@@ -114,28 +113,36 @@ class Hyperparameters: ...@@ -114,28 +113,36 @@ class Hyperparameters:
n_epochs = params_dict['n_epochs'] n_epochs = params_dict['n_epochs']
``` ```
""" """
def __init__(self,
def __init__(
self,
model, model,
n_epochs=10, n_epochs=10,
learning_rate=1e-3, learning_rate=1e-3,
optimizer = 'Adam', optimizer="Adam",
momentum=0, momentum=0,
weight_decay=0, weight_decay=0,
loss_function = 'Focal'): loss_function="Focal",
):
# TODO: implement custom loss_functions? then add a check to see if loss works for segmentation. # TODO: implement custom loss_functions? then add a check to see if loss works for segmentation.
if loss_function not in ['BCE','Dice','Focal','DiceCE']: if loss_function not in ["BCE", "Dice", "Focal", "DiceCE"]:
raise ValueError(f"Invalid loss function: {loss_function}. Loss criterion must " raise ValueError(
"be one of the following: 'BCE','Dice','Focal','DiceCE'.") f"Invalid loss function: {loss_function}. Loss criterion must "
"be one of the following: 'BCE','Dice','Focal','DiceCE'."
)
# TODO: implement custom optimizer? and add check to see if valid. # TODO: implement custom optimizer? and add check to see if valid.
if optimizer not in ['Adam','SGD','RMSprop']: if optimizer not in ["Adam", "SGD", "RMSprop"]:
raise ValueError(f"Invalid optimizer: {optimizer}. Optimizer must " raise ValueError(
"be one of the following: 'Adam', 'SGD', 'RMSprop'.") f"Invalid optimizer: {optimizer}. Optimizer must "
"be one of the following: 'Adam', 'SGD', 'RMSprop'."
)
if (momentum != 0) and optimizer == 'Adam': if (momentum != 0) and optimizer == "Adam":
log.info("Momentum isn't an input in the 'Adam' optimizer. " log.info(
"Change optimizer to 'SGD' or 'RMSprop' to use momentum.") "Momentum isn't an input in the 'Adam' optimizer. "
"Change optimizer to 'SGD' or 'RMSprop' to use momentum."
)
self.model = model self.model = model
self.n_epochs = n_epochs self.n_epochs = n_epochs
...@@ -146,41 +153,72 @@ class Hyperparameters: ...@@ -146,41 +153,72 @@ class Hyperparameters:
self.loss_function = loss_function self.loss_function = loss_function
def __call__(self): def __call__(self):
return self.model_params(self.model, self.n_epochs, self.optimizer, self.learning_rate, return self.model_params(
self.weight_decay, self.momentum, self.loss_function) self.model,
self.n_epochs,
self.optimizer,
self.learning_rate,
self.weight_decay,
self.momentum,
self.loss_function,
)
def model_params(self, model, n_epochs, optimizer, learning_rate, weight_decay, momentum, loss_function): def model_params(
self,
model,
n_epochs,
optimizer,
learning_rate,
weight_decay,
momentum,
loss_function,
):
optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum) optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
criterion = self._loss_functions(loss_function) criterion = self._loss_functions(loss_function)
hyper_dict = {'optimizer': optim, hyper_dict = {
'criterion': criterion, "optimizer": optim,
'n_epochs' : n_epochs, "criterion": criterion,
"n_epochs": n_epochs,
} }
return hyper_dict return hyper_dict
# selecting the optimizer # selecting the optimizer
def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum): def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum):
if optimizer == 'Adam': from torch.optim import Adam, SGD, RMSprop
optim = Adam(model.parameters(), lr = learning_rate,
weight_decay = weight_decay) if optimizer == "Adam":
elif optimizer == 'SGD': optim = Adam(
optim = SGD(model.parameters(), lr = learning_rate, model.parameters(), lr=learning_rate, weight_decay=weight_decay
momentum = momentum, weight_decay = weight_decay) )
elif optimizer == 'RMSprop': elif optimizer == "SGD":
optim = RMSprop(model.parameters(),lr = learning_rate, optim = SGD(
weight_decay = weight_decay, momentum = momentum) model.parameters(),
lr=learning_rate,
momentum=momentum,
weight_decay=weight_decay,
)
elif optimizer == "RMSprop":
optim = RMSprop(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
momentum=momentum,
)
return optim return optim
# selecting the loss function # selecting the loss function
def _loss_functions(self, loss_function): def _loss_functions(self, loss_function):
if loss_function =='BCE': from monai.losses import FocalLoss, DiceLoss, DiceCELoss
criterion = BCEWithLogitsLoss(reduction='mean') from torch.nn import BCEWithLogitsLoss
elif loss_function == 'Dice':
criterion = DiceLoss(sigmoid=True,reduction='mean') if loss_function == "BCE":
elif loss_function == 'Focal': criterion = BCEWithLogitsLoss(reduction="mean")
criterion = FocalLoss(reduction='mean') elif loss_function == "Dice":
elif loss_function == 'DiceCE': criterion = DiceLoss(sigmoid=True, reduction="mean")
criterion = DiceCELoss(sigmoid=True,reduction='mean') elif loss_function == "Focal":
criterion = FocalLoss(reduction="mean")
elif loss_function == "DiceCE":
criterion = DiceCELoss(sigmoid=True, reduction="mean")
return criterion return criterion
import numpy as np import numpy as np
import torch import torch
from scipy.ndimage import find_objects, label from scipy.ndimage import find_objects, label
from qim3d.io.logger import log from qim3d.io.logger import log
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
import numpy as np import numpy as np
from qim3d.io.logger import log from qim3d.io.logger import log
from skimage.feature import blob_dog
def blob_detection( def blob_detection(
vol: np.ndarray, vol: np.ndarray,
...@@ -63,6 +61,7 @@ def blob_detection( ...@@ -63,6 +61,7 @@ def blob_detection(
``` ```
![blob detection](assets/screenshots/blob_get_mask.gif) ![blob detection](assets/screenshots/blob_get_mask.gif)
""" """
from skimage.feature import blob_dog
if background == "bright": if background == "bright":
log.info("Bright background selected, volume will be inverted.") log.info("Bright background selected, volume will be inverted.")
......
"""Wrapper for the local thickness function from the localthickness package including visualization functions.""" """Wrapper for the local thickness function from the localthickness package including visualization functions."""
import localthickness as lt
import numpy as np import numpy as np
from typing import Optional from typing import Optional
from skimage.filters import threshold_otsu
from qim3d.io.logger import log from qim3d.io.logger import log
#from qim3d.viz import local_thickness as viz_local_thickness
import qim3d import qim3d
def local_thickness( def local_thickness(
image: np.ndarray, image: np.ndarray,
scale: float = 1, scale: float = 1,
...@@ -79,6 +77,8 @@ def local_thickness( ...@@ -79,6 +77,8 @@ def local_thickness(
""" """
import localthickness as lt
from skimage.filters import threshold_otsu
# Check if input is binary # Check if input is binary
if np.unique(image).size > 2: if np.unique(image).size > 2:
......
import numpy as np import numpy as np
import scipy
import skimage
import qim3d.processing.filters as filters import qim3d.processing.filters as filters
from qim3d.io.logger import log from qim3d.io.logger import log
...@@ -86,6 +83,9 @@ def watershed( ...@@ -86,6 +83,9 @@ def watershed(
![operations-watershed_after](assets/screenshots/operations-watershed_after.png) ![operations-watershed_after](assets/screenshots/operations-watershed_after.png)
""" """
import skimage
import scipy
# Compute distance transform of binary volume # Compute distance transform of binary volume
distance= scipy.ndimage.distance_transform_edt(bin_vol) distance= scipy.ndimage.distance_transform_edt(bin_vol)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
import structure_tensor as st
from qim3d.viz.structure_tensor import vectors from qim3d.viz.structure_tensor import vectors
...@@ -74,6 +73,7 @@ def structure_tensor( ...@@ -74,6 +73,7 @@ def structure_tensor(
``` ```
""" """
import structure_tensor as st
if vol.ndim != 3: if vol.ndim != 3:
raise ValueError("The input volume must be 3D") raise ValueError("The input volume must be 3D")
......
...@@ -62,7 +62,7 @@ def test_slices_torch_tensor_input(): ...@@ -62,7 +62,7 @@ def test_slices_torch_tensor_input():
def test_slices_wrong_input_format(): def test_slices_wrong_input_format():
input = 'not_a_volume' input = 'not_a_volume'
with pytest.raises(ValueError, match = 'Input must be a numpy.ndarray or torch.Tensor'): with pytest.raises(ValueError, match = 'Data type not supported'):
qim3d.viz.slices(input) qim3d.viz.slices(input)
def test_slices_not_volume(): def test_slices_not_volume():
......
"""Class for choosing the level of data augmentations with albumentations""" """Class for choosing the level of data augmentations with albumentations"""
import albumentations as A
from albumentations.pytorch import ToTensorV2
class Augmentation: class Augmentation:
""" """
...@@ -54,6 +52,8 @@ class Augmentation: ...@@ -54,6 +52,8 @@ class Augmentation:
Raises: Raises:
ValueError: If `level` is neither None, light, moderate nor heavy. ValueError: If `level` is neither None, light, moderate nor heavy.
""" """
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Check if one of standard augmentation levels # Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']: if level not in [None,'light','moderate','heavy']:
......
...@@ -5,8 +5,15 @@ from qim3d.gui import annotation_tool, data_explorer, iso3d, local_thickness ...@@ -5,8 +5,15 @@ from qim3d.gui import annotation_tool, data_explorer, iso3d, local_thickness
from qim3d.io.loading import DataLoader from qim3d.io.loading import DataLoader
from qim3d.utils import image_preview from qim3d.utils import image_preview
from qim3d import __version__ as version from qim3d import __version__ as version
import outputformat as ouf
import qim3d import qim3d
QIM_TITLE = ouf.rainbow(
f"\n _ _____ __ \n ____ _(_)___ ___ |__ /____/ / \n / __ `/ / __ `__ \ /_ </ __ / \n/ /_/ / / / / / / /__/ / /_/ / \n\__, /_/_/ /_/ /_/____/\__,_/ \n /_/ v{version}\n\n",
return_str=True,
cmap="hot",
)
def main(): def main():
parser = argparse.ArgumentParser(description="Qim3d command-line interface.") parser = argparse.ArgumentParser(description="Qim3d command-line interface.")
...@@ -139,29 +146,14 @@ def main(): ...@@ -139,29 +146,14 @@ def main():
) )
elif args.subcommand is None: elif args.subcommand is None:
print(QIM_TITLE)
welcome_text = ( welcome_text = (
"\n" "\nqim3d is a Python package for 3D image processing and visualization.\n"
" _ _____ __ \n" f"For more information, please visit {ouf.c('https://platform.qim.dk/qim3d/', color='orange', return_str=True)}\n"
" ____ _(_)___ ___ |__ /____/ / \n"
" / __ `/ / __ `__ \ /_ </ __ / \n"
"/ /_/ / / / / / / /__/ / /_/ / \n"
"\__, /_/_/ /_/ /_/____/\__,_/ \n"
" /_/ \n"
"\n"
"--- Welcome to qim3d command-line interface ---\n"
"qim3d is a Python package for 3D image processing and visualization.\n"
"For more information, please visit: https://platform.qim.dk/qim3d/\n"
f"Current version of qim3d: {version}\n"
" \n"
"The qim3d command-line interface provides the following subcommands:\n"
"- gui: Graphical User Interfaces\n"
"- viz: Volumetric visualizations of volumes\n"
"- preview: Preview of an volume directly in the terminal\n"
" \n" " \n"
"For more information on each subcommand, type 'qim3d <subcommand> --help'.\n" "For more information on each subcommand, type 'qim3d <subcommand> --help'.\n"
) )
print(welcome_text) print(welcome_text)
print("--- Help page for qim3d command-line interface shown below ---\n")
parser.print_help() parser.print_help()
print("\n") print("\n")
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
from qim3d.io.logger import log from qim3d.io.logger import log
from torch.utils.data import DataLoader
import torch import torch
import numpy as np import numpy as np
...@@ -187,6 +185,7 @@ def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train ...@@ -187,6 +185,7 @@ def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train
num_workers (int, optional): Defines how many processes should be run in parallel. num_workers (int, optional): Defines how many processes should be run in parallel.
pin_memory (bool, optional): Loads the datasets as CUDA tensors. pin_memory (bool, optional): Loads the datasets as CUDA tensors.
""" """
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory) train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) val_loader = DataLoader(dataset=val_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
......
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
import math import math
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps from matplotlib import colormaps
from skimage import color
def rearrange_colors(randRGBcolors_old, min_dist=0.5): def rearrange_colors(randRGBcolors_old, min_dist=0.5):
# Create new list for re-arranged colors # Create new list for re-arranged colors
...@@ -32,6 +32,7 @@ def rearrange_colors(randRGBcolors_old, min_dist = 0.5): ...@@ -32,6 +32,7 @@ def rearrange_colors(randRGBcolors_old, min_dist = 0.5):
return randRGBcolors_new return randRGBcolors_new
def objects( def objects(
nlabels: int, nlabels: int,
style: str = "bright", style: str = "bright",
...@@ -91,6 +92,8 @@ def objects( ...@@ -91,6 +92,8 @@ def objects(
""" """
from skimage import color
# Check style # Check style
if style not in ("bright", "soft", "earth", "ocean"): if style not in ("bright", "soft", "earth", "ocean"):
raise ValueError( raise ValueError(
...@@ -158,9 +161,7 @@ def objects( ...@@ -158,9 +161,7 @@ def objects(
# Convert LAB list to RGB # Convert LAB list to RGB
randRGBcolors = [] randRGBcolors = []
for LabColor in randLABColors: for LabColor in randLABColors:
randRGBcolors.append( randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())
color.lab2rgb([[LabColor]])[0][0].tolist()
)
# Generate color map for ocean colors, based on LAB # Generate color map for ocean colors, based on LAB
if style == "ocean": if style == "ocean":
...@@ -176,9 +177,7 @@ def objects( ...@@ -176,9 +177,7 @@ def objects(
# Convert LAB list to RGB # Convert LAB list to RGB
randRGBcolors = [] randRGBcolors = []
for LabColor in randLABColors: for LabColor in randLABColors:
randRGBcolors.append( randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())
color.lab2rgb([[LabColor]])[0][0].tolist()
)
# Re-arrange colors to have a minimum distance between neighboring colors # Re-arrange colors to have a minimum distance between neighboring colors
randRGBcolors = rearrange_colors(randRGBcolors, min_dist) randRGBcolors = rearrange_colors(randRGBcolors, min_dist)
...@@ -191,16 +190,18 @@ def objects( ...@@ -191,16 +190,18 @@ def objects(
randRGBcolors[-1] = background_color randRGBcolors[-1] = background_color
# Create colormap # Create colormap
objects = LinearSegmentedColormap.from_list( objects = LinearSegmentedColormap.from_list("objects", randRGBcolors, N=nlabels)
"objects", randRGBcolors, N=nlabels
)
return objects return objects
qim = LinearSegmentedColormap.from_list('qim',
[(0.6, 0.0, 0.0), #990000 qim = LinearSegmentedColormap.from_list(
"qim",
[
(0.6, 0.0, 0.0), # 990000
(1.0, 0.6, 0.0), # ff9900 (1.0, 0.6, 0.0), # ff9900
]) ],
)
""" """
Defines colormap in QIM logo colors. Can be accessed as module attribute or easily by ```cmap = 'qim'``` Defines colormap in QIM logo colors. Can be accessed as module attribute or easily by ```cmap = 'qim'```
......
...@@ -7,9 +7,7 @@ Volumetric visualization using K3D ...@@ -7,9 +7,7 @@ Volumetric visualization using K3D
""" """
import k3d
import numpy as np import numpy as np
from qim3d.io.logger import log from qim3d.io.logger import log
from qim3d.utils.internal_tools import downscale_img, scale_to_float16 from qim3d.utils.internal_tools import downscale_img, scale_to_float16
...@@ -73,6 +71,7 @@ def vol( ...@@ -73,6 +71,7 @@ def vol(
``` ```
""" """
import k3d
pixel_count = img.shape[0] * img.shape[1] * img.shape[2] pixel_count = img.shape[0] * img.shape[1] * img.shape[2]
# target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html # target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html
......
albumentations>=1.3.1, albumentations>=1.3.1
gradio>=4.27.0, gradio>=4.27.0
h5py>=3.9.0, h5py>=3.9.0
localthickness>=0.1.2, localthickness>=0.1.2
matplotlib>=3.8.0, matplotlib>=3.8.0
monai>=1.2.0, monai>=1.2.0
numpy>=1.26.0, numpy>=1.26.0
outputformat>=0.1.3, outputformat>=0.1.3
Pillow>=10.0.1, Pillow>=10.0.1
plotly>=5.14.1, plotly>=5.14.1
scipy>=1.11.2, scipy>=1.11.2
seaborn>=0.12.2, seaborn>=0.12.2
pydicom>=2.4.4, pydicom>=2.4.4
setuptools>=68.0.0, setuptools>=68.0.0
tifffile>=2023.4.12, tifffile>=2023.4.12
torch>=2.0.1, torch>=2.0.1
torchvision>=0.15.2, torchvision>=0.15.2
torchinfo>=1.8.0, torchinfo>=1.8.0
tqdm>=4.65.0, tqdm>=4.65.0
nibabel>=5.2.0, nibabel>=5.2.0
ipywidgets>=8.1.2, ipywidgets>=8.1.2
dask>=2023.6.0, dask>=2023.6.0
k3d>=2.16.1 k3d>=2.16.1
olefile>=0.46 olefile>=0.46
psutil>=5.9.0 psutil>=5.9.0
......
...@@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f: ...@@ -9,7 +9,7 @@ with open("README.md", "r", encoding="utf-8") as f:
setup( setup(
name="qim3d", name="qim3d",
version="0.3.6", version="0.3.7",
author="Felipe Delestro", author="Felipe Delestro",
author_email="fima@dtu.dk", author_email="fima@dtu.dk",
description="QIM tools and user interfaces", description="QIM tools and user interfaces",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment