Skip to content
Snippets Groups Projects
Commit 2ea7f365 authored by ofhkr's avatar ofhkr Committed by fima
Browse files

Update for slice_viz function

parent bb1c1d6a
No related branches found
No related tags found
1 merge request!31Update for slice_viz function
......@@ -45,7 +45,7 @@ def test_grid_pred():
def test_slice_viz():
example_volume = ones(10,10,10)
img_width = 3
fig = qim3d.viz.slice_viz(example_volume,img_width = img_width)
fig = qim3d.viz.slice_viz(example_volume,n_slices = 1, img_width = img_width)
assert fig.get_figwidth() == img_width
......
"""Class for choosing the level of data augmentations with albumentations"""
import albumentations as A
from albumentations.pytorch import ToTensorV2
from qim3d.io.logger import log
class Augmentation:
"""
......
......@@ -9,9 +9,10 @@ import numpy as np
import socket
import os
import shutil
from PIL import Image
from pathlib import Path
from qim3d.io.logger import log
def mock_plot():
......@@ -182,6 +183,21 @@ def is_server_running(ip, port):
return False
def temp_data(folder,remove = False,n = 3,img_shape = (32,32)):
"""Creates a temporary folder to test deep learning tools.
Creates two folders, 'train' and 'test', who each also have two subfolders 'images' and 'labels'.
n random images are then added to all four subfolders.
If the 'remove' variable is True, the folders and their content are removed.
Args:
folder (str): The path where the folders should be placed.
remove (bool, optional): If True, all folders are removed from their location.
n (int, optional): Number of random images and labels in the temporary dataset.
img_shape (tuple, options): Tuple with the height and width of the images and labels.
Example:
>>> tempdata('temporary_folder',n = 10, img_shape = (16,16))
"""
folder_trte = ['train','test']
sub_folders = ['images','labels']
......@@ -219,7 +235,7 @@ def temp_data(folder,remove = False,n = 3,img_shape = (32,32)):
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
log.warning('Failed to delete %s. Reason: %s' % (file_path, e))
os.rmdir(folder)
......
......@@ -15,17 +15,19 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
Args:
model (torch.nn.Module): PyTorch model.
hyperparameters (class): dictionary with n_epochs, optimizer and criterion.
hyperparameters (class): Dictionary with n_epochs, optimizer and criterion.
train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs.
eval_every (int, optional): Frequency of model evaluation. Defaults to every epoch.
print_every (int, optional): Frequency of log for model performance. Defaults to every 5 epochs.
plot (bool, optional): If True, plots the training and validation loss after the model is done training.
return_loss (bool, optional), If True, returns a dictionary with the history of the train and validation losses.
Returns:
if return_loss = True:
tuple:
train_loss (dict): dictionary with average losses and batch losses for training loop.
val_loss (dict): dictionary with average losses and batch losses for validation loop.
train_loss (dict): Dictionary with average losses and batch losses for training loop.
val_loss (dict): Dictionary with average losses and batch losses for validation loop.
Example:
# defining the model.
......@@ -141,7 +143,6 @@ def model_summary(dataloader,model):
summary = model_summary(model, dataloader)
print(summary)
"""
images,_ = next(iter(dataloader))
batch_size = tuple(images.shape)
model_s = summary(model,batch_size,depth = torch.inf)
......
"""Provides tools for obtaining informaion about the system."""
"""Provides tools for obtaining information about the system."""
import psutil
from qim3d.utils.internal_tools import sizeof
from qim3d.io.logger import log
......
......@@ -32,7 +32,7 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
- The grid layout and dimensions vary based on the presence of a mask.
Returns:
None
fig (matplotlib.figure.Figure): The figure with an overview of the images and their labels.
Example:
data = [(image1, label1, mask1), (image2, label2, mask2)]
......@@ -117,7 +117,7 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
show (bool, optional): If True, displays the plot. Defaults to False.
Returns:
None
fig (matplotlib.figure.Figure): The figure with images, labels and the label prediction from the trained models.
Raises:
None
......@@ -197,20 +197,31 @@ def grid_pred(in_targ_preds, num_images=7, cmap_im="gray", cmap_segm="viridis",
return fig
def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2, img_width=2,show = False):
def slice_viz(input, position = None, n_slices = 5, cmap = "viridis", axis = False, img_height = 4, img_width = 4, show = False):
""" Displays one or several slices from a 3d array.
By default if `position` is None, slice_viz plots an overview of the entire stack.
If `position` is given as a string or integer, slice_viz will plot an overview with `n_slices` figures around that position.
If `position` is given as a list or array, `n_slices` will be ignored and the idxs from `position` will be plotted.
Args:
input (str, numpy.ndarray): Path to the file or 3-dimensional array.
position (str, int, list, array, optional): One or several slicing levels.
n_slices (int, optional): Defines how many slices the user wants.
cmap (str, optional): Specifies the color map for the image.
axis (bool, optional): Specifies whether the axes should be included.
img_height(int, optional): Height of the figure.
img_width(int, optional): Width of the figure.
show (bool, optional): If True, displays the plot. Defaults to False.
Returns:
fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
Raises:
ValueError: If the file or array is not a 3D volume.
ValueError: If provided string for 'position' argument is not valid (not upper, middle or bottom).
Usage:
Example:
image_path = '/my_image_path/my_image.tif'
slice_viz(image_path)
"""
......@@ -220,35 +231,52 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
vol = qim3d.io.load(input) # Function has its own ValueErrors
dim = vol.ndim
if dim == 3:
pass
else:
raise ValueError(f"Given array is not a volume! Current dimension: {dim}")
# Numpy array input
elif isinstance(input,(np.ndarray,torch.Tensor)):
vol = input
dim = input.ndim
if dim == 3:
vol = input
else:
if dim != 3:
raise ValueError(f"Given array is not a volume! Current dimension: {dim}")
if position is None:
height = np.linspace(0,vol.shape[0]-1,n_slices).astype(int)
# Position is a string
if isinstance(position,str):
elif isinstance(position,str):
if position.lower() in ['mid','middle']:
height = [int(vol.shape[0]/2)]
expansion_start = int(vol.shape[0]/2)
height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int)
elif position.lower() in ['top','upper', 'start']:
height = [0]
expansion_start = 0
height = np.linspace(expansion_start,n_slices-1,n_slices).astype(int)
elif position.lower() in ['bot','bottom', 'end']:
height = [vol.shape[0]-1]
expansion_start = vol.shape[0]-1
height = np.linspace(expansion_start - n_slices,expansion_start,n_slices).astype(int)
else:
raise ValueError('Position not recognized. Choose an integer, list, array or "start","mid","end".')
# Position is an integer
elif isinstance(position,int):
height = [position]
expansion_start = position
n_stacks = vol.shape[0]-1
# if linspace would extend beyond n_stacks
if expansion_start + n_slices > n_stacks:
height = np.linspace(n_stacks - n_slices,n_stacks,n_slices).astype(int)
# if linspace would extend below 0
elif expansion_start - n_slices < 0:
height = np.linspace(0,n_slices-1,n_slices).astype(int)
else:
height = np.linspace(expansion_start - n_slices / 2,expansion_start + n_slices / 2,n_slices).astype(int)
# Position is a list or array of integers
elif isinstance(position,(list,np.ndarray)):
......@@ -256,12 +284,13 @@ def slice_viz(input, position = 'mid', cmap="viridis", axis=False, img_height=2,
num_images = len(height)
fig = plt.figure(figsize=(img_width * num_images, img_height), constrained_layout = True)
axs = fig.subplots(nrows = 1, ncols = num_images)
for col, ax in enumerate(np.atleast_1d(axs)):
ax.imshow(vol[height[col],:,:],cmap = cmap)
ax.set_title(f'Slice {height[col]}', fontsize=8)
ax.set_title(f'Slice {height[col]}', fontsize=6*img_height)
if not axis:
ax.axis('off')
......
......@@ -22,7 +22,6 @@ def plot_metrics(*metrics,
show (bool, optional): If True, displays the plot. Defaults to False.
Returns:
if return_fig:
fig (matplotlib.figure.Figure): plot with metrics.
Example:
......@@ -30,7 +29,6 @@ def plot_metrics(*metrics,
val_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss,val_loss, labels=['Train','Valid.'])
"""
if labels == None:
labels = [None]*len(metrics)
elif len(metrics) != len(labels):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment