Skip to content
Snippets Groups Projects
Commit 53c3a716 authored by s204159's avatar s204159 :sunglasses: Committed by fima
Browse files

Data explorer series

parent 94d7ec42
No related branches found
No related tags found
1 merge request!76Data explorer series
...@@ -17,6 +17,7 @@ app.launch() ...@@ -17,6 +17,7 @@ app.launch()
import datetime import datetime
import os import os
import re
import gradio as gr import gradio as gr
import matplotlib import matplotlib
...@@ -125,24 +126,49 @@ class Interface: ...@@ -125,24 +126,49 @@ class Interface:
render=True, render=True,
file_count="single", file_count="single",
interactive=True, interactive=True,
elem_classes="h-256 hide-overflow", elem_classes="h-320 hide-overflow",
) )
with gr.Column(scale=1): with gr.Column(scale=1):
gr.Markdown("### Parameters") gr.Markdown("### Parameters")
virtual_stack = gr.Checkbox(value=False, label="Virtual stack")
cmap = gr.Dropdown( cmap = gr.Dropdown(
value="viridis", value="viridis",
choices=plt.colormaps(), choices=plt.colormaps(),
label="Colormap", label="Colormap",
interactive=True, interactive=True,
) )
virtual_stack = gr.Checkbox(
value=False,
label="Virtual stack",
info="If checked, will use less memory by loading the images on demand.",
)
load_series = gr.Checkbox(
value=False,
label="Load series",
info="If checked, will load the whole series of images in the same folder as the selected file.",
)
series_contains = gr.Textbox(
label="Specify common part of file names for series",
value="",
visible=False,
)
dataset_name = gr.Textbox( dataset_name = gr.Textbox(
label="Dataset name (in case of H5 files, for example)", label="Dataset name (in case of H5 files, for example)",
value="exchange/data", value="exchange/data",
) )
def toggle_show(checkbox):
return (
gr.update(visible=True)
if checkbox
else gr.update(visible=False)
)
# Show series_contains only if load_series is checked
load_series.change(toggle_show, load_series, series_contains)
with gr.Column(scale=1): with gr.Column(scale=1):
gr.Markdown("### Operations") gr.Markdown("### Operations")
operations = gr.CheckboxGroup( operations = gr.CheckboxGroup(
...@@ -233,6 +259,8 @@ class Interface: ...@@ -233,6 +259,8 @@ class Interface:
cmap, cmap,
dataset_name, dataset_name,
virtual_stack, virtual_stack,
load_series,
series_contains,
] ]
# Outputs # Outputs
outputs = [ outputs = [
...@@ -243,7 +271,6 @@ class Interface: ...@@ -243,7 +271,6 @@ class Interface:
min_projection_plot, min_projection_plot,
hist_plot, hist_plot,
data_summary, data_summary,
] ]
### Listeners ### Listeners
...@@ -283,6 +310,8 @@ class Interface: ...@@ -283,6 +310,8 @@ class Interface:
def start_session(self, *args): def start_session(self, *args):
# Starts a new session dictionary # Starts a new session dictionary
session = Session() session = Session()
# Tells rest of the pipeline if the session failed at some point prior, so should skip the rest
session.failed = False
session.all_operations = Interface().operations session.all_operations = Interface().operations
session.operations = args[0] session.operations = args[0]
session.base_path = args[1] session.base_path = args[1]
...@@ -293,14 +322,42 @@ class Interface: ...@@ -293,14 +322,42 @@ class Interface:
session.cmap = args[6] session.cmap = args[6]
session.dataset_name = args[7] session.dataset_name = args[7]
session.virtual_stack = args[8] session.virtual_stack = args[8]
session.load_series = args[9]
session.series_contains = args[10]
if session.load_series and session.series_contains == "":
# Try to guess the common part of the file names
try:
filename = session.explorer.split("/")[-1] # Extract filename from path
series_contains = re.search(r"[^0-9]+", filename).group()
gr.Info(f"Using '{series_contains}' as common file name part for loading.")
session.series_contains = series_contains
except:
session.failed = True
raise gr.Error(
"For series, common part of file name must be provided in 'series_contains' field."
)
# Get the file path from the explorer or base path # Get the file path from the explorer or base path
if session.base_path and os.path.isfile(session.base_path): # priority is given to the explorer if file is selected
session.file_path = session.base_path # else the base path is used
elif session.explorer and os.path.isfile(session.explorer): if session.explorer and (
os.path.isfile(session.explorer) or session.load_series
):
session.file_path = session.explorer session.file_path = session.explorer
elif session.base_path and (
os.path.isfile(session.base_path) or session.load_series
):
session.file_path = session.base_path
else: else:
raise ValueError("Invalid file path") session.failed = True
raise gr.Error("Invalid file path")
# If we are loading a series, we need to get the directory
if session.load_series:
session.file_path = os.path.dirname(session.file_path)
return session return session
...@@ -348,6 +405,7 @@ class Session: ...@@ -348,6 +405,7 @@ class Session:
def __init__(self): def __init__(self):
self.virtual_stack = False self.virtual_stack = False
self.file_path = None self.file_path = None
self.failed = None
self.vol = None self.vol = None
self.zpos = 0.5 self.zpos = 0.5
self.ypos = 0.5 self.ypos = 0.5
...@@ -401,14 +459,27 @@ class Pipeline: ...@@ -401,14 +459,27 @@ class Pipeline:
self.session = None self.session = None
def load_data(self, session): def load_data(self, session):
# skip loading if session failed at some point prior
if session.failed:
return session
try: try:
session.vol = load( session.vol = load(
session.file_path, session.file_path,
virtual_stack=session.virtual_stack, virtual_stack=session.virtual_stack,
dataset_name=session.dataset_name, dataset_name=session.dataset_name,
contains=session.series_contains,
) )
# Incase the data is 4D (RGB for example), we take the mean of the last dimension
if session.vol.ndim == 4:
session.vol = np.mean(session.vol, axis=-1)
# The rest of the pipeline expects 3D data
if session.vol.ndim != 3: if session.vol.ndim != 3:
raise ValueError("Invalid data shape should be 3 dimensional, not shape: ", session.vol.shape) raise gr.Error(
f"Invalid data shape should be 3 dimensional, not shape: {session.vol.shape}"
)
except Exception as error_message: except Exception as error_message:
raise gr.Error( raise gr.Error(
f"Failed to load the image: {error_message}" f"Failed to load the image: {error_message}"
...@@ -433,6 +504,10 @@ class Pipeline: ...@@ -433,6 +504,10 @@ class Pipeline:
return session return session
def run_pipeline(self, session): def run_pipeline(self, session):
# skip loading if session failed at some point prior
if session.failed:
return []
self.session = session self.session = session
outputs = [] outputs = []
log.info(session.all_operations) log.info(session.all_operations)
...@@ -616,7 +691,10 @@ class Pipeline: ...@@ -616,7 +691,10 @@ class Pipeline:
_ = self.get_projections(self.session.vol) _ = self.get_projections(self.session.vol)
vol_hist, bin_edges = self.vol_histogram( vol_hist, bin_edges = self.vol_histogram(
self.session.vol, self.session.nbins, self.session.min_value, self.session.max_value self.session.vol,
self.session.nbins,
self.session.min_value,
self.session.max_value,
) )
fig, ax = plt.subplots(figsize=(6, 4)) fig, ax = plt.subplots(figsize=(6, 4))
...@@ -649,6 +727,7 @@ class Pipeline: ...@@ -649,6 +727,7 @@ class Pipeline:
return vol_hist, bin_edges return vol_hist, bin_edges
def run_interface(host="0.0.0.0"): def run_interface(host="0.0.0.0"):
gradio_interface = Interface().create_interface() gradio_interface = Interface().create_interface()
internal_tools.run_gradio_app(gradio_interface, host) internal_tools.run_gradio_app(gradio_interface, host)
......
...@@ -458,11 +458,35 @@ class DataLoader: ...@@ -458,11 +458,35 @@ class DataLoader:
Args: Args:
path (str): Directory path path (str): Directory path
""" """
# loop over all .dcm files in the directory if not self.contains:
files = [f for f in os.listdir(path) if f.endswith('.dcm')] raise ValueError(
files.sort() "Please specify a part of the name that is common for the DICOM file stack with the argument 'contains'"
)
dicom_stack = [
file
for file in os.listdir(path)
if self.contains in file
]
dicom_stack.sort() # Ensure proper ordering
# Check that only one DICOM stack in the directory contains the provided string in its name
dicom_stack_only_letters = []
for filename in dicom_stack:
name = os.path.splitext(filename)[0] # Remove file extension
dicom_stack_only_letters.append(
"".join(filter(str.isalpha, name))
) # Remove everything else than letters from the name
# Get unique elements from tiff_stack_only_letters
unique_names = list(set(dicom_stack_only_letters))
if len(unique_names) > 1:
raise ValueError(
f"The provided part of the filename for the DICOM stack matches multiple DICOM stacks: {unique_names}.\nPlease provide a string that is unique for the DICOM stack that is intended to be loaded"
)
# dicom_list contains the dicom objects with metadata # dicom_list contains the dicom objects with metadata
dicom_list = [pydicom.dcmread(os.path.join(path, f)) for f in files] dicom_list = [pydicom.dcmread(os.path.join(path, f)) for f in dicom_stack]
# vol contains the pixel data # vol contains the pixel data
vol = np.stack([dicom.pixel_array for dicom in dicom_list], axis=0) vol = np.stack([dicom.pixel_array for dicom in dicom_list], axis=0)
...@@ -520,11 +544,11 @@ class DataLoader: ...@@ -520,11 +544,11 @@ class DataLoader:
# Load a directory # Load a directory
elif os.path.isdir(path): elif os.path.isdir(path):
# load dicom if directory contains dicom files else load tiff stack as default # load tiff stack if folder contains tiff files else load dicom directory
if any([f.endswith('.dcm') for f in os.listdir(path)]): if any([f.endswith('.tif') or f.endswith('.tiff') for f in os.listdir(path)]):
return self.load_dicom_dir(path)
else:
return self.load_tiff_stack(path) return self.load_tiff_stack(path)
else:
return self.load_dicom_dir(path)
# Fails # Fails
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment