diff --git a/qim3d/tests/viz/test_img.py b/qim3d/tests/viz/test_img.py index f62ffc2a7e2342363fa7e6313ecd9894fd70ba06..06fec0d0c41d0c0269ffc94291a77a8b63fa2b36 100644 --- a/qim3d/tests/viz/test_img.py +++ b/qim3d/tests/viz/test_img.py @@ -35,7 +35,7 @@ def test_grid_pred(): in_targ_pred = qim3d.utils.models.inference(train_set,model) fig = qim3d.viz.grid_pred(in_targ_pred) - + assert (fig.get_figwidth(),fig.get_figheight()) == (2*(n),10) temp_data(folder,remove = True) @@ -45,7 +45,7 @@ def test_grid_pred(): def test_slice_viz(): example_volume = ones(10,10,10) img_width = 3 - fig = qim3d.viz.slice_viz(example_volume,img_width = img_width) + fig = qim3d.viz.slice_viz(example_volume,n_slices = 1, img_width = img_width) assert fig.get_figwidth() == img_width diff --git a/qim3d/utils/augmentations.py b/qim3d/utils/augmentations.py index a57880e7c52ede52247e7fe666f7abd8a40b68aa..19154b7010f92a893bae81f8ea8dbce48e227e4f 100644 --- a/qim3d/utils/augmentations.py +++ b/qim3d/utils/augmentations.py @@ -1,7 +1,6 @@ """Class for choosing the level of data augmentations with albumentations""" import albumentations as A from albumentations.pytorch import ToTensorV2 -from qim3d.io.logger import log class Augmentation: """ diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py index b8a4b6b7bcb86be7cb4496b576dd68381db57556..bc06e5f9285cdc2062903e9f8ec651484c7fb985 100644 --- a/qim3d/utils/internal_tools.py +++ b/qim3d/utils/internal_tools.py @@ -9,9 +9,10 @@ import numpy as np import socket import os import shutil + from PIL import Image from pathlib import Path - +from qim3d.io.logger import log def mock_plot(): @@ -182,6 +183,21 @@ def is_server_running(ip, port): return False def temp_data(folder,remove = False,n = 3,img_shape = (32,32)): + """Creates a temporary folder to test deep learning tools. + + Creates two folders, 'train' and 'test', who each also have two subfolders 'images' and 'labels'. + n random images are then added to all four subfolders. + If the 'remove' variable is True, the folders and their content are removed. + + Args: + folder (str): The path where the folders should be placed. + remove (bool, optional): If True, all folders are removed from their location. + n (int, optional): Number of random images and labels in the temporary dataset. + img_shape (tuple, options): Tuple with the height and width of the images and labels. + + Example: + >>> tempdata('temporary_folder',n = 10, img_shape = (16,16)) + """ folder_trte = ['train','test'] sub_folders = ['images','labels'] @@ -219,7 +235,7 @@ def temp_data(folder,remove = False,n = 3,img_shape = (32,32)): elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) + log.warning('Failed to delete %s. Reason: %s' % (file_path, e)) os.rmdir(folder) diff --git a/qim3d/utils/models.py b/qim3d/utils/models.py index 0294f14a0aa8bd8971f29d3ae6873ec55088685e..19a2a844fe225932e03b041c922328129aa771e8 100644 --- a/qim3d/utils/models.py +++ b/qim3d/utils/models.py @@ -15,17 +15,19 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 Args: model (torch.nn.Module): PyTorch model. - hyperparameters (class): dictionary with n_epochs, optimizer and criterion. + hyperparameters (class): Dictionary with n_epochs, optimizer and criterion. train_loader (torch.utils.data.DataLoader): DataLoader for the training data. val_loader (torch.utils.data.DataLoader): DataLoader for the validation data. - eval_every (int, optional): frequency of model evaluation. Defaults to every epoch. - print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs. + eval_every (int, optional): Frequency of model evaluation. Defaults to every epoch. + print_every (int, optional): Frequency of log for model performance. Defaults to every 5 epochs. + plot (bool, optional): If True, plots the training and validation loss after the model is done training. + return_loss (bool, optional), If True, returns a dictionary with the history of the train and validation losses. - Returns: - tuple: - train_loss (dict): dictionary with average losses and batch losses for training loop. - val_loss (dict): dictionary with average losses and batch losses for validation loop. + if return_loss = True: + tuple: + train_loss (dict): Dictionary with average losses and batch losses for training loop. + val_loss (dict): Dictionary with average losses and batch losses for validation loop. Example: # defining the model. @@ -141,7 +143,6 @@ def model_summary(dataloader,model): summary = model_summary(model, dataloader) print(summary) """ - images,_ = next(iter(dataloader)) batch_size = tuple(images.shape) model_s = summary(model,batch_size,depth = torch.inf) diff --git a/qim3d/utils/system.py b/qim3d/utils/system.py index 5fd6e8eed040d3a51de5219ce1188d74622932b8..383060df8c8fad875df4c4984325b7fb0f92ed6a 100644 --- a/qim3d/utils/system.py +++ b/qim3d/utils/system.py @@ -1,4 +1,4 @@ -"""Provides tools for obtaining informaion about the system.""" +"""Provides tools for obtaining information about the system.""" import psutil from qim3d.utils.internal_tools import sizeof from qim3d.io.logger import log diff --git a/qim3d/viz/img.py b/qim3d/viz/img.py index 3a30005db39ff6b34ea996ee3104cd5d7ff31ec4..83edde50b5c98af6a5b21cbd2b46e32dda8515a1 100644 --- a/qim3d/viz/img.py +++ b/qim3d/viz/img.py @@ -32,7 +32,7 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha - The grid layout and dimensions vary based on the presence of a mask. Returns: - None + fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels. Example: data = [(image1, label1, mask1), (image2, label2, mask2)] @@ -117,7 +117,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", show (bool, optional): If True, displays the plot. Defaults to False. Returns: - None + fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models. Raises: None @@ -197,20 +197,31 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis", return fig -def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, img_width=2,show = False): +def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = False, img_height = 4, img_width = 4, show = False): """ Displays one or several slices from a 3d array. + By default if `position` is None, slice_viz plots an overview of the entire stack. + If `position` is given as a string or integer, slice_viz will plot an overview with `n_slices` figures around that position. + If `position` is given as a list or array, `n_slices` will be ignored and the idxs from `position` will be plotted. + Args: input (str, numpy.ndarray): Path to the file or 3-dimensional array. position (str, int, list, array, optional): One or several slicing levels. + n_slices (int, optional): Defines how many slices the user wants. cmap (str, optional): Specifies the color map for the image. axis (bool, optional): Specifies whether the axes should be included. + img_height(int, optional): Height of the figure. + img_width(int, optional): Width of the figure. show (bool, optional): If True, displays the plot. Defaults to False. + Returns: + fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. + Raises: + ValueError: If the file or array is not a 3D volume. ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom). - - Usage: + + Example: image_path = '/my_image_path/my_image.tif' slice_viz(image_path) """ @@ -218,50 +229,68 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, # Filepath input if isinstance(input,str): vol = qim3d.io.load(input) # Function has its own ValueErrors - dim = vol.ndim - - if dim == 3: - pass - else: - raise ValueError(f"Given array is not a volume! Current dimension: {dim}") - + dim = vol.ndim # Numpy array input elif isinstance(input,(np.ndarray,torch.Tensor)): + vol = input dim = input.ndim - if dim == 3: - vol = input - else: - raise ValueError(f"Given array is not a volume! Current dimension: {dim}") + if dim != 3: + raise ValueError(f"Given array is not a volume! Current dimension: {dim}") + if position is None: + height = np.linspace(0,vol.shape[0]-1,n_slices).astype(int) + # Position is a string - if isinstance(position,str): + elif isinstance(position,str): + if position.lower() in ['mid','middle']: - height = [int(vol.shape[0]/2)] + expansion_start = int(vol.shape[0]/2) + height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int) + elif position.lower() in ['top','upper', 'start']: - height = [0] + expansion_start = 0 + height = np.linspace(expansion_start,n_slices-1,n_slices).astype(int) + elif position.lower() in ['bot','bottom', 'end']: - height = [vol.shape[0]-1] + expansion_start = vol.shape[0]-1 + height = np.linspace(expansion_start - n_slices,expansion_start,n_slices).astype(int) + else: raise ValueError('Position not recognized. Choose an integer, list, array or "start","mid","end".') + # Position is an integer elif isinstance(position,int): - height = [position] + expansion_start = position + n_stacks = vol.shape[0]-1 + + # if linspace would extend beyond n_stacks + if expansion_start + n_slices > n_stacks: + height = np.linspace(n_stacks - n_slices,n_stacks,n_slices).astype(int) + + # if linspace would extend below 0 + elif expansion_start - n_slices < 0: + height = np.linspace(0,n_slices-1,n_slices).astype(int) + else: + height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int) + + # Position is a list or array of integers elif isinstance(position,(list,np.ndarray)): height = position num_images = len(height) - + + fig = plt.figure(figsize=(img_width * num_images, img_height), constrained_layout = True) axs = fig.subplots(nrows = 1, ncols = num_images) for col, ax in enumerate(np.atleast_1d(axs)): ax.imshow(vol[height[col],:,:],cmap = cmap) - ax.set_title(f'Slice {height[col]}', fontsize=8) + ax.set_title(f'Slice {height[col]}', fontsize=6*img_height) if not axis: ax.axis('off') diff --git a/qim3d/viz/visualizations.py b/qim3d/viz/visualizations.py index f412a2b60204c054fb968502092ea0d763a0c919..4d64b0ce5b397486c90d4e3af69363e4a333dc2c 100644 --- a/qim3d/viz/visualizations.py +++ b/qim3d/viz/visualizations.py @@ -22,15 +22,13 @@ def plot_metrics(*metrics, show (bool, optional): If True, displays the plot. Defaults to False. Returns: - if return_fig: - fig (matplotlib.figure.Figure): plot with metrics. + fig (matplotlib.figure.Figure): plot with metrics. Example: train_loss = {'epoch_loss' : [...], 'batch_loss': [...]} val_loss = {'epoch_loss' : [...], 'batch_loss': [...]} plot_metrics(train_loss,val_loss, labels=['Train','Valid.']) """ - if labels == None: labels = [None]*len(metrics) elif len(metrics) != len(labels):