From 28318fcac280acf47349d842c903fbfe74a9e99a Mon Sep 17 00:00:00 2001 From: Felipe <fima@dtu.dk> Date: Fri, 26 Apr 2024 15:12:10 +0200 Subject: [PATCH] Hotfix for the scale_to_float16 function --- qim3d/utils/internal_tools.py | 65 +++++++++++++++-------------------- 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py index 113b66eb..cf1416f9 100644 --- a/qim3d/utils/internal_tools.py +++ b/qim3d/utils/internal_tools.py @@ -17,6 +17,7 @@ from fastapi import FastAPI import gradio as gr from uvicorn import run + def mock_plot(): """Creates a mock plot of a sine wave. @@ -267,7 +268,8 @@ def get_port_dict(): return port_dict -def run_gradio_app(gradio_interface, host = "0.0.0.0"): + +def run_gradio_app(gradio_interface, host="0.0.0.0"): # Get port using the QIM API port_dict = get_port_dict() @@ -278,7 +280,7 @@ def run_gradio_app(gradio_interface, host = "0.0.0.0"): port = port_dict["port"] else: raise Exception("Port not specified from QIM API") - + gradio_header(gradio_interface.title, port) # Create FastAPI with mounted gradio interface @@ -297,51 +299,40 @@ def get_css(): current_directory = os.path.dirname(os.path.abspath(__file__)) parent_directory = os.path.abspath(os.path.join(current_directory, os.pardir)) - css_path = os.path.join(parent_directory,"css","gradio.css") - - with open(css_path,'r') as file: + css_path = os.path.join(parent_directory, "css", "gradio.css") + + with open(css_path, "r") as file: css_content = file.read() - + return css_content -def scale_to_float16(arr): +def scale_to_float16(arr: np.ndarray): """ - Scale a NumPy array to fit within the limits of the float16 data type. + Scale the input array to the float16 data type. Parameters: - - arr (numpy.ndarray): The input array to be scaled. + arr (np.ndarray): Input array to be scaled. Returns: - - numpy.ndarray: The scaled array, with values adjusted to fit within the limits of float16 data type. + np.ndarray: Scaled array with dtype=np.float16. - This function takes a NumPy array as input and checks if its maximum and minimum values - exceed the limits of the float16 data type. If necessary, it scales the positive and negative - parts of the array independently to fit within the range of float16. + This function scales the input array to the float16 data type, ensuring that the + maximum value of the array does not exceed the maximum representable value + for float16. If the maximum value of the input array exceeds the maximum + representable value for float16, the array is scaled down proportionally + to fit within the float16 range. """ - - # Determine maximum and minimum values of the array + + # Get the maximum value to comprare with the float16 maximum value arr_max = np.max(arr) - arr_min = np.min(arr) - - # Check if scaling is necessary for positive and negative parts separately - if arr_max > np.finfo(np.float16).max: - pos_scaled_arr = np.interp(arr[arr >= 0], (0, arr_max), (0, np.finfo(np.float16).max)) - else: - pos_scaled_arr = arr[arr >= 0].astype(np.float16) - - if arr_min < -np.finfo(np.float16).max: - neg_scaled_arr = np.interp(arr[arr < 0], (arr_min, 0), (-np.finfo(np.float16).max, 0)) - else: - neg_scaled_arr = arr[arr < 0].astype(np.float16) - - # Combine the scaled positive and negative parts - scaled_arr = np.concatenate((neg_scaled_arr, pos_scaled_arr)) - - # Reshape the scaled array to match the original shape - scaled_arr = scaled_arr.reshape(arr.shape) - + float16_max = np.finfo(np.float16).max + + # If the maximum value of the array exceeds the float16 maximum value, scale the array + if arr_max > float16_max: + arr = (arr / arr_max) * float16_max + # Convert the scaled array to float16 data type - scaled_arr = scaled_arr.astype(np.float16) - - return scaled_arr \ No newline at end of file + arr = arr.astype(np.float16) + + return arr -- GitLab