From 28318fcac280acf47349d842c903fbfe74a9e99a Mon Sep 17 00:00:00 2001
From: Felipe <fima@dtu.dk>
Date: Fri, 26 Apr 2024 15:12:10 +0200
Subject: [PATCH] Hotfix for the scale_to_float16 function

---
 qim3d/utils/internal_tools.py | 65 +++++++++++++++--------------------
 1 file changed, 28 insertions(+), 37 deletions(-)

diff --git a/qim3d/utils/internal_tools.py b/qim3d/utils/internal_tools.py
index 113b66eb..cf1416f9 100644
--- a/qim3d/utils/internal_tools.py
+++ b/qim3d/utils/internal_tools.py
@@ -17,6 +17,7 @@ from fastapi import FastAPI
 import gradio as gr
 from uvicorn import run
 
+
 def mock_plot():
     """Creates a mock plot of a sine wave.
 
@@ -267,7 +268,8 @@ def get_port_dict():
 
     return port_dict
 
-def run_gradio_app(gradio_interface, host = "0.0.0.0"):
+
+def run_gradio_app(gradio_interface, host="0.0.0.0"):
 
     # Get port using the QIM API
     port_dict = get_port_dict()
@@ -278,7 +280,7 @@ def run_gradio_app(gradio_interface, host = "0.0.0.0"):
         port = port_dict["port"]
     else:
         raise Exception("Port not specified from QIM API")
-    
+
     gradio_header(gradio_interface.title, port)
 
     # Create FastAPI with mounted gradio interface
@@ -297,51 +299,40 @@ def get_css():
 
     current_directory = os.path.dirname(os.path.abspath(__file__))
     parent_directory = os.path.abspath(os.path.join(current_directory, os.pardir))
-    css_path = os.path.join(parent_directory,"css","gradio.css")
-    
-    with open(css_path,'r') as file:
+    css_path = os.path.join(parent_directory, "css", "gradio.css")
+
+    with open(css_path, "r") as file:
         css_content = file.read()
-    
+
     return css_content
 
 
-def scale_to_float16(arr):
+def scale_to_float16(arr: np.ndarray):
     """
-    Scale a NumPy array to fit within the limits of the float16 data type.
+    Scale the input array to the float16 data type.
 
     Parameters:
-    - arr (numpy.ndarray): The input array to be scaled.
+    arr (np.ndarray): Input array to be scaled.
 
     Returns:
-    - numpy.ndarray: The scaled array, with values adjusted to fit within the limits of float16 data type.
+    np.ndarray: Scaled array with dtype=np.float16.
 
-    This function takes a NumPy array as input and checks if its maximum and minimum values
-    exceed the limits of the float16 data type. If necessary, it scales the positive and negative
-    parts of the array independently to fit within the range of float16.
+    This function scales the input array to the float16 data type, ensuring that the
+    maximum value of the array does not exceed the maximum representable value
+    for float16. If the maximum value of the input array exceeds the maximum
+    representable value for float16, the array is scaled down proportionally
+    to fit within the float16 range.
     """
-        
-    # Determine maximum and minimum values of the array
+
+    # Get the maximum value to comprare with the float16 maximum value
     arr_max = np.max(arr)
-    arr_min = np.min(arr)
-    
-    # Check if scaling is necessary for positive and negative parts separately
-    if arr_max > np.finfo(np.float16).max:
-        pos_scaled_arr = np.interp(arr[arr >= 0], (0, arr_max), (0, np.finfo(np.float16).max))
-    else:
-        pos_scaled_arr = arr[arr >= 0].astype(np.float16)
-        
-    if arr_min < -np.finfo(np.float16).max:
-        neg_scaled_arr = np.interp(arr[arr < 0], (arr_min, 0), (-np.finfo(np.float16).max, 0))
-    else:
-        neg_scaled_arr = arr[arr < 0].astype(np.float16)
-    
-    # Combine the scaled positive and negative parts
-    scaled_arr = np.concatenate((neg_scaled_arr, pos_scaled_arr))
-
-    # Reshape the scaled array to match the original shape
-    scaled_arr = scaled_arr.reshape(arr.shape)
-    
+    float16_max = np.finfo(np.float16).max
+
+    # If the maximum value of the array exceeds the float16 maximum value, scale the array
+    if arr_max > float16_max:
+        arr = (arr / arr_max) * float16_max
+
     # Convert the scaled array to float16 data type
-    scaled_arr = scaled_arr.astype(np.float16)
-    
-    return scaled_arr
\ No newline at end of file
+    arr = arr.astype(np.float16)
+
+    return arr
-- 
GitLab