Skip to content
Snippets Groups Projects
Commit 3ff594d9 authored by fima's avatar fima :beers:
Browse files

Merge branch 'Progress_UNet_July' into 'main'

Progress UNet July Oskar

See merge request !6
parents 2d546cd3 e1991a80
Branches
Tags
1 merge request!6Progress UNet July Oskar
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -2,4 +2,5 @@ import qim3d.io
import qim3d.gui
import qim3d.utils
import qim3d.viz
import qim3d.models
import logging
\ No newline at end of file
......@@ -24,7 +24,10 @@ class DataLoader:
load_tiff(path): Load a TIFF file from the specified path.
load_h5(path): Load an HDF5 file from the specified path.
load_tiff_stack(path): Load a stack of TIFF files from the specified path.
<<<<<<< HEAD
=======
load_txrm(path): Load a TXRM/TXM/XRM file from the specified path
>>>>>>> main
load(path): Load a file or directory based on the given path.
Raises:
......@@ -41,7 +44,11 @@ class DataLoader:
Args:
path (str): The path to the file or directory.
virtual_stack (bool, optional): Specifies whether to use virtual
<<<<<<< HEAD
stack when loading TIFF and HDF5 files. Default is False.
=======
stack when loading files. Default is False.
>>>>>>> main
dataset_name (str, optional): Specifies the name of the dataset to be loaded
in case multiple dataset exist within the same file. Default is None (only for HDF5 files)
return_metadata (bool, optional): Specifies whether to return metadata or not. Default is False (only for HDF5 files)
......@@ -88,6 +95,8 @@ class DataLoader:
ValueError: If the specified dataset_name is not found or is invalid.
ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file
ValueError: If no datasets are found in the file.
<<<<<<< HEAD
=======
"""
# Read file
......@@ -158,6 +167,89 @@ class DataLoader:
else:
return vol
def load_tiff_stack(self, path):
"""Load a stack of TIFF files from the specified path.
Args:
path (str): The path to the stack of TIFF files.
Returns:
numpy.ndarray: The loaded volume as a NumPy array.
>>>>>>> main
Raises:
ValueError: If the 'contains' argument is not specified.
ValueError: If the 'contains' argument matches multiple TIFF stacks in the directory
"""
<<<<<<< HEAD
# Read file
f = h5py.File(path, "r")
data_keys = self._get_h5_dataset_keys(f)
datasets = []
metadata = {}
for key in data_keys:
if (
f[key].ndim > 1
): # Data is assumed to be a dataset if it is two dimensions or more
datasets.append(key)
if f[key].attrs.keys():
metadata[key] = {
"value": f[key][()],
**{attr_key: val for attr_key, val in f[key].attrs.items()},
}
# Only one dataset was found
if len(datasets) == 1:
if self.dataset_name:
log.info(
"'dataset_name' argument is unused since there is only one dataset in the file"
)
name = datasets[0]
vol = f[name]
# Multiple datasets were found
elif len(datasets) > 1:
if self.dataset_name in datasets: # Provided dataset name is valid
name = self.dataset_name
vol = f[name]
else:
if self.dataset_name: # Dataset name is provided
similar_names = difflib.get_close_matches(
self.dataset_name, datasets
) # Find closest matching name if any
if similar_names:
suggestion = similar_names[0] # Get the closest match
raise ValueError(
f"Invalid dataset name. Did you mean '{suggestion}'?"
)
else:
raise ValueError(
f"Invalid dataset name. Please choose between the following datasets: {datasets}"
)
else:
raise ValueError(
f"Found multiple datasets: {datasets}. Please specify which of them that you want to load with the argument 'dataset_name'"
)
# No datasets were found
else:
raise ValueError(f"Did not find any data in the file: {path}")
if not self.virtual_stack:
vol = vol[()] # Load dataset into memory
f.close()
else:
log.info("Using virtual stack")
log.info("Loaded the following dataset: %s", name)
log.info("Loaded shape: %s", vol.shape)
log.info("Using %s of memory", sizeof(sys.getsizeof(vol)))
if self.return_metadata:
return vol, metadata
else:
return vol
def load_tiff_stack(self, path):
"""Load a stack of TIFF files from the specified path.
......@@ -170,8 +262,11 @@ class DataLoader:
Raises:
ValueError: If the 'contains' argument is not specified.
ValueError: If the 'contains' argument matches multiple TIFF stacks in the directory
"""
=======
>>>>>>> main
if not self.contains:
raise ValueError(
"Please specify a part of the name that is common for the TIFF file stack with the argument 'contains'"
......@@ -192,7 +287,11 @@ class DataLoader:
raise ValueError(f"The provided part of the filename for the TIFF stack matches multiple TIFF stacks: {unique_names}.\nPlease provide a string that is unique for the TIFF stack that is intended to be loaded")
<<<<<<< HEAD
vol = tifffile.imread([os.path.join(path, file) for file in tiff_stack])
=======
vol = tifffile.imread([os.path.join(path, file) for file in tiff_stack],out='memmap')
>>>>>>> main
if not self.virtual_stack:
vol = np.copy(vol) # Copy to memory
......@@ -205,6 +304,8 @@ class DataLoader:
return vol
<<<<<<< HEAD
=======
def load_txrm(self,path):
"""Load a TXRM/XRM/TXM file from the specified path.
......@@ -239,6 +340,7 @@ class DataLoader:
else:
return vol
>>>>>>> main
def load(self, path):
"""
Load a file or directory based on the given path.
......@@ -287,12 +389,21 @@ class DataLoader:
else:
raise ValueError("Invalid path")
<<<<<<< HEAD
def _get_h5_dataset_keys(self, f):
keys = []
f.visit(
lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None
)
return keys
=======
def _get_h5_dataset_keys(f):
keys = []
f.visit(
lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None
)
return keys
>>>>>>> main
def load(path, virtual_stack=False, dataset_name=None, return_metadata=False, contains=None, **kwargs):
......@@ -305,7 +416,11 @@ def load(path, virtual_stack=False, dataset_name=None, return_metadata=False, co
stack when loading TIFF and HDF5 files. Default is False.
dataset_name (str, optional): Specifies the name of the dataset to be loaded
in case multiple dataset exist within the same file. Default is None (only for HDF5 files)
<<<<<<< HEAD
return_metadata (bool, optional): Specifies whether to return metadata or not. Default is False (only for HDF5 files)
=======
return_metadata (bool, optional): Specifies whether to return metadata or not. Default is False (only for HDF5 and TXRM files)
>>>>>>> main
contains (str, optional): Specifies a part of the name that is common for the TIFF file stack to be loaded (only for TIFF stacks)
**kwargs: Additional keyword arguments to be passed
to the DataLoader constructor.
......
from .unet import UNet, Hyperparameters
\ No newline at end of file
"""Implementing the UNet model class and Hyperparameters class."""
from monai.networks.nets import UNet as monai_UNet
from monai.losses import FocalLoss, DiceLoss, DiceCELoss
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam, SGD, RMSprop
from qim3d.io.logger import log
class UNet(nn.Module):
"""
2D UNet model for QIM imaging.
This class represents a 2D UNet model designed for imaging segmentation tasks.
Args:
size (str, optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'.
dropout (float, optional): Dropout rate between 0 and 1. Defaults to 0.
kernel_size (int, optional): Convolution kernel size. Defaults to 3.
up_kernel_size (int, optional): Up-convolution kernel size. Defaults to 3.
activation (str, optional): Activation function. Defaults to 'PReLU'.
bias (bool, optional): Whether to include bias in convolutions. Defaults to True.
adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Defaults to 'NDA'.
Raises:
ValueError: If `size` is not one of 'small', 'medium', or 'large'.
Example:
unet = qim_UNet(size='large')
model = unet()
"""
def __init__(self, size = 'medium',
dropout = 0,
kernel_size = 3,
up_kernel_size = 3,
activation = 'PReLU',
bias = True,
adn_order = 'NDA'
):
super().__init__()
if size not in ['small','medium','large']:
raise ValueError(
f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'."
)
self.size = size
self.dropout = dropout
self.kernel_size = kernel_size
self.up_kernel_size = up_kernel_size
self.activation = activation
self.bias = bias
self.adn_order = adn_order
self.model = self._model_choice()
def _model_choice(self):
if self.size == 'small':
self.channels = (64, 128, 256)
elif self.size == 'medium':
self.channels = (64, 128, 256, 512, 1024)
elif self.size == 'large':
self.channels = (64, 128, 256, 512, 1024, 2048)
model = monai_UNet(
spatial_dims=2,
in_channels=1, #TODO: check if image has 1 or multiple input channels
out_channels=1,
channels=self.channels,
strides=(2,) * (len(self.channels) - 1),
kernel_size=self.kernel_size,
up_kernel_size=self.up_kernel_size,
act=self.activation,
dropout=self.dropout,
bias=self.bias,
adn_ordering=self.adn_order
)
return model
def forward(self,x):
x = self.model(x)
return x
class Hyperparameters:
"""
Hyperparameters for QIM segmentation.
Args:
model (torch.nn.Module): PyTorch model.
n_epochs (int, optional): Number of training epochs. Defaults to 10.
learning_rate (float, optional): Learning rate for the optimizer. Defaults to 1e-3.
optimizer (str, optional): Optimizer algorithm. Must be one of 'Adam', 'SGD', 'RMSprop'. Defaults to 'Adam'.
momentum (float, optional): Momentum value for SGD and RMSprop optimizers. Defaults to 0.
weight_decay (float, optional): Weight decay (L2 penalty) for the optimizer. Defaults to 0.
loss_function (str, optional): Loss function criterion. Must be one of 'BCE', 'Dice', 'Focal', 'DiceCE'. Defaults to 'BCE'.
Raises:
ValueError: If `loss_function` is not one of 'BCE', 'Dice', 'Focal', 'DiceCE'.
ValueError: If `optimizer` is not one of 'Adam', 'SGD', 'RMSprop'.
Example:
# Create hyperparameters instance
hyperparams = qim_hyperparameters(model=my_model, n_epochs=20, learning_rate=0.001)
# Get the hyperparameters
params = hyperparams()
# Access the optimizer and criterion
optimizer = params['optimizer']
criterion = params['criterion']
n_epochs = params['n_epochs']
"""
def __init__(self,
model,
n_epochs = 10,
learning_rate = 1e-3,
optimizer = 'Adam',
momentum = 0,
weight_decay = 0,
loss_function = 'Focal'):
# TODO: implement custom loss_functions? then add a check to see if loss works for segmentation.
if loss_function not in ['BCE','Dice','Focal','DiceCE']:
raise ValueError(f"Invalid loss function: {loss_function}. Loss criterion must "
"be one of the following: 'BCE','Dice','Focal','DiceCE'.")
# TODO: implement custom optimizer? and add check to see if valid.
if optimizer not in ['Adam','SGD','RMSprop']:
raise ValueError(f"Invalid optimizer: {optimizer}. Optimizer must "
"be one of the following: 'Adam', 'SGD', 'RMSprop'.")
if (momentum != 0) and optimizer == 'Adam':
log.info("Momentum isn't an input in the 'Adam' optimizer. "
"Change optimizer to 'SGD' or 'RMSprop' to use momentum.")
self.model = model
self.n_epochs = n_epochs
self.learning_rate = learning_rate
self.optimizer = optimizer
self.momentum = momentum
self.weight_decay = weight_decay
self.loss_function = loss_function
def __call__(self):
return self.model_params(self.model, self.n_epochs, self.optimizer, self.learning_rate,
self.weight_decay, self.momentum, self.loss_function)
def model_params(self, model, n_epochs, optimizer, learning_rate, weight_decay, momentum, loss_function):
optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
criterion = self._loss_functions(loss_function)
hyper_dict = {'optimizer': optim,
'criterion': criterion,
'n_epochs' : n_epochs,
}
return hyper_dict
# selecting the optimizer
def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum):
if optimizer == 'Adam':
optim = Adam(model.parameters(), lr = learning_rate,
weight_decay = weight_decay)
elif optimizer == 'SGD':
optim = SGD(model.parameters(), lr = learning_rate,
momentum = momentum, weight_decay = weight_decay)
elif optimizer == 'RMSprop':
optim = RMSprop(model.parameters(),lr = learning_rate,
weight_decay = weight_decay, momentum = momentum)
return optim
# selecting the loss function
def _loss_functions(self,loss_function):
if loss_function =='BCE':
criterion = BCEWithLogitsLoss(reduction='mean')
elif loss_function == 'Dice':
criterion = DiceLoss(sigmoid=True,reduction='mean')
elif loss_function == 'Focal':
criterion = FocalLoss(reduction='mean')
elif loss_function == 'DiceCE':
criterion = DiceCELoss(sigmoid=True,reduction='mean')
return criterion
\ No newline at end of file
from . import internal_tools
from . import models
from .models import train_model, model_summary, inference
from .augmentations import Augmentation
from .data import Dataset
\ No newline at end of file
from .data import Dataset, prepare_datasets, prepare_dataloaders
\ No newline at end of file
"""Class for choosing or customizing data augmentations with albumentations"""
"""Class for choosing the level of data augmentations with albumentations"""
import albumentations as A
from albumentations.pytorch import ToTensorV2
from qim3d.io.logger import log
class Augmentation:
"""
Class for defining image augmentation transformations using Albumentations library.
Class for defining image augmentation transformations using the Albumentations library.
Args:
resize ((int,tuple), optional): The target size to resize the image.
trainsform_train (str, optional): level of transformation for the training set.
transform_validation (str, optional): level of transformation for the validation set.
transform_test (str, optional): level of transformation for the test set.
mean (float, optional): The mean value for normalizing pixel intensities.
std (float, optional): The standard deviation value for normalizing pixel intensities.
Raises:
ValueError: If the provided level is neither None, 'light', 'moderate', 'heavy', nor a custom augmentation.
ValueError: If `resize` is neither a None, int nor tuple.
Attributes:
resize (int): The target size to resize the image.
mean (float): The mean value for normalizing pixel intensities.
std (float): The standard deviation value for normalizing pixel intensities.
Example:
my_augmentation = Augmentation(resize = (256,256), transform_train = 'heavy')
"""
Methods:
augment(level=None): Apply image augmentation transformations based on the specified level, or on a
custom albumentations augmentation. The available levels are None, 'light', 'moderate', and 'heavy'.
def __init__(self,
resize = None,
transform_train = 'moderate',
transform_validation = None,
transform_test = None,
mean: float = 0.5,
std: float = 0.5
):
Usage:
my_augmentation = Augmentation()
moderate_augment = augmentation.augment(level='moderate')
"""
def __init__(self, resize=256, mean=0.5, std=0.5):
if not isinstance(resize,(type(None),int,tuple)):
raise ValueError(f"Invalid input for resize: {resize}. Use an integer or tuple to modify the data.")
self.resize = resize
self.mean = mean
self.std = std
self.transform_train = transform_train
self.transform_validation = transform_validation
self.transform_test = transform_test
def augment(self, im_h, im_w, level=None):
"""
Returns an albumentations.core.composition.Compose class depending on the augmentation level.
A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
The A.Resize() function is used if the user has specified a 'resize' int or tuple at the creation of the Augmentation class.
def augment(self, level=None):
Args:
im_h (int): image height for resize.
im_w (int): image width for resize.
level (str, optional): level of augmentation.
Raises:
ValueError: If `level` is neither None, light, moderate nor heavy.
"""
# Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']:
raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
# Check if the custom transformation is an albumentation:
if not isinstance(level, A.core.composition.Compose):
raise ValueError("Custom Transformations need to be an instance of Albumentations Compose class, "
"or one of the following levels: None, 'light', 'moderate', 'heavy'")
# Custom transformation
else:
return level
# Default transformation
elif level is None:
augment = A.Compose([
A.Resize(self.resize, self.resize),
# Baseline
baseline_aug = [
A.Resize(im_h, im_w),
A.Normalize(mean = (self.mean),std = (self.std)),
ToTensorV2()
])
]
# Choosing light augmentation
# Level of augmentation
if level == None:
level_aug = []
elif level == 'light':
augment = A.Compose([
A.Resize(self.resize, self.resize),
A.RandomRotate90(),
A.Normalize(mean = (self.mean), std = (self.std)),
ToTensorV2()
])
# Choosing moderate augmentation
level_aug = [
A.RandomRotate90()
]
elif level == 'moderate':
augment = A.Compose([
A.Resize(self.resize, self.resize),
level_aug = [
A.RandomRotate90(),
A.HorizontalFlip(p = 0.3),
A.VerticalFlip(p = 0.3),
A.GlassBlur(sigma = 0.7, p = 0.1),
A.Affine(scale = [0.8,1.2], translate_percent = (0.1,0.1)),
A.Normalize(mean = (self.mean), std = (self.std)),
ToTensorV2()
])
# Choosing heavy augmentation
A.Affine(scale = [0.9,1.1], translate_percent = (0.1,0.1))
]
elif level == 'heavy':
augment = A.Compose([
A.Resize(self.resize,self.resize),
level_aug = [
A.RandomRotate90(),
A.HorizontalFlip(p = 0.7),
A.VerticalFlip(p = 0.7),
A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3),
A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15)),
A.Normalize(mean = (self.mean), std = (self.std)),
ToTensorV2()
])
A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15))
]
augment = A.Compose(level_aug + baseline_aug)
return augment
\ No newline at end of file
"""Provides a custom Dataset class for building a PyTorch dataset"""
"""Provides a custom Dataset class for building a PyTorch dataset."""
from pathlib import Path
from PIL import Image
from qim3d.io.logger import log
from torch.utils.data import DataLoader
import torch
import numpy as np
class Dataset(torch.utils.data.Dataset):
"""
Custom Dataset class for building a PyTorch dataset
Custom Dataset class for building a PyTorch dataset.
Args:
root_path (str): The root directory path of the dataset.
......@@ -51,6 +54,9 @@ class Dataset(torch.utils.data.Dataset):
self.sample_targets = [file for file in sorted((path / "labels").iterdir())]
assert len(self.sample_images) == len(self.sample_targets)
# checking the characteristics of the dataset
self.check_shape_consistency(self.sample_images)
def __len__(self):
return len(self.sample_images)
......@@ -69,3 +75,121 @@ class Dataset(torch.utils.data.Dataset):
target = transformed["mask"]
return image, target
# TODO: working with images of different sizes
def check_shape_consistency(self,sample_images):
image_shapes= []
for image_path in sample_images:
image_shape = self._get_shape(image_path)
image_shapes.append(image_shape)
# check if all images have the same size.
consistency_check = all(i == image_shapes[0] for i in image_shapes)
if not consistency_check:
raise NotImplementedError(
"Only images of all the same size can be processed at the moment"
)
else:
log.debug(
"Images are all the same size!"
)
return consistency_check
def _get_shape(self,image_path):
return Image.open(str(image_path)).size
def check_resize(im_height: int, im_width: int, n_channels: int):
"""
Checks the compatibility of the image shape with the depth of the model.
If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
If so, the image is reshaped to the closest appropriate dimension, and the user is notified with a warning.
Args:
im_height (int): Height of the image chosen by the user.
im_width (int): Width of the image chosen by the user.
n_channels (int): Number of channels in the model.
Raises:
ValueError: If the image size is smaller than minimum required for the model's depth.
"""
h_adjust, w_adjust = (im_height // 2**n_channels) * 2**n_channels , (im_width // 2**n_channels) * 2**n_channels
if h_adjust == 0 or w_adjust == 0:
raise ValueError("The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.")
elif (h_adjust!=im_height) or (w_adjust != im_width):
log.warning(f"The image size doesn't match the Unet model's depth. The image is resized to: {h_adjust,w_adjust}")
return h_adjust, w_adjust
def prepare_datasets(path: str, val_fraction: float, model, augmentation):
"""
Splits and augments the train/validation/test datasets.
Args:
path (str): Path to the dataset.
val_fraction (float): Fraction of the data for the validation set.
model (torch.nn.Module): PyTorch Model.
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
Raises:
ValueError: if the validation fraction is not a float, and is not between 0 and 1.
"""
if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1):
raise ValueError("The validation fraction must be a float between 0 and 1.")
resize = augmentation.resize
n_channels = len(model.channels)
if isinstance(resize,type(None)):
# OPEN THE FIRST IMAGE
im_path = Path(path) / 'train'
first_img = sorted((im_path / "images").iterdir())[0]
image = Image.open(str(first_img))
im_h, im_w = image.size[:2]
log.info("User did not choose a specific value for 'resize'. Checking the first image in the dataset...")
elif isinstance(resize,int):
im_h, im_w = resize, resize
else:
im_h,im_w = resize
final_h, final_w = check_resize(im_h, im_w, n_channels)
train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test))
split_idx = int(np.floor(val_fraction * len(train_set)))
indices = torch.randperm(len(train_set))
train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
val_set = torch.utils.data.Subset(val_set, indices[:split_idx])
return train_set, val_set, test_set
def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = True):
"""
Prepares the dataloaders for model training.
Args:
train_set (torch.utils.data): Training dataset.
val_set (torch.utils.data): Validation dataset.
test_set (torch.utils.data): Testing dataset.
batch_size (int): Size of the batches that should be trained upon.
shuffle_train (bool, optional): Optional input to shuffle the training data (training robustness).
num_workers (int, optional): Defines how many processes should be run in parallel.
pin_memory (bool, optional): Loads the datasets as CUDA tensors.
"""
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
return train_loader,val_loader,test_loader
\ No newline at end of file
""" Tools performed with trained models."""
""" Tools performed with models."""
import torch
import matplotlib.pyplot as plt
from torchinfo import summary
from qim3d.io.logger import log, level
from qim3d.viz.visualizations import plot_metrics
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
def train_model(model, qim_hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True):
""" Function for training Neural Network models.
Args:
model (torch.nn.Module): PyTorch model.
qim_hyperparameters (dict): dictionary with n_epochs, optimizer and criterion.
train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs.
Returns:
tuple:
train_loss (dict): dictionary with average losses and batch losses for training loop.
val_loss (dict): dictionary with average losses and batch losses for validation loop.
Example:
# defining the model.
model = qim3d.qim3d.utils.qim_UNet()
# choosing the hyperparameters
qim_hyper = qim3d.qim3d.utils.qim_hyperparameters(model)
hyper_dict = qim_hyper()
# DataLoaders
train_loader = MyTrainLoader()
val_loader = MyValLoader()
# training the model.
train_loss,val_loss = train_model(model, hyper_dict, train_loader, val_loader)
"""
n_epochs = qim_hyperparameters['n_epochs']
optimizer = qim_hyperparameters['optimizer']
criterion = qim_hyperparameters['criterion']
# Choosing best device available.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Avoid logging twice.
log.propagate = False
train_loss = {'loss' : [],'batch_loss': []}
val_loss = {'loss' : [], 'batch_loss' : []}
with logging_redirect_tqdm():
for epoch in tqdm(range(n_epochs)):
epoch_loss = 0
step = 0
model.train()
for data in train_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backpropagation
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
step += 1
# Log and store batch training loss.
train_loss['batch_loss'].append(loss.detach().item())
# Log and store average training loss per epoch.
epoch_loss = epoch_loss / step
train_loss['loss'].append(epoch_loss)
if epoch % eval_every ==0:
eval_loss = 0
step = 0
model.eval()
for data in val_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
with torch.no_grad():
outputs = model(inputs)
loss = criterion(outputs, targets)
eval_loss += loss.item()
step += 1
# Log and store batch validation loss.
val_loss['batch_loss'].append(loss.item())
# Log and store average validation loss.
eval_loss = eval_loss / step
val_loss['loss'].append(eval_loss)
if epoch % print_every == 0:
log.info(
f"Epoch {epoch: 3}, train loss: {train_loss['loss'][epoch]:.4f}, "
f"val loss: {val_loss['loss'][epoch]:.4f}"
)
if plot:
fig = plt.figure(figsize=(16, 6), constrained_layout = True)
plot_metrics(train_loss, label = 'Train')
plot_metrics(val_loss,color = 'orange', label = 'Valid.')
fig.show()
def model_summary(dataloader,model):
"""Prints the summary of a PyTorch model.
Args:
model (torch.nn.Module): The PyTorch model to summarize.
dataloader (torch.utils.data.DataLoader): The data loader used to determine the input shape.
Returns:
str: Summary of the model architecture.
Example:
model = MyModel()
dataloader = DataLoader(dataset, batch_size=32)
summary = model_summary(model, dataloader)
print(summary)
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
images,_ = next(iter(dataloader))
batch_size = tuple(images.shape)
model_s = summary(model,batch_size,depth = torch.inf)
return model_s
def inference(data,model):
"""Performs inference on input data using the specified model.
......@@ -68,7 +218,7 @@ def inference(data,model):
inputs = inputs.cpu().squeeze()
targets = targets.squeeze()
if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > 0.5
preds = outputs.cpu().squeeze() > 0.5 # TODO: outputs from model are not between [0,1] yet, need to implement that
else:
preds = outputs.cpu().argmax(axis=1)
......
<<<<<<< HEAD
from .img import grid_pred, grid_overview
from .visualizations import plot_metrics
=======
from .img import grid_pred, grid_overview, slice_viz
>>>>>>> main
"""Visualization tools"""
import numpy as np
import matplotlib.pyplot as plt
import seaborn as snb
def plot_metrics(metric, color = 'blue', linestyle = '-', batch_linestyle = 'dotted', label = None):
"""
Plots the metrics over epochs and batches.
Args:
metric (dict): A dictionary containing the metrics per epochs and per batches.
color (str, optional): The color of the plotted lines. Defaults to 'blue'.
linestyle (str, optional): The style of the epoch metric line. Defaults to '-'.
batch_linestyle (str, optional): The style of the batch metric line. Defaults to 'dotted'.
label (str, optional): The label for the epoch metric line. Defaults to None.
Returns:
None
Example:
train_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss, color = 'red', label='Train')
"""
# plotting parameters
snb.set_style('darkgrid')
snb.set(font_scale=1.5)
plt.rcParams['lines.linewidth'] = 2
metric_name = list(metric.keys())[0]
epoch_metric = metric[list(metric.keys())[0]]
batch_metric = metric[list(metric.keys())[1]]
x_axis = np.linspace(0,len(epoch_metric)-1,len(batch_metric))
plt.plot(epoch_metric,linestyle = linestyle, color = color,label = label)
plt.plot(x_axis, batch_metric, linestyle = batch_linestyle, color = color, alpha = 0.4)
plt.ylabel(metric_name)
plt.xlabel('epoch')
plt.legend()
# reset plotting parameters
snb.set_style('white')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment