From d614e7bd5ca20bd633c4e08bcf7f880431da1902 Mon Sep 17 00:00:00 2001 From: s184058 <s184058@student.dtu.dk> Date: Thu, 26 Oct 2023 16:23:34 +0200 Subject: [PATCH] Implemented beta version of annoation tool + unit test --- docs/notebooks/annotation_tool.ipynb | 149 +++++++++ qim3d/gui/__init__.py | 3 +- qim3d/gui/annotation_tool.py | 300 ++++++++++++++++++ qim3d/tests/gui/test_annotation_tool.py | 36 +++ ...{test_local_thickness.py => test_iso3d.py} | 0 5 files changed, 487 insertions(+), 1 deletion(-) create mode 100644 docs/notebooks/annotation_tool.ipynb create mode 100644 qim3d/gui/annotation_tool.py create mode 100644 qim3d/tests/gui/test_annotation_tool.py rename qim3d/tests/gui/{test_local_thickness.py => test_iso3d.py} (100%) diff --git a/docs/notebooks/annotation_tool.ipynb b/docs/notebooks/annotation_tool.ipynb new file mode 100644 index 00000000..066b0766 --- /dev/null +++ b/docs/notebooks/annotation_tool.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Image annotation tool\n", + "This notebook shows how the annotation interface can be used to create masks for images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import qim3d\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import numpy as np\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load 2D example image\n", + "img = qim3d.examples.blobs_256x256\n", + "\n", + "# Display image\n", + "plt.imshow(img)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start annotation tool\n", + "interface = qim3d.gui.annotation_tool.Interface()\n", + "interface.max_masks = 4\n", + "\n", + "# We can directly pass the image we loaded to the interface\n", + "interface.launch(img=img)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# When 'prepare mask for download' is pressed once, the mask can be retrieved with the get_result() method\n", + "mask = interface.get_result()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check the obtained mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print (f\"Original image shape..: {img.shape}\")\n", + "print (f\"Mask image shape......: {mask.shape}\")\n", + "print (f\"\\nNumber of masks: {np.max(mask)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Show the masked regions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "nmasks = np.max(mask)\n", + "fig, axs = plt.subplots(nrows=1, ncols=nmasks+2, figsize=(12,3))\n", + "\n", + "# Show original image\n", + "axs[0].imshow(img)\n", + "axs[0].set_title(\"Original\")\n", + "axs[0].axis('off')\n", + "\n", + "\n", + "# Show masks\n", + "cmap = mpl.colormaps[\"rainbow\"].copy()\n", + "cmap.set_under(color='black') # Sets the background to black\n", + "axs[1].imshow(mask, interpolation='none', cmap=cmap, vmin=1, vmax=nmasks+1)\n", + "axs[1].set_title(\"Masks\")\n", + "axs[1].axis('off')\n", + "\n", + "# Show masked regions\n", + "for idx in np.arange(2, nmasks+2):\n", + " mask_id = idx-1\n", + " submask = mask.copy()\n", + " submask[submask != mask_id] = 0\n", + " \n", + " masked_img = img.copy()\n", + " masked_img[submask==0] = 0\n", + " axs[idx].imshow(masked_img)\n", + " axs[idx].set_title(f\"Mask {mask_id}\")\n", + "\n", + " axs[idx].axis('off')\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/qim3d/gui/__init__.py b/qim3d/gui/__init__.py index a438bee9..e06589b4 100644 --- a/qim3d/gui/__init__.py +++ b/qim3d/gui/__init__.py @@ -1,3 +1,4 @@ from . import data_explorer from . import iso3d -from . import local_thickness \ No newline at end of file +from . import local_thickness +from . import annotation_tool \ No newline at end of file diff --git a/qim3d/gui/annotation_tool.py b/qim3d/gui/annotation_tool.py new file mode 100644 index 00000000..ae44f498 --- /dev/null +++ b/qim3d/gui/annotation_tool.py @@ -0,0 +1,300 @@ +import tifffile +import os +import numpy as np +import gradio as gr +from qim3d.io import load # load or DataLoader? + +class Interface: + def __init__(self): + self.verbose = False + self.title = "Annotation tool" + #self.plot_height = 768 + self.height = 1024 + #self.width = 960 + self.max_masks = 3 + self.mask_opacity = 0.5 + self.cmy_hex = ['#00ffff','#ff00ff','#ffff00'] # Colors for max_masks>3? + + # CSS path + current_dir = os.path.dirname(os.path.abspath(__file__)) + self.css_path = os.path.join(current_dir, "..", "css", "gradio.css") + + def launch(self, img=None, **kwargs): + # Create gradio interfaces + self.interface = self.create_interface(img=img) + + # Set gradio verbose level + if self.verbose: + quiet = False + else: + quiet = True + + self.interface.launch( + quiet=quiet, + height=self.height, + #width=self.width, + show_tips=False, + **kwargs + ) + + return + + + def get_result(self): + # Get the temporary files from gradio + temp_sets = self.interface.temp_file_sets + for temp_set in temp_sets: + if "mask" in str(temp_set): + # Get the list of the temporary files + temp_path_list = list(temp_set) + + # Files are not in creation order, + # so we need to get find the latest + creation_time_list = [] + for path in temp_path_list: + creation_time_list.append(os.path.getctime(path)) + + # Get index for the latest file + file_idx = np.argmax(creation_time_list) + + # Load the temporary file + mask = load(temp_path_list[file_idx]) + + return mask + + + def create_interface(self, img=None): + with gr.Blocks(css=self.css_path) as gradio_interface: + masks_state = gr.State(value={}) + counts = gr.Number(value=1,visible=False) + + with gr.Row(): + with gr.Column(scale=1,min_width=320): + upload_img_btn = gr.UploadButton( + label='Upload image', + file_types=['image'], + interactive=True if img is None else False + ) + clear_img_btn = gr.Button( + value='Clear image', + interactive=False if img is None else True + ) + + with gr.Row(): + with gr.Column(scale=2,min_width=32): + selected_mask = gr.Radio( + choices = ["Mask 1"], + value = "Mask 1", + label="Choose which mask to draw", + scale=1 + ) + with gr.Column(scale=1,min_width=64): + add_mask_btn = gr.Button( + value='Add mask', + scale=2, + ) + with gr.Row(): + prep_dl_btn = gr.Button( + value='Prepare mask for download', + visible=False if img is None else True + ) + with gr.Row(): + save_output = gr.File( + show_label=True, + label="Output file", + visible=False, + ) + + with gr.Column(scale=4): + with gr.Row(): + input_img = gr.Image( + label="Input", + tool='sketch', + value=img, + height=600, + width=600, + brush_color='#00ffff', + mask_opacity=self.mask_opacity, + interactive=False if img is None else True + ) + + output_masks = [] + for mask_idx in range(self.max_masks): + with gr.Row(): # make a new row for every mask + output_mask=gr.Image( + label=f"Mask {mask_idx+1}", + visible=True if mask_idx==0 else False, + image_mode='L', + height=600, + width=600, + interactive=False if img is None else True, # If statement added bc of bug after Gradio 3.44.x + show_download_button=False + ) + output_masks.append(output_mask) + + # Operations + operations = Operations(max_masks=self.max_masks,cmy_hex=self.cmy_hex) + + # Update component configuration when image is uploaded + upload_img_btn.upload(fn=operations.upload_img_update, + inputs=upload_img_btn, + outputs=[input_img,clear_img_btn,upload_img_btn,prep_dl_btn] + output_masks + ) + + # Add mask below when 'add mask' button is clicked + add_mask_btn.click( + fn=operations.increment_mask, + inputs=counts, + outputs=[counts, selected_mask] + output_masks + ) + + # Draw mask when input image is edited + input_img.edit( + fn=operations.update_masks, + inputs=[input_img,selected_mask,masks_state,upload_img_btn], + outputs=output_masks + ) + + # Update brush color according to radio setting + selected_mask.change( + fn=operations.update_brush_color, + inputs=selected_mask,outputs=input_img + ) + + # Make file download visible + prep_dl_btn.click( + fn=operations.save_mask, + inputs=output_masks, + outputs=[save_output,save_output] + ) + + # Update 'Add mask' button interactivit according to the current count + counts.change( + fn=operations.set_add_mask_btn_interactivity, + inputs=counts, + outputs=add_mask_btn + ) + + # Reset component configuration when image is cleared + clear_img_btn.click( + fn=operations.clear_img_update, + inputs=None, + outputs=[selected_mask,prep_dl_btn,save_output,counts,input_img,upload_img_btn,clear_img_btn] + output_masks + ) + + return gradio_interface + +class Operations: + def __init__(self, max_masks, cmy_hex): + self.max_masks = max_masks + self.cmy_hex = cmy_hex + + + def update_masks(self,input_img,selected_mask,masks_state,file): + # Binarize mask (it is not per default due to anti-aliasing) + input_mask = input_img['mask'] + input_mask[input_mask>0]=255 + + try: + file_name = file.name + except AttributeError: + file_name = 'nb_img' + + # Add new file to state dictionary when this function sees it first time + if file_name not in masks_state.keys(): + masks_state[file_name]=[[] for _ in range(self.max_masks)] + + # Get index of currently selected and non-selected masks + sel_mask_idx = int(selected_mask[-1])-1 + nonsel_mask_idxs = [mask_idx for mask_idx in list(range(self.max_masks)) if mask_idx != sel_mask_idx] + + # Add background to state first time function is invoked in current session + if len(masks_state[file_name][0])==0: + for i in range(len(masks_state[file_name])): + masks_state[file_name][i].append(input_mask) + + # Check for discrepancy between what is drawn and what is shown as output masks + masks_state_combined = 0 + for i in range(len(masks_state[file_name])): + masks_state_combined+=masks_state[file_name][i][-1] + discrepancy = masks_state_combined!=input_mask + if np.any(discrepancy): # Correct discrepancy in output masks + for i in range(self.max_masks): + masks_state[file_name][i][-1][discrepancy]=0 + + # Add most recent change in input to currently selected mask + mask2append = input_mask + for mask_idx in nonsel_mask_idxs: + mask2append -= masks_state[file_name][mask_idx][-1] + masks_state[file_name][sel_mask_idx].append(mask2append) + + return [masks_state[file_name][i][-1] for i in range(self.max_masks)] + + def save_mask(self,*masks): + # Go from multi-channel to single-channel mask + stacked_masks = np.stack(masks,axis=-1) + final_mask = np.zeros_like(masks[0]) + final_mask[np.where(stacked_masks==255)[:2]]=np.where(stacked_masks==255)[-1]+1 + + # Save output image in a temp space (and to current directory which is a bug) + filename = "mask.tif" + tifffile.imwrite(filename,final_mask) + + save_output_update = gr.File(visible=True) + + return save_output_update, filename + + def increment_mask(self,counts): + # increment count by 1 + counts+=1 + counts=int(counts) + + counts_update = gr.Number(value=counts) + selected_mask_update = gr.Radio(value = f"Mask {counts}", choices = [f"Mask {i+1}" for i in range(counts)]) + output_masks_update = [gr.Image(visible=True)]*counts + [gr.Image(visible=False)]*(self.max_masks-counts) + + return [counts_update, selected_mask_update] + output_masks_update + + def update_brush_color(self,selected_mask): + sel_mask_idx = int(selected_mask[-1])-1 + if sel_mask_idx<len(self.cmy_hex): + input_img_update = gr.Image(brush_color=self.cmy_hex[sel_mask_idx]) + else: + input_img_update = gr.Image(brush_color='#000000') # Return black brush + + return input_img_update + + def set_add_mask_btn_interactivity(self,counts): + add_mask_btn_update = gr.Button(interactive=True) if counts<self.max_masks else gr.Button(interactive=False) + return add_mask_btn_update + + def clear_img_update(self): + selected_mask_update = gr.Radio(choices = ["Mask 1"], value = "Mask 1") # Reset radio component to only show 'Mask 1' + prep_dl_btn_update = gr.Button(visible=False) # Make 'Prepare mask for download' button invisible + save_output_update = gr.File(visible=False) # Make File save box invisible + counts_update = gr.Number(value=1) # Reset invisible counter to 1 + input_img_update = gr.Image(value=None,interactive=False) # Set input image component to non-interactive (so a new image cannot be uploaded directly in the component) + upload_img_btn_update = gr.Button(interactive=True) # Make 'Upload image' button interactive + clear_img_btn_update = gr.Button(interactive=False) # Make 'Clear image' button non-interactive + output_masks_update = [gr.Image(value=None,visible=True if i==0 else False,interactive=False) for i in range(self.max_masks)] # Remove drawn masks and set as invisible except mask 1. 'interactive=False' added bc of bug after Gradio 3.44.x + + return [selected_mask_update, + prep_dl_btn_update, + save_output_update, + counts_update, + input_img_update, + upload_img_btn_update, + clear_img_btn_update] + output_masks_update + + def upload_img_update(self,file): + input_img_update = gr.Image(value=load(file.name),interactive=True) # Upload image from button to Image components + clear_img_btn_update = gr.Button(interactive=True) # Make 'Clear image' button interactive + upload_img_btn_update = gr.Button(interactive=False) # Make 'Upload image' button non-interactive + prep_dl_btn_update = gr.Button(visible=True) # Make 'Prepare mask for download' button visible + output_masks_update = [gr.Image(interactive=True)]*self.max_masks # This line is added bc of bug in Gradio after 3.44.x + + return [input_img_update, + clear_img_btn_update, + upload_img_btn_update, + prep_dl_btn_update] + output_masks_update + \ No newline at end of file diff --git a/qim3d/tests/gui/test_annotation_tool.py b/qim3d/tests/gui/test_annotation_tool.py new file mode 100644 index 00000000..65c02931 --- /dev/null +++ b/qim3d/tests/gui/test_annotation_tool.py @@ -0,0 +1,36 @@ +import qim3d +import multiprocessing +import time + + +def test_starting_class(): + app = qim3d.gui.annotation_tool.Interface() + + assert app.title == "Annotation tool" + + +def test_app_launch(): + ip = "0.0.0.0" + port = 65432 + + def start_server(ip, port): + app = qim3d.gui.annotation_tool.Interface() + app.launch(server_name=ip, server_port=port) + + proc = multiprocessing.Process(target=start_server, args=(ip, port)) + proc.start() + + # App is running in a separate process + # So we try to get a response for a while + max_checks = 5 + check = 0 + server_running = False + while check < max_checks and not server_running: + server_running = qim3d.utils.internal_tools.is_server_running(ip, port) + time.sleep(1) + check += 1 + + # Terminate tre process before assertions + proc.terminate() + + assert server_running is True diff --git a/qim3d/tests/gui/test_local_thickness.py b/qim3d/tests/gui/test_iso3d.py similarity index 100% rename from qim3d/tests/gui/test_local_thickness.py rename to qim3d/tests/gui/test_iso3d.py -- GitLab