Skip to content
Snippets Groups Projects
Commit 76e3a4c7 authored by fima's avatar fima :beers:
Browse files

Merge branch 'fix-unit-tests' into 'main'

Fix unit tests

See merge request !140
parents 31fba17b d5ab756a
Branches
No related tags found
1 merge request!140Fix unit tests
Showing
with 124 additions and 109 deletions
......@@ -8,9 +8,10 @@ from skimage import morphology
import dask.array as da
import dask_image.ndfilters as dask_ndfilters
from qim3d.utils._logger import log
from qim3d.utils import log
__all__ = [
"FilterBase",
"Gaussian",
"Median",
"Maximum",
......
......@@ -27,10 +27,7 @@ import tempfile
import gradio as gr
import numpy as np
from PIL import Image
from qim3d.io import load, save
from qim3d.operations._common_operations_methods import overlay_rgb_images
from qim3d.gui.interface import BaseInterface
import qim3d
# TODO: img in launch should be self.img
......@@ -68,7 +65,7 @@ class Interface(BaseInterface):
for temp_file in temp_path_list:
mask_file = os.path.basename(temp_file)
mask_name = os.path.splitext(mask_file)[0]
masks[mask_name] = load(temp_file)
masks[mask_name] = qim3d.io.load(temp_file)
return masks
......@@ -98,7 +95,7 @@ class Interface(BaseInterface):
def create_preview(self, img_editor: gr.ImageEditor) -> np.ndarray:
background = img_editor["background"]
masks = img_editor["layers"][0]
overlay_image = overlay_rgb_images(background, masks)
overlay_image = qim3d.operations.overlay_rgb_images(background, masks)
return overlay_image
def cerate_download_list(self, img_editor: gr.ImageEditor) -> list[str]:
......@@ -122,7 +119,7 @@ class Interface(BaseInterface):
filepath = os.path.join(self.temp_dir, filename)
files_list.append(filepath)
save(filepath, mask, replace=True)
qim3d.io.save(filepath, mask, replace=True)
self.temp_files.append(filepath)
return files_list
......
......@@ -22,9 +22,8 @@ import numpy as np
import plotly.graph_objects as go
from scipy import ndimage
from qim3d.io import load
import qim3d
from qim3d.utils._logger import log
from qim3d.gui.interface import InterfaceWithExamples
......@@ -46,7 +45,7 @@ class Interface(InterfaceWithExamples):
def load_data(self, gradiofile: gr.File):
try:
self.vol = load(gradiofile.name)
self.vol = qim3d.io.load(gradiofile.name)
assert self.vol.ndim == 3
except AttributeError:
raise gr.Error("You have to select a file")
......
......@@ -21,7 +21,6 @@ import os
import gradio as gr
import numpy as np
from .interface import BaseInterface
# from qim3d.processing import layers2d as l2d
......@@ -358,7 +357,7 @@ class Interface(BaseInterface):
raise gr.Error("Invalid file path")
try:
self.data = load(
self.data = qim3d.io.load(
file_path,
progress_bar=False
)
......@@ -403,7 +402,7 @@ class Interface(BaseInterface):
if self.is_transposed(slicing_axis, segmenting_axis):
slice = np.rot90(slice)
self.__dict__[seg_key] = segment_layers(slice, inverted = inverted, n_layers = n_layers, delta = delta, min_margin = min_margin, wrap = wrap)
self.__dict__[seg_key] = qim3d.processing.segment_layers(slice, inverted = inverted, n_layers = n_layers, delta = delta, min_margin = min_margin, wrap = wrap)
return process
......@@ -456,13 +455,13 @@ class Interface(BaseInterface):
seg = np.rot90(seg, k = 3)
# slice = 255 * (slice/np.max(slice))
# return image_with_overlay(np.repeat(slice[..., None], 3, -1), seg, alpha)
return overlay_rgb_images(slice, seg, alpha)
return qim3d.operations.overlay_rgb_images(slice, seg, alpha)
else:
lines = get_lines(seg)
lines = qim3d.processing.get_lines(seg)
if self.is_transposed(slicing_axis, segmenting_axis):
return image_with_lines(np.rot90(slice), lines, line_thickness).rotate(270, expand = True)
return qim3d.viz.image_with_lines(np.rot90(slice), lines, line_thickness).rotate(270, expand = True)
else:
return image_with_lines(slice, lines, line_thickness)
return qim3d.viz.image_with_lines(slice, lines, line_thickness)
return plot_output_img
......
......@@ -40,12 +40,11 @@ import gradio as gr
import numpy as np
import tifffile
import localthickness as lt
import qim3d
from qim3d.io import load
from qim3d.gui.interface import InterfaceWithExamples
class Interface(InterfaceWithExamples):
class Interface(qim3d.gui.interface.InterfaceWithExamples):
def __init__(self,
img: np.ndarray = None,
verbose:bool = False,
......@@ -79,7 +78,7 @@ class Interface(InterfaceWithExamples):
file_idx = np.argmax(creation_time_list)
# Load the temporary file
vol_lt = load(temp_path_list[file_idx])
vol_lt = qim3d.io.load(temp_path_list[file_idx])
return vol_lt
......@@ -251,7 +250,7 @@ class Interface(InterfaceWithExamples):
def process_input(self, data: np.ndarray, dark_objects: bool):
# Load volume
try:
self.vol = load(data.name)
self.vol = qim3d.io.load(data.name)
assert self.vol.ndim == 3
except AttributeError:
self.vol = data
......
......@@ -10,7 +10,7 @@ from tqdm import tqdm
import zarr.core
from qim3d.utils._misc import stringify_path
from qim3d.io._saving import save
from qim3d.io import save
class Convert:
......
......@@ -8,7 +8,7 @@ from tqdm import tqdm
from pathlib import Path
from qim3d.io import load
from qim3d.utils._logger import log
from qim3d.utils import log
import outputformat as ouf
......
......@@ -23,9 +23,9 @@ from dask import delayed
from PIL import Image, UnidentifiedImageError
import qim3d
from qim3d.utils._logger import log
from qim3d.utils import log
from qim3d.utils._misc import get_file_size, sizeof, stringify_path
from qim3d.utils._system import Memory
from qim3d.utils import Memory
from qim3d.utils._progress_bar import FileLoadingProgressBar
import trimesh
......@@ -718,7 +718,7 @@ class DataLoader:
# Fails
else:
# Find the closest matching path to warn the user
similar_paths = qim3d.utils.misc.find_similar_paths(path)
similar_paths = qim3d.utils._misc.find_similar_paths(path)
if similar_paths:
suggestion = similar_paths[0] # Get the closest match
......
......@@ -31,7 +31,7 @@ from skimage.transform import (
resize,
)
from qim3d.utils._logger import log
from qim3d.utils import log
from qim3d.utils._progress_bar import OmeZarrExportProgressBar
from qim3d.utils._ome_zarr import get_n_chunks
......
......@@ -37,7 +37,7 @@ from pydicom.dataset import FileDataset, FileMetaDataset
from pydicom.uid import UID
import trimesh
from qim3d.utils._logger import log
from qim3d.utils import log
from qim3d.utils._misc import sizeof, stringify_path
......
......@@ -2,7 +2,7 @@
import os
import subprocess
import outputformat as ouf
from qim3d.utils._logger import log
from qim3d.utils import log
from pathlib import Path
......
"""Provides a custom Dataset class for building a PyTorch dataset."""
from pathlib import Path
from PIL import Image
from qim3d.utils._logger import log
from qim3d.utils import log
import torch
import numpy as np
from typing import Optional, Callable
......
......@@ -2,7 +2,7 @@
import torch.nn as nn
from qim3d.utils._logger import log
from qim3d.utils import log
class UNet(nn.Module):
......
import numpy as np
import qim3d.filters as filters
from qim3d.utils._logger import log
from qim3d.utils import log
__all__ = ["remove_background", "fade_mask", "overlay_rgb_images"]
......
......@@ -2,7 +2,7 @@
import numpy as np
from typing import Optional
from qim3d.utils._logger import log
from qim3d.utils import log
import qim3d
......
......@@ -3,7 +3,7 @@
from typing import Tuple
import logging
import numpy as np
from qim3d.utils._logger import log
from qim3d.utils import log
def structure_tensor(
......@@ -108,7 +108,7 @@ def structure_tensor(
val, vec = st.eig_special_3d(s_vol, full=full)
if visualize:
from qim3d.viz._structure_tensor import vectors
from qim3d.viz import vectors
display(vectors(vol, vec, **viz_kwargs))
......
import qim3d
from qim3d.filters import *
import numpy as np
import pytest
import re
def test_filter_base_initialization():
filter_base = qim3d.processing.filters.FilterBase(3,size=2)
filter_base = qim3d.filters.FilterBase(3,size=2)
assert filter_base.args == (3,)
assert filter_base.kwargs == {'size': 2}
......@@ -13,10 +12,10 @@ def test_gaussian_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = gaussian(input_image,sigma=1.5)
filtered_image_fn = qim3d.filters.gaussian(input_image,sigma=1.5)
# Testing the class method
gaussian_filter_cls = Gaussian(sigma=1.5)
gaussian_filter_cls = qim3d.filters.Gaussian(sigma=1.5)
filtered_image_cls = gaussian_filter_cls(input_image)
# Assertions
......@@ -28,10 +27,10 @@ def test_median_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = median(input_image, size=3)
filtered_image_fn = qim3d.filters.median(input_image, size=3)
# Testing the class method
median_filter_cls = Median(size=3)
median_filter_cls = qim3d.filters.Median(size=3)
filtered_image_cls = median_filter_cls(input_image)
# Assertions
......@@ -43,10 +42,10 @@ def test_maximum_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = maximum(input_image, size=3)
filtered_image_fn = qim3d.filters.maximum(input_image, size=3)
# Testing the class method
maximum_filter_cls = Maximum(size=3)
maximum_filter_cls = qim3d.filters.Maximum(size=3)
filtered_image_cls = maximum_filter_cls(input_image)
# Assertions
......@@ -58,10 +57,10 @@ def test_minimum_filter():
input_image = np.random.rand(50, 50)
# Testing the function
filtered_image_fn = minimum(input_image, size=3)
filtered_image_fn = qim3d.filters.minimum(input_image, size=3)
# Testing the class method
minimum_filter_cls = Minimum(size=3)
minimum_filter_cls = qim3d.filters.Minimum(size=3)
filtered_image_cls = minimum_filter_cls(input_image)
# Assertions
......@@ -73,16 +72,16 @@ def test_sequential_filter_pipeline():
input_image = np.random.rand(50, 50)
# Individual filters
gaussian_filter = Gaussian(sigma=1.5)
median_filter = Median(size=3)
maximum_filter = Maximum(size=3)
gaussian_filter = qim3d.filters.Gaussian(sigma=1.5)
median_filter = qim3d.filters.Median(size=3)
maximum_filter = qim3d.filters.Maximum(size=3)
# Testing the sequential pipeline
sequential_pipeline = Pipeline(gaussian_filter, median_filter, maximum_filter)
sequential_pipeline = qim3d.filters.Pipeline(gaussian_filter, median_filter, maximum_filter)
filtered_image_pipeline = sequential_pipeline(input_image)
# Testing the equivalence to maximum(median(gaussian(input,**kwargs),**kwargs),**kwargs)
expected_output = maximum(median(gaussian(input_image, sigma=1.5), size=3), size=3)
expected_output = qim3d.filters.maximum(qim3d.filters.median(qim3d.filters.gaussian(input_image, sigma=1.5), size=3), size=3)
# Assertions
assert filtered_image_pipeline.shape == expected_output.shape == input_image.shape
......@@ -93,16 +92,16 @@ def test_sequential_filter_appending():
input_image = np.random.rand(50, 50)
# Individual filters
gaussian_filter = Gaussian(sigma=1.5)
median_filter = Median(size=3)
maximum_filter = Maximum(size=3)
gaussian_filter = qim3d.filters.Gaussian(sigma=1.5)
median_filter = qim3d.filters.Median(size=3)
maximum_filter = qim3d.filters.Maximum(size=3)
# Sequential pipeline with filter initialized at the beginning
sequential_pipeline_initial = Pipeline(gaussian_filter, median_filter, maximum_filter)
sequential_pipeline_initial = qim3d.filters.Pipeline(gaussian_filter, median_filter, maximum_filter)
filtered_image_initial = sequential_pipeline_initial(input_image)
# Sequential pipeline with filter appended
sequential_pipeline_appended = Pipeline(gaussian_filter, median_filter)
sequential_pipeline_appended = qim3d.filters.Pipeline(gaussian_filter, median_filter)
sequential_pipeline_appended.append(maximum_filter)
filtered_image_appended = sequential_pipeline_appended(input_image)
......@@ -113,7 +112,7 @@ def test_sequential_filter_appending():
def test_assertion_error_not_filterbase_subclass():
# Get valid filter classes
valid_filters = [subclass.__name__ for subclass in qim3d.processing.filters.FilterBase.__subclasses__()]
valid_filters = [subclass.__name__ for subclass in qim3d.filters.FilterBase.__subclasses__()]
# Create invalid object
invalid_filter = object() # An object that is not an instance of FilterBase
......@@ -124,4 +123,4 @@ def test_assertion_error_not_filterbase_subclass():
# Use pytest.raises to catch the AssertionError
with pytest.raises(AssertionError, match=re.escape(message)):
sequential_pipeline = Pipeline(invalid_filter)
\ No newline at end of file
sequential_pipeline = qim3d.filters.Pipeline(invalid_filter)
\ No newline at end of file
import qim3d
import os
import pytest
from pathlib import Path
import shutil
def test_download():
folder = 'Cowry_Shell'
file = 'Cowry_DOWNSAMPLED.tif'
path = os.path.join(folder,file)
@pytest.fixture
def setup_temp_folder():
"""Fixture to create and clean up a temporary folder for tests."""
folder = "Cowry_Shell"
file = "Cowry_DOWNSAMPLED.tif"
path = Path(folder) / file
dl = qim3d.io.Downloader()
# Ensure clean environment before running tests
if Path(folder).exists():
shutil.rmtree(folder)
yield folder, path
# Cleanup after tests
if path.exists():
path.unlink()
if Path(folder).exists():
shutil.rmtree(folder)
dl.Cowry_Shell.Cowry_DOWNSAMPLED()
def test_download(setup_temp_folder):
folder, path = setup_temp_folder
img = qim3d.io.load(path)
dl = qim3d.io.Downloader()
dl.Cowry_Shell.Cowry_DOWNSAMPLED()
# Remove temp file
os.remove(path)
os.rmdir(folder)
# Verify the file was downloaded correctly
assert path.exists(), f"{path} does not exist after download."
img = qim3d.io.load(str(path))
assert img.shape == (500, 350, 350)
def test_get_file_size_right():
coal_file = 'https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository/Coal/CoalBrikett.tif'
size = qim3d.io.downloader._get_file_size(coal_file)
# Cleanup is handled by the fixture
assert size == 2_400_082_900
def test_get_file_size():
"""Tests for correct and incorrect file size retrieval."""
coal_file = "https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository/Coal/CoalBrikett.tif"
folder_url = "https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository/"
def test_get_file_size_wrong():
file_to_folder = 'https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository/'
size = qim3d.io.downloader._get_file_size(file_to_folder)
# Correct file size
size = qim3d.io._downloader._get_file_size(coal_file)
assert size == 2_400_082_900, f"Expected size mismatch for {coal_file}."
assert size == -1
# Wrong URL (not a file)
size = qim3d.io._downloader._get_file_size(folder_url)
assert size == -1, "Expected size -1 for non-file URL."
def test_extract_html():
url = 'https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository'
html = qim3d.io.downloader._extract_html(url)
url = "https://archive.compute.dtu.dk/files/public/projects/viscomp_data_repository"
html = qim3d.io._downloader._extract_html(url)
assert 'data-path="/files/public/projects/viscomp_data_repository"' in html, \
"Expected HTML content not found in extracted HTML."
assert 'data-path="/files/public/projects/viscomp_data_repository"' in html
......@@ -3,13 +3,13 @@ import torch
# unit tests for UNet()
def test_starting_unet():
unet = qim3d.models.UNet()
unet = qim3d.ml.models.UNet()
assert unet.size == 'medium'
def test_forward_pass():
unet = qim3d.models.UNet()
unet = qim3d.ml.models.UNet()
# Size: B x C x H x W
x = torch.ones([1,1,256,256])
......@@ -19,14 +19,14 @@ def test_forward_pass():
# unit tests for Hyperparameters()
def test_hyper():
unet = qim3d.models.UNet()
hyperparams = qim3d.models.Hyperparameters(unet)
unet = qim3d.ml.models.UNet()
hyperparams = qim3d.ml.models.Hyperparameters(unet)
assert hyperparams.n_epochs == 10
def test_hyper_dict():
unet = qim3d.models.UNet()
hyperparams = qim3d.models.Hyperparameters(unet)
unet = qim3d.ml.models.UNet()
hyperparams = qim3d.ml.models.Hyperparameters(unet)
hyper_dict = hyperparams()
......
......@@ -12,16 +12,16 @@ def test_model_summary():
folder = "folder_data"
temp_data(folder, img_shape=img_shape, n=n)
unet = qim3d.models.UNet(size="small")
augment = qim3d.models.Augmentation(transform_train=None)
train_set, val_set, test_set = qim3d.models.prepare_datasets(
unet = qim3d.ml.models.UNet(size="small")
augment = qim3d.ml.Augmentation(transform_train=None)
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
folder, 1 / 3, unet, augment
)
_, val_loader, _ = qim3d.models.prepare_dataloaders(
_, val_loader, _ = qim3d.ml.prepare_dataloaders(
train_set, val_set, test_set, batch_size=1, num_workers=1, pin_memory=False
)
summary = qim3d.models.model_summary(val_loader, unet)
summary = qim3d.ml.model_summary(val_loader, unet)
assert summary.input_size[0] == (1, 1) + img_shape
......@@ -33,11 +33,11 @@ def test_inference():
folder = "folder_data"
temp_data(folder)
unet = qim3d.models.UNet(size="small")
augment = qim3d.models.Augmentation(transform_train=None)
train_set, _, _ = qim3d.models.prepare_datasets(folder, 1 / 3, unet, augment)
unet = qim3d.ml.models.UNet(size="small")
augment = qim3d.ml.Augmentation(transform_train=None)
train_set, _, _ = qim3d.ml.prepare_datasets(folder, 1 / 3, unet, augment)
_, targ, _ = qim3d.models.inference(train_set, unet)
_, targ, _ = qim3d.ml.inference(train_set, unet)
assert tuple(targ[0].unique()) == (0, 1)
......@@ -49,11 +49,11 @@ def test_inference_tuple():
folder = "folder_data"
temp_data(folder)
unet = qim3d.models.UNet(size="small")
unet = qim3d.ml.models.UNet(size="small")
data = [1, 2, 3]
with pytest.raises(ValueError, match="Data items must be tuples"):
qim3d.models.inference(data, unet)
qim3d.ml.inference(data, unet)
temp_data(folder, remove=True)
......@@ -63,11 +63,11 @@ def test_inference_tensor():
folder = "folder_data"
temp_data(folder)
unet = qim3d.models.UNet(size="small")
unet = qim3d.ml.models.UNet(size="small")
data = [(1, 2)]
with pytest.raises(ValueError, match="Data items must consist of tensors"):
qim3d.models.inference(data, unet)
qim3d.ml.inference(data, unet)
temp_data(folder, remove=True)
......@@ -77,12 +77,12 @@ def test_inference_dim():
folder = "folder_data"
temp_data(folder)
unet = qim3d.models.UNet(size="small")
unet = qim3d.ml.models.UNet(size="small")
data = [(ones(1), ones(1))]
# need the r"" for special characters
with pytest.raises(ValueError, match=r"Input image must be \(C,H,W\) format"):
qim3d.models.inference(data, unet)
qim3d.ml.inference(data, unet)
temp_data(folder, remove=True)
......@@ -94,17 +94,17 @@ def test_train_model():
n_epochs = 1
unet = qim3d.models.UNet(size="small")
augment = qim3d.models.Augmentation(transform_train=None)
hyperparams = qim3d.models.Hyperparameters(unet, n_epochs=n_epochs)
train_set, val_set, test_set = qim3d.models.prepare_datasets(
unet = qim3d.ml.models.UNet(size="small")
augment = qim3d.ml.Augmentation(transform_train=None)
hyperparams = qim3d.ml.Hyperparameters(unet, n_epochs=n_epochs)
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
folder, 1 / 3, unet, augment
)
train_loader, val_loader, _ = qim3d.models.prepare_dataloaders(
train_loader, val_loader, _ = qim3d.ml.prepare_dataloaders(
train_set, val_set, test_set, batch_size=1, num_workers=1, pin_memory=False
)
train_loss, _ = qim3d.models.train_model(
train_loss, _ = qim3d.ml.train_model(
unet, hyperparams, train_loader, val_loader, plot=False, return_loss=True
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment