Skip to content
Snippets Groups Projects
Commit e1991a80 authored by ofhkr's avatar ofhkr Committed by fima
Browse files

Progress UNet July Oskar (merged with main before)

parent 2d546cd3
No related branches found
No related tags found
1 merge request!6Progress UNet July Oskar
This diff is collapsed.
...@@ -2,4 +2,5 @@ import qim3d.io ...@@ -2,4 +2,5 @@ import qim3d.io
import qim3d.gui import qim3d.gui
import qim3d.utils import qim3d.utils
import qim3d.viz import qim3d.viz
import qim3d.models
import logging import logging
\ No newline at end of file
...@@ -24,7 +24,10 @@ class DataLoader: ...@@ -24,7 +24,10 @@ class DataLoader:
load_tiff(path): Load a TIFF file from the specified path. load_tiff(path): Load a TIFF file from the specified path.
load_h5(path): Load an HDF5 file from the specified path. load_h5(path): Load an HDF5 file from the specified path.
load_tiff_stack(path): Load a stack of TIFF files from the specified path. load_tiff_stack(path): Load a stack of TIFF files from the specified path.
<<<<<<< HEAD
=======
load_txrm(path): Load a TXRM/TXM/XRM file from the specified path load_txrm(path): Load a TXRM/TXM/XRM file from the specified path
>>>>>>> main
load(path): Load a file or directory based on the given path. load(path): Load a file or directory based on the given path.
Raises: Raises:
...@@ -41,7 +44,11 @@ class DataLoader: ...@@ -41,7 +44,11 @@ class DataLoader:
Args: Args:
path (str): The path to the file or directory. path (str): The path to the file or directory.
virtual_stack (bool, optional): Specifies whether to use virtual virtual_stack (bool, optional): Specifies whether to use virtual
<<<<<<< HEAD
stack when loading TIFF and HDF5 files. Default is False.
=======
stack when loading files. Default is False. stack when loading files. Default is False.
>>>>>>> main
dataset_name (str, optional): Specifies the name of the dataset to be loaded dataset_name (str, optional): Specifies the name of the dataset to be loaded
in case multiple dataset exist within the same file. Default is None (only for HDF5 files) in case multiple dataset exist within the same file. Default is None (only for HDF5 files)
return_metadata (bool, optional): Specifies whether to return metadata or not. Default is False (only for HDF5 files) return_metadata (bool, optional): Specifies whether to return metadata or not. Default is False (only for HDF5 files)
...@@ -88,6 +95,8 @@ class DataLoader: ...@@ -88,6 +95,8 @@ class DataLoader:
ValueError: If the specified dataset_name is not found or is invalid. ValueError: If the specified dataset_name is not found or is invalid.
ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file ValueError: If the dataset_name is not specified in case of multiple datasets in the HDF5 file
ValueError: If no datasets are found in the file. ValueError: If no datasets are found in the file.
<<<<<<< HEAD
=======
""" """
# Read file # Read file
...@@ -158,6 +167,89 @@ class DataLoader: ...@@ -158,6 +167,89 @@ class DataLoader:
else: else:
return vol return vol
def load_tiff_stack(self, path):
"""Load a stack of TIFF files from the specified path.
Args:
path (str): The path to the stack of TIFF files.
Returns:
numpy.ndarray: The loaded volume as a NumPy array.
>>>>>>> main
Raises:
ValueError: If the 'contains' argument is not specified.
ValueError: If the 'contains' argument matches multiple TIFF stacks in the directory
"""
<<<<<<< HEAD
# Read file
f = h5py.File(path, "r")
data_keys = self._get_h5_dataset_keys(f)
datasets = []
metadata = {}
for key in data_keys:
if (
f[key].ndim > 1
): # Data is assumed to be a dataset if it is two dimensions or more
datasets.append(key)
if f[key].attrs.keys():
metadata[key] = {
"value": f[key][()],
**{attr_key: val for attr_key, val in f[key].attrs.items()},
}
# Only one dataset was found
if len(datasets) == 1:
if self.dataset_name:
log.info(
"'dataset_name' argument is unused since there is only one dataset in the file"
)
name = datasets[0]
vol = f[name]
# Multiple datasets were found
elif len(datasets) > 1:
if self.dataset_name in datasets: # Provided dataset name is valid
name = self.dataset_name
vol = f[name]
else:
if self.dataset_name: # Dataset name is provided
similar_names = difflib.get_close_matches(
self.dataset_name, datasets
) # Find closest matching name if any
if similar_names:
suggestion = similar_names[0] # Get the closest match
raise ValueError(
f"Invalid dataset name. Did you mean '{suggestion}'?"
)
else:
raise ValueError(
f"Invalid dataset name. Please choose between the following datasets: {datasets}"
)
else:
raise ValueError(
f"Found multiple datasets: {datasets}. Please specify which of them that you want to load with the argument 'dataset_name'"
)
# No datasets were found
else:
raise ValueError(f"Did not find any data in the file: {path}")
if not self.virtual_stack:
vol = vol[()] # Load dataset into memory
f.close()
else:
log.info("Using virtual stack")
log.info("Loaded the following dataset: %s", name)
log.info("Loaded shape: %s", vol.shape)
log.info("Using %s of memory", sizeof(sys.getsizeof(vol)))
if self.return_metadata:
return vol, metadata
else:
return vol
def load_tiff_stack(self, path): def load_tiff_stack(self, path):
"""Load a stack of TIFF files from the specified path. """Load a stack of TIFF files from the specified path.
...@@ -170,8 +262,11 @@ class DataLoader: ...@@ -170,8 +262,11 @@ class DataLoader:
Raises: Raises:
ValueError: If the 'contains' argument is not specified. ValueError: If the 'contains' argument is not specified.
ValueError: If the 'contains' argument matches multiple TIFF stacks in the directory ValueError: If the 'contains' argument matches multiple TIFF stacks in the directory
""" """
=======
>>>>>>> main
if not self.contains: if not self.contains:
raise ValueError( raise ValueError(
"Please specify a part of the name that is common for the TIFF file stack with the argument 'contains'" "Please specify a part of the name that is common for the TIFF file stack with the argument 'contains'"
...@@ -192,7 +287,11 @@ class DataLoader: ...@@ -192,7 +287,11 @@ class DataLoader:
raise ValueError(f"The provided part of the filename for the TIFF stack matches multiple TIFF stacks: {unique_names}.\nPlease provide a string that is unique for the TIFF stack that is intended to be loaded") raise ValueError(f"The provided part of the filename for the TIFF stack matches multiple TIFF stacks: {unique_names}.\nPlease provide a string that is unique for the TIFF stack that is intended to be loaded")
<<<<<<< HEAD
vol = tifffile.imread([os.path.join(path, file) for file in tiff_stack])
=======
vol = tifffile.imread([os.path.join(path, file) for file in tiff_stack],out='memmap') vol = tifffile.imread([os.path.join(path, file) for file in tiff_stack],out='memmap')
>>>>>>> main
if not self.virtual_stack: if not self.virtual_stack:
vol = np.copy(vol) # Copy to memory vol = np.copy(vol) # Copy to memory
...@@ -205,6 +304,8 @@ class DataLoader: ...@@ -205,6 +304,8 @@ class DataLoader:
return vol return vol
<<<<<<< HEAD
=======
def load_txrm(self,path): def load_txrm(self,path):
"""Load a TXRM/XRM/TXM file from the specified path. """Load a TXRM/XRM/TXM file from the specified path.
...@@ -239,6 +340,7 @@ class DataLoader: ...@@ -239,6 +340,7 @@ class DataLoader:
else: else:
return vol return vol
>>>>>>> main
def load(self, path): def load(self, path):
""" """
Load a file or directory based on the given path. Load a file or directory based on the given path.
...@@ -287,12 +389,21 @@ class DataLoader: ...@@ -287,12 +389,21 @@ class DataLoader:
else: else:
raise ValueError("Invalid path") raise ValueError("Invalid path")
<<<<<<< HEAD
def _get_h5_dataset_keys(self, f):
keys = []
f.visit(
lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None
)
return keys
=======
def _get_h5_dataset_keys(f): def _get_h5_dataset_keys(f):
keys = [] keys = []
f.visit( f.visit(
lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None lambda key: keys.append(key) if isinstance(f[key], h5py.Dataset) else None
) )
return keys return keys
>>>>>>> main
def load(path, virtual_stack=False, dataset_name=None, return_metadata=False, contains=None, **kwargs): def load(path, virtual_stack=False, dataset_name=None, return_metadata=False, contains=None, **kwargs):
...@@ -305,7 +416,11 @@ def load(path, virtual_stack=False, dataset_name=None, return_metadata=False, co ...@@ -305,7 +416,11 @@ def load(path, virtual_stack=False, dataset_name=None, return_metadata=False, co
stack when loading TIFF and HDF5 files. Default is False. stack when loading TIFF and HDF5 files. Default is False.
dataset_name (str, optional): Specifies the name of the dataset to be loaded dataset_name (str, optional): Specifies the name of the dataset to be loaded
in case multiple dataset exist within the same file. Default is None (only for HDF5 files) in case multiple dataset exist within the same file. Default is None (only for HDF5 files)
<<<<<<< HEAD
return_metadata (bool, optional): Specifies whether to return metadata or not. Default is False (only for HDF5 files)
=======
return_metadata (bool, optional): Specifies whether to return metadata or not. Default is False (only for HDF5 and TXRM files) return_metadata (bool, optional): Specifies whether to return metadata or not. Default is False (only for HDF5 and TXRM files)
>>>>>>> main
contains (str, optional): Specifies a part of the name that is common for the TIFF file stack to be loaded (only for TIFF stacks) contains (str, optional): Specifies a part of the name that is common for the TIFF file stack to be loaded (only for TIFF stacks)
**kwargs: Additional keyword arguments to be passed **kwargs: Additional keyword arguments to be passed
to the DataLoader constructor. to the DataLoader constructor.
......
from .unet import UNet, Hyperparameters
\ No newline at end of file
"""Implementing the UNet model class and Hyperparameters class."""
from monai.networks.nets import UNet as monai_UNet
from monai.losses import FocalLoss, DiceLoss, DiceCELoss
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam, SGD, RMSprop
from qim3d.io.logger import log
class UNet(nn.Module):
"""
2D UNet model for QIM imaging.
This class represents a 2D UNet model designed for imaging segmentation tasks.
Args:
size (str, optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Defaults to 'medium'.
dropout (float, optional): Dropout rate between 0 and 1. Defaults to 0.
kernel_size (int, optional): Convolution kernel size. Defaults to 3.
up_kernel_size (int, optional): Up-convolution kernel size. Defaults to 3.
activation (str, optional): Activation function. Defaults to 'PReLU'.
bias (bool, optional): Whether to include bias in convolutions. Defaults to True.
adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Defaults to 'NDA'.
Raises:
ValueError: If `size` is not one of 'small', 'medium', or 'large'.
Example:
unet = qim_UNet(size='large')
model = unet()
"""
def __init__(self, size = 'medium',
dropout = 0,
kernel_size = 3,
up_kernel_size = 3,
activation = 'PReLU',
bias = True,
adn_order = 'NDA'
):
super().__init__()
if size not in ['small','medium','large']:
raise ValueError(
f"Invalid model size: {size}. Size must be one of the following: 'small', 'medium', 'large'."
)
self.size = size
self.dropout = dropout
self.kernel_size = kernel_size
self.up_kernel_size = up_kernel_size
self.activation = activation
self.bias = bias
self.adn_order = adn_order
self.model = self._model_choice()
def _model_choice(self):
if self.size == 'small':
self.channels = (64, 128, 256)
elif self.size == 'medium':
self.channels = (64, 128, 256, 512, 1024)
elif self.size == 'large':
self.channels = (64, 128, 256, 512, 1024, 2048)
model = monai_UNet(
spatial_dims=2,
in_channels=1, #TODO: check if image has 1 or multiple input channels
out_channels=1,
channels=self.channels,
strides=(2,) * (len(self.channels) - 1),
kernel_size=self.kernel_size,
up_kernel_size=self.up_kernel_size,
act=self.activation,
dropout=self.dropout,
bias=self.bias,
adn_ordering=self.adn_order
)
return model
def forward(self,x):
x = self.model(x)
return x
class Hyperparameters:
"""
Hyperparameters for QIM segmentation.
Args:
model (torch.nn.Module): PyTorch model.
n_epochs (int, optional): Number of training epochs. Defaults to 10.
learning_rate (float, optional): Learning rate for the optimizer. Defaults to 1e-3.
optimizer (str, optional): Optimizer algorithm. Must be one of 'Adam', 'SGD', 'RMSprop'. Defaults to 'Adam'.
momentum (float, optional): Momentum value for SGD and RMSprop optimizers. Defaults to 0.
weight_decay (float, optional): Weight decay (L2 penalty) for the optimizer. Defaults to 0.
loss_function (str, optional): Loss function criterion. Must be one of 'BCE', 'Dice', 'Focal', 'DiceCE'. Defaults to 'BCE'.
Raises:
ValueError: If `loss_function` is not one of 'BCE', 'Dice', 'Focal', 'DiceCE'.
ValueError: If `optimizer` is not one of 'Adam', 'SGD', 'RMSprop'.
Example:
# Create hyperparameters instance
hyperparams = qim_hyperparameters(model=my_model, n_epochs=20, learning_rate=0.001)
# Get the hyperparameters
params = hyperparams()
# Access the optimizer and criterion
optimizer = params['optimizer']
criterion = params['criterion']
n_epochs = params['n_epochs']
"""
def __init__(self,
model,
n_epochs = 10,
learning_rate = 1e-3,
optimizer = 'Adam',
momentum = 0,
weight_decay = 0,
loss_function = 'Focal'):
# TODO: implement custom loss_functions? then add a check to see if loss works for segmentation.
if loss_function not in ['BCE','Dice','Focal','DiceCE']:
raise ValueError(f"Invalid loss function: {loss_function}. Loss criterion must "
"be one of the following: 'BCE','Dice','Focal','DiceCE'.")
# TODO: implement custom optimizer? and add check to see if valid.
if optimizer not in ['Adam','SGD','RMSprop']:
raise ValueError(f"Invalid optimizer: {optimizer}. Optimizer must "
"be one of the following: 'Adam', 'SGD', 'RMSprop'.")
if (momentum != 0) and optimizer == 'Adam':
log.info("Momentum isn't an input in the 'Adam' optimizer. "
"Change optimizer to 'SGD' or 'RMSprop' to use momentum.")
self.model = model
self.n_epochs = n_epochs
self.learning_rate = learning_rate
self.optimizer = optimizer
self.momentum = momentum
self.weight_decay = weight_decay
self.loss_function = loss_function
def __call__(self):
return self.model_params(self.model, self.n_epochs, self.optimizer, self.learning_rate,
self.weight_decay, self.momentum, self.loss_function)
def model_params(self, model, n_epochs, optimizer, learning_rate, weight_decay, momentum, loss_function):
optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
criterion = self._loss_functions(loss_function)
hyper_dict = {'optimizer': optim,
'criterion': criterion,
'n_epochs' : n_epochs,
}
return hyper_dict
# selecting the optimizer
def _optimizer(self, model, optimizer, learning_rate, weight_decay, momentum):
if optimizer == 'Adam':
optim = Adam(model.parameters(), lr = learning_rate,
weight_decay = weight_decay)
elif optimizer == 'SGD':
optim = SGD(model.parameters(), lr = learning_rate,
momentum = momentum, weight_decay = weight_decay)
elif optimizer == 'RMSprop':
optim = RMSprop(model.parameters(),lr = learning_rate,
weight_decay = weight_decay, momentum = momentum)
return optim
# selecting the loss function
def _loss_functions(self,loss_function):
if loss_function =='BCE':
criterion = BCEWithLogitsLoss(reduction='mean')
elif loss_function == 'Dice':
criterion = DiceLoss(sigmoid=True,reduction='mean')
elif loss_function == 'Focal':
criterion = FocalLoss(reduction='mean')
elif loss_function == 'DiceCE':
criterion = DiceCELoss(sigmoid=True,reduction='mean')
return criterion
\ No newline at end of file
from . import internal_tools from . import internal_tools
from . import models from .models import train_model, model_summary, inference
from .augmentations import Augmentation from .augmentations import Augmentation
from .data import Dataset from .data import Dataset, prepare_datasets, prepare_dataloaders
\ No newline at end of file \ No newline at end of file
"""Class for choosing or customizing data augmentations with albumentations""" """Class for choosing the level of data augmentations with albumentations"""
import albumentations as A import albumentations as A
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from qim3d.io.logger import log
class Augmentation: class Augmentation:
""" """
Class for defining image augmentation transformations using Albumentations library. Class for defining image augmentation transformations using the Albumentations library.
Args:
resize ((int,tuple), optional): The target size to resize the image.
trainsform_train (str, optional): level of transformation for the training set.
transform_validation (str, optional): level of transformation for the validation set.
transform_test (str, optional): level of transformation for the test set.
mean (float, optional): The mean value for normalizing pixel intensities.
std (float, optional): The standard deviation value for normalizing pixel intensities.
Raises: Raises:
ValueError: If the provided level is neither None, 'light', 'moderate', 'heavy', nor a custom augmentation. ValueError: If `resize` is neither a None, int nor tuple.
Attributes: Example:
resize (int): The target size to resize the image. my_augmentation = Augmentation(resize = (256,256), transform_train = 'heavy')
mean (float): The mean value for normalizing pixel intensities. """
std (float): The standard deviation value for normalizing pixel intensities.
Methods: def __init__(self,
augment(level=None): Apply image augmentation transformations based on the specified level, or on a resize = None,
custom albumentations augmentation. The available levels are None, 'light', 'moderate', and 'heavy'. transform_train = 'moderate',
transform_validation = None,
transform_test = None,
mean: float = 0.5,
std: float = 0.5
):
Usage: if not isinstance(resize,(type(None),int,tuple)):
my_augmentation = Augmentation() raise ValueError(f"Invalid input for resize: {resize}. Use an integer or tuple to modify the data.")
moderate_augment = augmentation.augment(level='moderate')
"""
def __init__(self, resize=256, mean=0.5, std=0.5):
self.resize = resize self.resize = resize
self.mean = mean self.mean = mean
self.std = std self.std = std
self.transform_train = transform_train
self.transform_validation = transform_validation
self.transform_test = transform_test
def augment(self, im_h, im_w, level=None):
"""
Returns an albumentations.core.composition.Compose class depending on the augmentation level.
A baseline augmentation is implemented regardless of the level, and a set of augmentations are added depending of the level.
The A.Resize() function is used if the user has specified a 'resize' int or tuple at the creation of the Augmentation class.
def augment(self, level=None): Args:
im_h (int): image height for resize.
im_w (int): image width for resize.
level (str, optional): level of augmentation.
Raises:
ValueError: If `level` is neither None, light, moderate nor heavy.
"""
# Check if one of standard augmentation levels # Check if one of standard augmentation levels
if level not in [None,'light','moderate','heavy']: if level not in [None,'light','moderate','heavy']:
raise ValueError(f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'.")
# Check if the custom transformation is an albumentation: # Baseline
if not isinstance(level, A.core.composition.Compose): baseline_aug = [
raise ValueError("Custom Transformations need to be an instance of Albumentations Compose class, " A.Resize(im_h, im_w),
"or one of the following levels: None, 'light', 'moderate', 'heavy'")
# Custom transformation
else:
return level
# Default transformation
elif level is None:
augment = A.Compose([
A.Resize(self.resize, self.resize),
A.Normalize(mean = (self.mean),std = (self.std)), A.Normalize(mean = (self.mean),std = (self.std)),
ToTensorV2() ToTensorV2()
]) ]
# Choosing light augmentation # Level of augmentation
if level == None:
level_aug = []
elif level == 'light': elif level == 'light':
augment = A.Compose([ level_aug = [
A.Resize(self.resize, self.resize), A.RandomRotate90()
A.RandomRotate90(), ]
A.Normalize(mean = (self.mean), std = (self.std)),
ToTensorV2()
])
# Choosing moderate augmentation
elif level == 'moderate': elif level == 'moderate':
augment = A.Compose([ level_aug = [
A.Resize(self.resize, self.resize),
A.RandomRotate90(), A.RandomRotate90(),
A.HorizontalFlip(p = 0.3), A.HorizontalFlip(p = 0.3),
A.VerticalFlip(p = 0.3), A.VerticalFlip(p = 0.3),
A.GlassBlur(sigma = 0.7, p = 0.1), A.GlassBlur(sigma = 0.7, p = 0.1),
A.Affine(scale = [0.8,1.2], translate_percent = (0.1,0.1)), A.Affine(scale = [0.9,1.1], translate_percent = (0.1,0.1))
A.Normalize(mean = (self.mean), std = (self.std)), ]
ToTensorV2()
])
# Choosing heavy augmentation
elif level == 'heavy': elif level == 'heavy':
augment = A.Compose([ level_aug = [
A.Resize(self.resize,self.resize),
A.RandomRotate90(), A.RandomRotate90(),
A.HorizontalFlip(p = 0.7), A.HorizontalFlip(p = 0.7),
A.VerticalFlip(p = 0.7), A.VerticalFlip(p = 0.7),
A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3), A.GlassBlur(sigma = 1.2, iterations = 2, p = 0.3),
A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15)), A.Affine(scale = [0.8,1.4], translate_percent = (0.2,0.2), shear = (-15,15))
A.Normalize(mean = (self.mean), std = (self.std)), ]
ToTensorV2()
]) augment = A.Compose(level_aug + baseline_aug)
return augment return augment
\ No newline at end of file
"""Provides a custom Dataset class for building a PyTorch dataset""" """Provides a custom Dataset class for building a PyTorch dataset."""
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
from qim3d.io.logger import log
from torch.utils.data import DataLoader
import torch import torch
import numpy as np import numpy as np
class Dataset(torch.utils.data.Dataset): class Dataset(torch.utils.data.Dataset):
""" """
Custom Dataset class for building a PyTorch dataset Custom Dataset class for building a PyTorch dataset.
Args: Args:
root_path (str): The root directory path of the dataset. root_path (str): The root directory path of the dataset.
...@@ -51,6 +54,9 @@ class Dataset(torch.utils.data.Dataset): ...@@ -51,6 +54,9 @@ class Dataset(torch.utils.data.Dataset):
self.sample_targets = [file for file in sorted((path / "labels").iterdir())] self.sample_targets = [file for file in sorted((path / "labels").iterdir())]
assert len(self.sample_images) == len(self.sample_targets) assert len(self.sample_images) == len(self.sample_targets)
# checking the characteristics of the dataset
self.check_shape_consistency(self.sample_images)
def __len__(self): def __len__(self):
return len(self.sample_images) return len(self.sample_images)
...@@ -69,3 +75,121 @@ class Dataset(torch.utils.data.Dataset): ...@@ -69,3 +75,121 @@ class Dataset(torch.utils.data.Dataset):
target = transformed["mask"] target = transformed["mask"]
return image, target return image, target
# TODO: working with images of different sizes
def check_shape_consistency(self,sample_images):
image_shapes= []
for image_path in sample_images:
image_shape = self._get_shape(image_path)
image_shapes.append(image_shape)
# check if all images have the same size.
consistency_check = all(i == image_shapes[0] for i in image_shapes)
if not consistency_check:
raise NotImplementedError(
"Only images of all the same size can be processed at the moment"
)
else:
log.debug(
"Images are all the same size!"
)
return consistency_check
def _get_shape(self,image_path):
return Image.open(str(image_path)).size
def check_resize(im_height: int, im_width: int, n_channels: int):
"""
Checks the compatibility of the image shape with the depth of the model.
If the image height and width cannot be divided by 2 `n_channels` times, then the image size is inappropriate.
If so, the image is reshaped to the closest appropriate dimension, and the user is notified with a warning.
Args:
im_height (int): Height of the image chosen by the user.
im_width (int): Width of the image chosen by the user.
n_channels (int): Number of channels in the model.
Raises:
ValueError: If the image size is smaller than minimum required for the model's depth.
"""
h_adjust, w_adjust = (im_height // 2**n_channels) * 2**n_channels , (im_width // 2**n_channels) * 2**n_channels
if h_adjust == 0 or w_adjust == 0:
raise ValueError("The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model.")
elif (h_adjust!=im_height) or (w_adjust != im_width):
log.warning(f"The image size doesn't match the Unet model's depth. The image is resized to: {h_adjust,w_adjust}")
return h_adjust, w_adjust
def prepare_datasets(path: str, val_fraction: float, model, augmentation):
"""
Splits and augments the train/validation/test datasets.
Args:
path (str): Path to the dataset.
val_fraction (float): Fraction of the data for the validation set.
model (torch.nn.Module): PyTorch Model.
augmentation (albumentations.core.composition.Compose): Augmentation class for the dataset with predefined augmentation levels.
Raises:
ValueError: if the validation fraction is not a float, and is not between 0 and 1.
"""
if not isinstance(val_fraction,float) or not (0 <= val_fraction < 1):
raise ValueError("The validation fraction must be a float between 0 and 1.")
resize = augmentation.resize
n_channels = len(model.channels)
if isinstance(resize,type(None)):
# OPEN THE FIRST IMAGE
im_path = Path(path) / 'train'
first_img = sorted((im_path / "images").iterdir())[0]
image = Image.open(str(first_img))
im_h, im_w = image.size[:2]
log.info("User did not choose a specific value for 'resize'. Checking the first image in the dataset...")
elif isinstance(resize,int):
im_h, im_w = resize, resize
else:
im_h,im_w = resize
final_h, final_w = check_resize(im_h, im_w, n_channels)
train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train))
val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation))
test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test))
split_idx = int(np.floor(val_fraction * len(train_set)))
indices = torch.randperm(len(train_set))
train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
val_set = torch.utils.data.Subset(val_set, indices[:split_idx])
return train_set, val_set, test_set
def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = True):
"""
Prepares the dataloaders for model training.
Args:
train_set (torch.utils.data): Training dataset.
val_set (torch.utils.data): Validation dataset.
test_set (torch.utils.data): Testing dataset.
batch_size (int): Size of the batches that should be trained upon.
shuffle_train (bool, optional): Optional input to shuffle the training data (training robustness).
num_workers (int, optional): Defines how many processes should be run in parallel.
pin_memory (bool, optional): Loads the datasets as CUDA tensors.
"""
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
return train_loader,val_loader,test_loader
\ No newline at end of file
""" Tools performed with trained models.""" """ Tools performed with models."""
import torch import torch
import matplotlib.pyplot as plt
from torchinfo import summary
from qim3d.io.logger import log, level
from qim3d.viz.visualizations import plot_metrics
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
def train_model(model, qim_hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True):
""" Function for training Neural Network models.
Args:
model (torch.nn.Module): PyTorch model.
qim_hyperparameters (dict): dictionary with n_epochs, optimizer and criterion.
train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs.
Returns:
tuple:
train_loss (dict): dictionary with average losses and batch losses for training loop.
val_loss (dict): dictionary with average losses and batch losses for validation loop.
Example:
# defining the model.
model = qim3d.qim3d.utils.qim_UNet()
# choosing the hyperparameters
qim_hyper = qim3d.qim3d.utils.qim_hyperparameters(model)
hyper_dict = qim_hyper()
# DataLoaders
train_loader = MyTrainLoader()
val_loader = MyValLoader()
# training the model.
train_loss,val_loss = train_model(model, hyper_dict, train_loader, val_loader)
"""
n_epochs = qim_hyperparameters['n_epochs']
optimizer = qim_hyperparameters['optimizer']
criterion = qim_hyperparameters['criterion']
# Choosing best device available.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Avoid logging twice.
log.propagate = False
train_loss = {'loss' : [],'batch_loss': []}
val_loss = {'loss' : [], 'batch_loss' : []}
with logging_redirect_tqdm():
for epoch in tqdm(range(n_epochs)):
epoch_loss = 0
step = 0
model.train()
for data in train_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backpropagation
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
step += 1
# Log and store batch training loss.
train_loss['batch_loss'].append(loss.detach().item())
# Log and store average training loss per epoch.
epoch_loss = epoch_loss / step
train_loss['loss'].append(epoch_loss)
if epoch % eval_every ==0:
eval_loss = 0
step = 0
model.eval()
for data in val_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
with torch.no_grad():
outputs = model(inputs)
loss = criterion(outputs, targets)
eval_loss += loss.item()
step += 1
# Log and store batch validation loss.
val_loss['batch_loss'].append(loss.item())
# Log and store average validation loss.
eval_loss = eval_loss / step
val_loss['loss'].append(eval_loss)
if epoch % print_every == 0:
log.info(
f"Epoch {epoch: 3}, train loss: {train_loss['loss'][epoch]:.4f}, "
f"val loss: {val_loss['loss'][epoch]:.4f}"
)
if plot:
fig = plt.figure(figsize=(16, 6), constrained_layout = True)
plot_metrics(train_loss, label = 'Train')
plot_metrics(val_loss,color = 'orange', label = 'Valid.')
fig.show()
def model_summary(dataloader,model):
"""Prints the summary of a PyTorch model.
Args:
model (torch.nn.Module): The PyTorch model to summarize.
dataloader (torch.utils.data.DataLoader): The data loader used to determine the input shape.
Returns:
str: Summary of the model architecture.
Example:
model = MyModel()
dataloader = DataLoader(dataset, batch_size=32)
summary = model_summary(model, dataloader)
print(summary)
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
images,_ = next(iter(dataloader))
batch_size = tuple(images.shape)
model_s = summary(model,batch_size,depth = torch.inf)
return model_s
def inference(data,model): def inference(data,model):
"""Performs inference on input data using the specified model. """Performs inference on input data using the specified model.
...@@ -68,7 +218,7 @@ def inference(data,model): ...@@ -68,7 +218,7 @@ def inference(data,model):
inputs = inputs.cpu().squeeze() inputs = inputs.cpu().squeeze()
targets = targets.squeeze() targets = targets.squeeze()
if outputs.shape[1] == 1: if outputs.shape[1] == 1:
preds = outputs.cpu().squeeze() > 0.5 preds = outputs.cpu().squeeze() > 0.5 # TODO: outputs from model are not between [0,1] yet, need to implement that
else: else:
preds = outputs.cpu().argmax(axis=1) preds = outputs.cpu().argmax(axis=1)
......
<<<<<<< HEAD
from .img import grid_pred, grid_overview
from .visualizations import plot_metrics
=======
from .img import grid_pred, grid_overview, slice_viz from .img import grid_pred, grid_overview, slice_viz
>>>>>>> main
"""Visualization tools"""
import numpy as np
import matplotlib.pyplot as plt
import seaborn as snb
def plot_metrics(metric, color = 'blue', linestyle = '-', batch_linestyle = 'dotted', label = None):
"""
Plots the metrics over epochs and batches.
Args:
metric (dict): A dictionary containing the metrics per epochs and per batches.
color (str, optional): The color of the plotted lines. Defaults to 'blue'.
linestyle (str, optional): The style of the epoch metric line. Defaults to '-'.
batch_linestyle (str, optional): The style of the batch metric line. Defaults to 'dotted'.
label (str, optional): The label for the epoch metric line. Defaults to None.
Returns:
None
Example:
train_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss, color = 'red', label='Train')
"""
# plotting parameters
snb.set_style('darkgrid')
snb.set(font_scale=1.5)
plt.rcParams['lines.linewidth'] = 2
metric_name = list(metric.keys())[0]
epoch_metric = metric[list(metric.keys())[0]]
batch_metric = metric[list(metric.keys())[1]]
x_axis = np.linspace(0,len(epoch_metric)-1,len(batch_metric))
plt.plot(epoch_metric,linestyle = linestyle, color = color,label = label)
plt.plot(x_axis, batch_metric, linestyle = batch_linestyle, color = color, alpha = 0.4)
plt.ylabel(metric_name)
plt.xlabel('epoch')
plt.legend()
# reset plotting parameters
snb.set_style('white')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment