diff --git a/docs/notebooks/Unet.ipynb b/docs/notebooks/Unet.ipynb index 4b7a3d9e52aeb8466c67fab035daa131c97e5b2b..37d71cb90229f6636711354d793519de508452e2 100644 --- a/docs/notebooks/Unet.ipynb +++ b/docs/notebooks/Unet.ipynb @@ -20,6 +20,7 @@ "source": [ "from os.path import join\n", "import qim3d\n", + "import os\n", "\n", "%matplotlib inline" ] @@ -35,14 +36,25 @@ "def get_dataset_path(name: str, datasets):\n", " assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)\n", " dataset_idx = datasets.index(name)\n", - " datasets_path = [\n", - " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n", - " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n", - " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n", - " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n", - " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n", - " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n", - " ]\n", + " if os.name == 'nt':\n", + " datasets_path = [\n", + " '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n", + " '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n", + " '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n", + " '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n", + " '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n", + " '//home.cc.dtu.dk/3dimage/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n", + " ]\n", + " else:\n", + " datasets_path = [\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',\n", + " '/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'\n", + " ]\n", + "\n", " return datasets_path[dataset_idx]" ] }, @@ -154,7 +166,7 @@ "outputs": [], "source": [ "# model hyperparameters\n", - "my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=25,\n", + "my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=5,\n", " learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)\n", "\n", "# training model\n", @@ -197,7 +209,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.11.6" } }, "nbformat": 4, diff --git a/qim3d/io/load.py b/qim3d/io/load.py index 620a2eb81d10a0913677f3505a256c1370f1035b..3721c07463505752d669e64b37dc5058ef5b3a45 100644 --- a/qim3d/io/load.py +++ b/qim3d/io/load.py @@ -401,4 +401,4 @@ class ImgExamples: # Generate loader for each image found for idx, name in enumerate(img_names): - exec(f"self.{name} = qim3d.io.load(path = img_paths[idx])") + exec(f"self.{name} = qim3d.io.load(path = img_paths[idx])") \ No newline at end of file diff --git a/qim3d/tests/models/test_unet.py b/qim3d/tests/models/test_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..80b17605329af78f3b04cdcedd35de60608577f6 --- /dev/null +++ b/qim3d/tests/models/test_unet.py @@ -0,0 +1,33 @@ +import qim3d +import torch + +# unit tests for UNet() +def test_starting_unet(): + unet = qim3d.models.UNet() + + assert unet.size == 'medium' + + +def test_forward_pass(): + unet = qim3d.models.UNet() + + # Size: B x C x H x W + x = torch.ones([1,1,256,256]) + + output = unet(x) + assert x.shape == output.shape + +# unit tests for Hyperparameters() +def test_hyper(): + unet = qim3d.models.UNet() + hyperparams = qim3d.models.Hyperparameters(unet) + + assert hyperparams.n_epochs == 10 + +def test_hyper_dict(): + unet = qim3d.models.UNet() + hyperparams = qim3d.models.Hyperparameters(unet) + + hyper_dict = hyperparams() + + assert type(hyper_dict) == dict diff --git a/qim3d/tests/utils/test_augmentations.py b/qim3d/tests/utils/test_augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..da6c490d2706011a4dddfe274bb661e3dc303127 --- /dev/null +++ b/qim3d/tests/utils/test_augmentations.py @@ -0,0 +1,32 @@ +import qim3d +import albumentations +import pytest + +# unit tests for Augmentation() +def test_augmentation(): + augment_class = qim3d.utils.Augmentation() + + assert augment_class.resize == 'crop' + +def test_augment(): + augment_class = qim3d.utils.Augmentation() + + album_augment = augment_class.augment(256,256) + + assert type(album_augment) == albumentations.core.composition.Compose + +# unit tests for ValueErrors in Augmentation() +def test_resize(): + resize_str = 'not valid resize' + + with pytest.raises(ValueError,match = f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'."): + augment_class = qim3d.utils.Augmentation(resize = resize_str) + + +def test_levels(): + augment_class = qim3d.utils.Augmentation() + + level = 'Not a valid level' + + with pytest.raises(ValueError, match=f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."): + augment_class.augment(256,256,level) \ No newline at end of file diff --git a/qim3d/tests/utils/test_data.py b/qim3d/tests/utils/test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..928ef26f2f785b7760c073f2519f37eb9ebb61f6 --- /dev/null +++ b/qim3d/tests/utils/test_data.py @@ -0,0 +1,78 @@ +import qim3d +import pytest + +from torch.utils.data.dataloader import DataLoader +from qim3d.utils.internal_tools import temp_data + +# unit tests for Dataset() +def test_dataset(): + img_shape = (32,32) + folder = 'folder_data' + temp_data(folder, img_shape = img_shape) + + images = qim3d.utils.Dataset(folder) + + assert images[0][0].shape == img_shape + + temp_data(folder,remove=True) + + +# unit tests for check_resize() +def test_check_resize(): + h_adjust,w_adjust = qim3d.utils.data.check_resize(240,240,resize = 'crop',n_channels = 6) + + assert (h_adjust,w_adjust) == (192,192) + +def test_check_resize_pad(): + h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'padding',n_channels = 6) + + assert (h_adjust,w_adjust) == (64,64) + +def test_check_resize_fail(): + + with pytest.raises(ValueError,match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model."): + h_adjust,w_adjust = qim3d.utils.data.check_resize(16,16,resize = 'crop',n_channels = 6) + + +# unit tests for prepare_datasets() +def test_prepare_datasets(): + n = 3 + validation = 1/3 + + folder = 'folder_data' + img = temp_data(folder,n = n) + + my_model = qim3d.models.UNet() + my_augmentation = qim3d.utils.Augmentation(transform_test='light') + train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,validation,my_model,my_augmentation) + + assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n) + + temp_data(folder,remove=True) + + +# unit test for validation in prepare_datasets() +def test_validation(): + validation = 10 + + with pytest.raises(ValueError,match = "The validation fraction must be a float between 0 and 1."): + augment_class = qim3d.utils.prepare_datasets('folder',validation,'my_model','my_augmentation') + + +# unit test for prepare_dataloaders() +def test_prepare_dataloaders(): + folder = 'folder_data' + temp_data(folder) + + batch_size = 1 + my_model = qim3d.models.UNet() + my_augmentation = qim3d.utils.Augmentation() + train_set, val_set, test_set = qim3d.utils.prepare_datasets(folder,1/3,my_model,my_augmentation) + + _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, + batch_size,num_workers = 1, + pin_memory = False) + + assert type(val_loader) == DataLoader + + temp_data(folder,remove=True) \ No newline at end of file diff --git a/qim3d/tests/utils/test_models.py b/qim3d/tests/utils/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..37262ad6517b3e603c972de6526e7219cccf2611 --- /dev/null +++ b/qim3d/tests/utils/test_models.py @@ -0,0 +1,107 @@ +import qim3d +import pytest +from torch import ones + +from qim3d.utils.internal_tools import temp_data + +# unit test for model summary() +def test_model_summary(): + n = 10 + img_shape = (32,32) + folder = 'folder_data' + temp_data(folder,img_shape=img_shape,n = n) + + unet = qim3d.models.UNet(size = 'small') + augment = qim3d.utils.Augmentation(transform_train=None) + train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) + + _,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, + batch_size = 1,num_workers = 1, + pin_memory = False) + summary = qim3d.utils.model_summary(val_loader,unet) + + assert summary.input_size[0] == (1,1) + img_shape + + temp_data(folder,remove=True) + + +# unit test for inference() +def test_inference(): + folder = 'folder_data' + temp_data(folder) + + unet = qim3d.models.UNet(size = 'small') + augment = qim3d.utils.Augmentation(transform_train=None) + train_set,_,_ = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) + + _, targ,_ = qim3d.utils.inference(train_set,unet) + + assert tuple(targ[0].unique()) == (0,1) + + temp_data(folder,remove=True) + + +#unit test for tuple ValueError(). +def test_inference_tuple(): + folder = 'folder_data' + temp_data(folder) + + unet = qim3d.models.UNet(size = 'small') + + data = [1,2,3] + with pytest.raises(ValueError,match="Data items must be tuples"): + qim3d.utils.inference(data,unet) + + temp_data(folder,remove=True) + + +#unit test for tensor ValueError(). +def test_inference_tensor(): + folder = 'folder_data' + temp_data(folder) + + unet = qim3d.models.UNet(size = 'small') + + data = [(1,2)] + with pytest.raises(ValueError,match="Data items must consist of tensors"): + qim3d.utils.inference(data,unet) + + temp_data(folder,remove=True) + + +#unit test for dimension ValueError(). +def test_inference_dim(): + folder = 'folder_data' + temp_data(folder) + + unet = qim3d.models.UNet(size = 'small') + + data = [(ones(1),ones(1))] + # need the r"" for special characters + with pytest.raises(ValueError,match=r"Input image must be \(C,H,W\) format"): + qim3d.utils.inference(data,unet) + + temp_data(folder,remove=True) + + +# unit test for train_model() +def test_train_model(): + folder = 'folder_data' + temp_data(folder) + + n_epochs = 1 + + unet = qim3d.models.UNet(size = 'small') + augment = qim3d.utils.Augmentation(transform_train=None) + hyperparams = qim3d.models.Hyperparameters(unet,n_epochs=n_epochs) + train_set,val_set,test_set = qim3d.utils.prepare_datasets(folder,1/3,unet,augment) + train_loader,val_loader,_ = qim3d.utils.prepare_dataloaders(train_set,val_set,test_set, + batch_size = 1,num_workers = 1, + pin_memory = False) + + train_loss,_ = qim3d.utils.train_model(unet,hyperparams,train_loader,val_loader, + plot = False, return_loss = True) + + assert len(train_loss['loss']) == n_epochs + + temp_data(folder,remove=True) \ No newline at end of file diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py index 5eb9ccab061ad3ca6225622987bbed4cb11491bf..332856e43248ffc083059ac93276697625a48de3 100644 --- a/qim3d/utils/data.py +++ b/qim3d/utils/data.py @@ -160,21 +160,21 @@ def prepare_datasets(path: str, val_fraction: float, model, augmentation): orig_h, orig_w = image.size[:2] final_h, final_w = check_resize(orig_h, orig_w, resize, n_channels) - + train_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_train)) val_set = Dataset(root_path = path, transform = augmentation.augment(final_h, final_w, augmentation.transform_validation)) test_set = Dataset(root_path = path, split='test', transform = augmentation.augment(final_h, final_w, augmentation.transform_test)) split_idx = int(np.floor(val_fraction * len(train_set))) indices = torch.randperm(len(train_set)) - + train_set = torch.utils.data.Subset(train_set, indices[split_idx:]) val_set = torch.utils.data.Subset(val_set, indices[:split_idx]) return train_set, val_set, test_set -def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = True): +def prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train = True, num_workers = 8, pin_memory = False): """ Prepares the dataloaders for model training. diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py index e31001266288c405f3fb1d9092341d4e51cf1e98..b8a4b6b7bcb86be7cb4496b576dd68381db57556 100644 --- a/qim3d/utils/internal_tools.py +++ b/qim3d/utils/internal_tools.py @@ -8,6 +8,9 @@ import matplotlib import numpy as np import socket import os +import shutil +from PIL import Image +from pathlib import Path @@ -177,10 +180,52 @@ def is_server_running(ip, port): return True except: return False + +def temp_data(folder,remove = False,n = 3,img_shape = (32,32)): + folder_trte = ['train','test'] + sub_folders = ['images','labels'] + + # Creating train/test folder + path_train = Path(folder) / folder_trte[0] + path_test = Path(folder) / folder_trte[1] + + # Creating folders for images and labels + path_train_im = path_train / sub_folders[0] + path_train_lab = path_train / sub_folders[1] + path_test_im = path_test / sub_folders[0] + path_test_lab = path_test / sub_folders[1] + + # Random image + img = np.random.randint(2,size = img_shape,dtype = np.uint8) + img = Image.fromarray(img) + + if not os.path.exists(path_train): + os.makedirs(path_train_im) + os.makedirs(path_test_im) + os.makedirs(path_train_lab) + os.makedirs(path_test_lab) + for i in range(n): + img.save(path_train_im / f'img_train{i}.png') + img.save(path_train_lab / f'img_train{i}.png') + img.save(path_test_im / f'img_test{i}.png') + img.save(path_test_lab / f'img_test{i}.png') + + if remove: + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print('Failed to delete %s. Reason: %s' % (file_path, e)) + + os.rmdir(folder) def stringify_path(path): """Converts an os.PathLike object to a string """ if isinstance(path,os.PathLike): path = path.__fspath__() - return path \ No newline at end of file + return path diff --git a/qim3d/utils/models.py b/qim3d/utils/models.py index a011fe77ac42846ce963ad4a60c38ba9d52f2d0b..f0693186aecb2634a3d2d4113788bad74b6010ef 100644 --- a/qim3d/utils/models.py +++ b/qim3d/utils/models.py @@ -10,7 +10,7 @@ from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm -def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True): +def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True, return_loss = False): """ Function for training Neural Network models. Args: @@ -20,6 +20,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 val_loader (torch.utils.data.DataLoader): DataLoader for the validation data. eval_every (int, optional): frequency of model evaluation. Defaults to every epoch. print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs. + Returns: tuple: @@ -65,10 +66,11 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 for data in train_loader: inputs, targets = data inputs = inputs.to(device) - targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1) + targets = targets.to(device).unsqueeze(1) optimizer.zero_grad() outputs = model(inputs) + loss = criterion(outputs, targets) # Backpropagation @@ -94,8 +96,8 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 for data in val_loader: inputs, targets = data inputs = inputs.to(device) - targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1) - + targets = targets.to(device).unsqueeze(1) + with torch.no_grad(): outputs = model(inputs) loss = criterion(outputs, targets) @@ -122,6 +124,9 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 plot_metrics(val_loss,color = 'orange', label = 'Valid.') fig.show() + if return_loss: + return train_loss,val_loss + def model_summary(dataloader,model): """Prints the summary of a PyTorch model. @@ -196,7 +201,7 @@ def inference(data,model): else: raise ValueError("Input image must be (C,H,W) format") - + model.to(device) model.eval() # Make new list such that possible augmentations remain identical for all three rows