From 4b1565ac7b5b0e61fa028c398650cc9162db4a53 Mon Sep 17 00:00:00 2001 From: glia <glia@dtu.dk> Date: Tue, 13 Aug 2024 15:10:03 +0200 Subject: [PATCH] Upload helper.py --- helper.py | 180 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 helper.py diff --git a/helper.py b/helper.py new file mode 100644 index 0000000..ac6606f --- /dev/null +++ b/helper.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +""" +Generic helper functions + +@author: Qianliang Li (glia@dtu.dk) +""" +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib +import time + +# Function to get the current time in my preferred format +def time_now(): + c_time = time.localtime() + c_time = time.strftime("%a %d %b %Y %H:%M:%S", c_time) + return c_time + +def numpy_arr_to_pandas_df(array, col_names: list, col_values: list, dtypes: list): + """ Convert array to 2D Dataframe """ + # The dimensions of the array will each be a column with numbers + # and the last column will be the actual values + arr = np.column_stack(list(map(np.ravel, np.meshgrid(*map(np.arange, array.shape), + indexing="ij"))) + [array.ravel()]) + # Initialize the dataframe + df = pd.DataFrame(arr, columns = col_names) + # Change of the numerical coding to the actual values + temp_df = df.copy() # make temp df to not sequentially overwrite when modifying + for col in range(len(col_values)): + col_name = df.columns[col] + # Fix dtype + temp_df[col_name] = temp_df[col_name].astype(dtypes[col]) + # Insert col values + for shape in range(array.shape[col]): + temp_df.loc[df.iloc[:,col] == shape,col_name]\ + = col_values[col][shape] + df = temp_df # replace original df + return df + +def get_feature_indices(X, feature_name_list): + col_idx = np.zeros(len(X.columns), dtype=bool) + for fset in range(len(feature_name_list)): + temp_feat = feature_name_list[fset] + col_idx0 = X.columns.str.contains(temp_feat) + col_idx = np.logical_or(col_idx,col_idx0) # append all trues + return col_idx + +def heatmap(data, row_labels, col_labels, ax=None, rotate_x_labels=True, + cbar_kw={}, cbarlabel="", **kwargs): + """ + Create a heatmap from a numpy array and two lists of labels. + + Parameters + ---------- + data + A 2D numpy array of shape (N, M). + row_labels + A list or array of length N with the labels for the rows. + col_labels + A list or array of length M with the labels for the columns. + ax + A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If + not provided, use current axes or create a new one. Optional. + rotate_x_labels + Boolean variable for rotating the x labels on top of the figure. + cbar_kw + A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. + cbarlabel + The label for the colorbar. Optional. + **kwargs + All other arguments are forwarded to `imshow`. + """ + + if not ax: + ax = plt.gca() + + # Plot the heatmap + im = ax.imshow(data, **kwargs) + + # Create colorbar + cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) + cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") + + # We want to show all ticks... + ax.set_xticks(np.arange(data.shape[1])) + ax.set_yticks(np.arange(data.shape[0])) + # ... and label them with the respective list entries. + ax.set_xticklabels(col_labels) + ax.set_yticklabels(row_labels) + + # Let the horizontal axes labeling appear on top. + ax.tick_params(top=True, bottom=False, + labeltop=True, labelbottom=False) + + # Rotate the tick labels and set their alignment. + if rotate_x_labels == True: + plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", + rotation_mode="anchor") + + # Turn spines off and create white grid. + for edge, spine in ax.spines.items(): + spine.set_visible(False) + + ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) + ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) + ax.grid(which="minor", color="w", linestyle='-', linewidth=3) + ax.tick_params(which="minor", bottom=False, left=False) + + return im, cbar + +def annotate_heatmap(im, data=None, valfmt="{x:.2f}", + sem=None, p_value=None, + textcolors=["black", "white"], + threshold=None, **textkw): + """ + A function to annotate a heatmap. + + Parameters + ---------- + im + The AxesImage to be labeled. + data + Data used to annotate. If None, the image's data is used. Optional. + sem + Standard error or mean. In the same shape as Data. Optional. + SD can also be used, this input just insert the number as text. + valfmt + The format of the annotations inside the heatmap. This should either + use the string format method, e.g. "$ {x:.2f}", or be a + `matplotlib.ticker.Formatter`. Optional. + textcolors + A list or array of two color specifications. The first is used for + values below a threshold, the second for those above. Optional. + threshold + Value in data units according to which the colors from textcolors are + applied. If None (the default) uses the middle of the colormap as + separation. Optional. + **kwargs + All other arguments are forwarded to each call to `text` used to create + the text labels. + """ + + if not isinstance(data, (list, np.ndarray)): + data = im.get_array() + + # Normalize the threshold to the images color range. + if threshold is not None: + threshold = im.norm(threshold) + else: + threshold = im.norm(data.max())/2. + + # Set default alignment to center, but allow it to be + # overwritten by textkw. + kw = dict(horizontalalignment="center", + verticalalignment="center") + kw.update(textkw) + + # Get the formatter in case a string is supplied + if isinstance(valfmt, str): + valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) + + # Loop over the data and create a `Text` for each "pixel". + # Change the text's color depending on the data. + texts = [] + for i in range(data.shape[0]): + for j in range(data.shape[1]): + kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) + # Add text and SEM if not none + if sem is not None: + text = im.axes.text(j, i, (valfmt(data[i, j], None)+"\n"+u"\u00B1"+ + valfmt(sem[i, j], None)), **kw) + elif p_value is not None: + text = im.axes.text(j, i, (valfmt(data[i, j], None)+"\n"+"p = "+ + valfmt(p_value[i, j], None)), **kw) + else: + text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) + + texts.append(text) + + return texts -- GitLab