From 4b1565ac7b5b0e61fa028c398650cc9162db4a53 Mon Sep 17 00:00:00 2001
From: glia <glia@dtu.dk>
Date: Tue, 13 Aug 2024 15:10:03 +0200
Subject: [PATCH] Upload helper.py

---
 helper.py | 180 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 180 insertions(+)
 create mode 100644 helper.py

diff --git a/helper.py b/helper.py
new file mode 100644
index 0000000..ac6606f
--- /dev/null
+++ b/helper.py
@@ -0,0 +1,180 @@
+# -*- coding: utf-8 -*-
+"""
+Generic helper functions
+
+@author: Qianliang Li (glia@dtu.dk)
+"""
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+import matplotlib
+import time
+
+# Function to get the current time in my preferred format
+def time_now():
+    c_time = time.localtime()
+    c_time = time.strftime("%a %d %b %Y %H:%M:%S", c_time)
+    return c_time
+
+def numpy_arr_to_pandas_df(array, col_names: list, col_values: list, dtypes: list):
+    """ Convert array to 2D Dataframe """
+    # The dimensions of the array will each be a column with numbers
+    # and the last column will be the actual values
+    arr = np.column_stack(list(map(np.ravel, np.meshgrid(*map(np.arange, array.shape),
+                                                         indexing="ij"))) + [array.ravel()])
+    # Initialize the dataframe    
+    df = pd.DataFrame(arr, columns = col_names)
+    # Change of the numerical coding to the actual values
+    temp_df = df.copy() # make temp df to not sequentially overwrite when modifying
+    for col in range(len(col_values)):
+        col_name = df.columns[col]
+        # Fix dtype
+        temp_df[col_name] = temp_df[col_name].astype(dtypes[col])
+        # Insert col values
+        for shape in range(array.shape[col]):
+            temp_df.loc[df.iloc[:,col] == shape,col_name]\
+            = col_values[col][shape]
+    df = temp_df # replace original df
+    return df
+
+def get_feature_indices(X, feature_name_list):
+    col_idx = np.zeros(len(X.columns), dtype=bool)
+    for fset in range(len(feature_name_list)):
+        temp_feat = feature_name_list[fset]    
+        col_idx0 = X.columns.str.contains(temp_feat)
+        col_idx = np.logical_or(col_idx,col_idx0) # append all trues
+    return col_idx
+
+def heatmap(data, row_labels, col_labels, ax=None, rotate_x_labels=True,
+            cbar_kw={}, cbarlabel="", **kwargs):
+    """
+    Create a heatmap from a numpy array and two lists of labels.
+
+    Parameters
+    ----------
+    data
+        A 2D numpy array of shape (N, M).
+    row_labels
+        A list or array of length N with the labels for the rows.
+    col_labels
+        A list or array of length M with the labels for the columns.
+    ax
+        A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
+        not provided, use current axes or create a new one.  Optional.
+    rotate_x_labels
+        Boolean variable for rotating the x labels on top of the figure.
+    cbar_kw
+        A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
+    cbarlabel
+        The label for the colorbar.  Optional.
+    **kwargs
+        All other arguments are forwarded to `imshow`.
+    """
+
+    if not ax:
+        ax = plt.gca()
+
+    # Plot the heatmap
+    im = ax.imshow(data, **kwargs)
+
+    # Create colorbar
+    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
+    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
+
+    # We want to show all ticks...
+    ax.set_xticks(np.arange(data.shape[1]))
+    ax.set_yticks(np.arange(data.shape[0]))
+    # ... and label them with the respective list entries.
+    ax.set_xticklabels(col_labels)
+    ax.set_yticklabels(row_labels)
+
+    # Let the horizontal axes labeling appear on top.
+    ax.tick_params(top=True, bottom=False,
+                   labeltop=True, labelbottom=False)
+
+    # Rotate the tick labels and set their alignment.
+    if rotate_x_labels == True:
+        plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
+                 rotation_mode="anchor")
+
+    # Turn spines off and create white grid.
+    for edge, spine in ax.spines.items():
+        spine.set_visible(False)
+
+    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
+    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
+    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
+    ax.tick_params(which="minor", bottom=False, left=False)
+
+    return im, cbar
+
+def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
+                     sem=None, p_value=None,
+                     textcolors=["black", "white"],
+                     threshold=None, **textkw):
+    """
+    A function to annotate a heatmap.
+
+    Parameters
+    ----------
+    im
+        The AxesImage to be labeled.
+    data
+        Data used to annotate.  If None, the image's data is used.  Optional.
+    sem
+        Standard error or mean. In the same shape as Data. Optional.
+        SD can also be used, this input just insert the number as text.
+    valfmt
+        The format of the annotations inside the heatmap.  This should either
+        use the string format method, e.g. "$ {x:.2f}", or be a
+        `matplotlib.ticker.Formatter`.  Optional.
+    textcolors
+        A list or array of two color specifications.  The first is used for
+        values below a threshold, the second for those above.  Optional.
+    threshold
+        Value in data units according to which the colors from textcolors are
+        applied.  If None (the default) uses the middle of the colormap as
+        separation.  Optional.
+    **kwargs
+        All other arguments are forwarded to each call to `text` used to create
+        the text labels.
+    """
+
+    if not isinstance(data, (list, np.ndarray)):
+        data = im.get_array()
+
+    # Normalize the threshold to the images color range.
+    if threshold is not None:
+        threshold = im.norm(threshold)
+    else:
+        threshold = im.norm(data.max())/2.
+
+    # Set default alignment to center, but allow it to be
+    # overwritten by textkw.
+    kw = dict(horizontalalignment="center",
+              verticalalignment="center")
+    kw.update(textkw)
+
+    # Get the formatter in case a string is supplied
+    if isinstance(valfmt, str):
+        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
+
+    # Loop over the data and create a `Text` for each "pixel".
+    # Change the text's color depending on the data.
+    texts = []
+    for i in range(data.shape[0]):
+        for j in range(data.shape[1]):
+            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
+            # Add text and SEM if not none
+            if sem is not None:
+                text = im.axes.text(j, i, (valfmt(data[i, j], None)+"\n"+u"\u00B1"+
+                                        valfmt(sem[i, j], None)), **kw)
+            elif p_value is not None:
+                text = im.axes.text(j, i, (valfmt(data[i, j], None)+"\n"+"p = "+
+                                        valfmt(p_value[i, j], None)), **kw)
+            else:            
+                text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
+            
+            texts.append(text)
+
+    return texts
-- 
GitLab