diff --git a/qim3d/gui/data_explorer.py b/qim3d/gui/data_explorer.py index 15a744ca53a3f65597758700b886d14cc6868fba..a8b4470163eced67cde9e90e0c26ee455f3e7cbe 100644 --- a/qim3d/gui/data_explorer.py +++ b/qim3d/gui/data_explorer.py @@ -17,6 +17,7 @@ app.launch() import datetime import os +import re import gradio as gr import matplotlib @@ -119,30 +120,55 @@ class Interface: value="⟳", elem_classes="btn-html h-36" ) explorer = gr.FileExplorer( - ignore_glob="*/.*", # ignores hidden files + ignore_glob="*/.*", # ignores hidden files root_dir=os.getcwd(), label=os.getcwd(), render=True, file_count="single", interactive=True, - elem_classes="h-256 hide-overflow", + elem_classes="h-320 hide-overflow", ) with gr.Column(scale=1): gr.Markdown("### Parameters") - virtual_stack = gr.Checkbox(value=False, label="Virtual stack") - cmap = gr.Dropdown( value="viridis", choices=plt.colormaps(), label="Colormap", interactive=True, ) + + virtual_stack = gr.Checkbox( + value=False, + label="Virtual stack", + info="If checked, will use less memory by loading the images on demand.", + ) + load_series = gr.Checkbox( + value=False, + label="Load series", + info="If checked, will load the whole series of images in the same folder as the selected file.", + ) + series_contains = gr.Textbox( + label="Specify common part of file names for series", + value="", + visible=False, + ) + dataset_name = gr.Textbox( label="Dataset name (in case of H5 files, for example)", value="exchange/data", ) + def toggle_show(checkbox): + return ( + gr.update(visible=True) + if checkbox + else gr.update(visible=False) + ) + + # Show series_contains only if load_series is checked + load_series.change(toggle_show, load_series, series_contains) + with gr.Column(scale=1): gr.Markdown("### Operations") operations = gr.CheckboxGroup( @@ -197,7 +223,7 @@ class Interface: # Intensity histogram with gr.Column(visible=False) as result_intensity_histogram: hist_plot = gr.Plot(label="Volume intensity histogram") - + # Text box with data summary with gr.Column(visible=False) as result_data_summary: data_summary = gr.Text( @@ -233,6 +259,8 @@ class Interface: cmap, dataset_name, virtual_stack, + load_series, + series_contains, ] # Outputs outputs = [ @@ -243,7 +271,6 @@ class Interface: min_projection_plot, hist_plot, data_summary, - ] ### Listeners @@ -283,6 +310,8 @@ class Interface: def start_session(self, *args): # Starts a new session dictionary session = Session() + # Tells rest of the pipeline if the session failed at some point prior, so should skip the rest + session.failed = False session.all_operations = Interface().operations session.operations = args[0] session.base_path = args[1] @@ -293,14 +322,42 @@ class Interface: session.cmap = args[6] session.dataset_name = args[7] session.virtual_stack = args[8] + session.load_series = args[9] + session.series_contains = args[10] + + if session.load_series and session.series_contains == "": + + # Try to guess the common part of the file names + try: + filename = session.explorer.split("/")[-1] # Extract filename from path + series_contains = re.search(r"[^0-9]+", filename).group() + gr.Info(f"Using '{series_contains}' as common file name part for loading.") + session.series_contains = series_contains + + except: + session.failed = True + raise gr.Error( + "For series, common part of file name must be provided in 'series_contains' field." + ) # Get the file path from the explorer or base path - if session.base_path and os.path.isfile(session.base_path): - session.file_path = session.base_path - elif session.explorer and os.path.isfile(session.explorer): + # priority is given to the explorer if file is selected + # else the base path is used + if session.explorer and ( + os.path.isfile(session.explorer) or session.load_series + ): session.file_path = session.explorer + elif session.base_path and ( + os.path.isfile(session.base_path) or session.load_series + ): + session.file_path = session.base_path else: - raise ValueError("Invalid file path") + session.failed = True + raise gr.Error("Invalid file path") + + # If we are loading a series, we need to get the directory + if session.load_series: + session.file_path = os.path.dirname(session.file_path) return session @@ -348,6 +405,7 @@ class Session: def __init__(self): self.virtual_stack = False self.file_path = None + self.failed = None self.vol = None self.zpos = 0.5 self.ypos = 0.5 @@ -401,14 +459,27 @@ class Pipeline: self.session = None def load_data(self, session): + # skip loading if session failed at some point prior + if session.failed: + return session + try: session.vol = load( session.file_path, virtual_stack=session.virtual_stack, dataset_name=session.dataset_name, + contains=session.series_contains, ) + + # Incase the data is 4D (RGB for example), we take the mean of the last dimension + if session.vol.ndim == 4: + session.vol = np.mean(session.vol, axis=-1) + + # The rest of the pipeline expects 3D data if session.vol.ndim != 3: - raise ValueError("Invalid data shape should be 3 dimensional, not shape: ", session.vol.shape) + raise gr.Error( + f"Invalid data shape should be 3 dimensional, not shape: {session.vol.shape}" + ) except Exception as error_message: raise gr.Error( f"Failed to load the image: {error_message}" @@ -433,6 +504,10 @@ class Pipeline: return session def run_pipeline(self, session): + # skip loading if session failed at some point prior + if session.failed: + return [] + self.session = session outputs = [] log.info(session.all_operations) @@ -611,12 +686,15 @@ class Pipeline: def plot_vol_histogram(self): - # The Histogram needs results from the projections + # The Histogram needs results from the projections if not self.session.projections_calculated: _ = self.get_projections(self.session.vol) - + vol_hist, bin_edges = self.vol_histogram( - self.session.vol, self.session.nbins, self.session.min_value, self.session.max_value + self.session.vol, + self.session.nbins, + self.session.min_value, + self.session.max_value, ) fig, ax = plt.subplots(figsize=(6, 4)) @@ -649,11 +727,12 @@ class Pipeline: return vol_hist, bin_edges -def run_interface(host = "0.0.0.0"): + +def run_interface(host="0.0.0.0"): gradio_interface = Interface().create_interface() - internal_tools.run_gradio_app(gradio_interface,host) + internal_tools.run_gradio_app(gradio_interface, host) if __name__ == "__main__": # Creates interface - run_interface() \ No newline at end of file + run_interface() diff --git a/qim3d/io/loading.py b/qim3d/io/loading.py index 83b4c3d50415d32c576689a3ce8955ba25049fb0..0f2ccc09bd4dfa8c546d2162cfef97cc4f1e4a6a 100644 --- a/qim3d/io/loading.py +++ b/qim3d/io/loading.py @@ -458,11 +458,35 @@ class DataLoader: Args: path (str): Directory path """ - # loop over all .dcm files in the directory - files = [f for f in os.listdir(path) if f.endswith('.dcm')] - files.sort() + if not self.contains: + raise ValueError( + "Please specify a part of the name that is common for the DICOM file stack with the argument 'contains'" + ) + + dicom_stack = [ + file + for file in os.listdir(path) + if self.contains in file + ] + dicom_stack.sort() # Ensure proper ordering + + # Check that only one DICOM stack in the directory contains the provided string in its name + dicom_stack_only_letters = [] + for filename in dicom_stack: + name = os.path.splitext(filename)[0] # Remove file extension + dicom_stack_only_letters.append( + "".join(filter(str.isalpha, name)) + ) # Remove everything else than letters from the name + + # Get unique elements from tiff_stack_only_letters + unique_names = list(set(dicom_stack_only_letters)) + if len(unique_names) > 1: + raise ValueError( + f"The provided part of the filename for the DICOM stack matches multiple DICOM stacks: {unique_names}.\nPlease provide a string that is unique for the DICOM stack that is intended to be loaded" + ) + # dicom_list contains the dicom objects with metadata - dicom_list = [pydicom.dcmread(os.path.join(path, f)) for f in files] + dicom_list = [pydicom.dcmread(os.path.join(path, f)) for f in dicom_stack] # vol contains the pixel data vol = np.stack([dicom.pixel_array for dicom in dicom_list], axis=0) @@ -520,11 +544,11 @@ class DataLoader: # Load a directory elif os.path.isdir(path): - # load dicom if directory contains dicom files else load tiff stack as default - if any([f.endswith('.dcm') for f in os.listdir(path)]): - return self.load_dicom_dir(path) - else: + # load tiff stack if folder contains tiff files else load dicom directory + if any([f.endswith('.tif') or f.endswith('.tiff') for f in os.listdir(path)]): return self.load_tiff_stack(path) + else: + return self.load_dicom_dir(path) # Fails else: