#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep  1 17:39:05 2024

@author: Maya Coulson Theodorsen (mcoth@dtu.dk)

This script calculates descriptive statistics for variables across the entire dataset 
and by cluster. Outputs are formatted for LaTeX tables, including median and 
interquartile range (IQR) for continuous variables and frequencies/percentages for 
categorical variables.

"""
import pandas as pd

def total_descriptives(data_complete, questionnaireClusters, categorical_variables, continuous_variables, binary_variables, sorter):
    # Median and IQR
    # Note: Output intended for LaTeX
    descriptives = {}

    # Descriptives for categorical variables
    for var, label, df in categorical_variables:
        # Get N and %
        if var in binary_variables:
            count = df[var].value_counts()
            percent = df[var].value_counts(normalize=True) * 100
            if 1 in count:
                cnt = count[1]
                pct = percent[1]
                descriptives[label] = f"{cnt} ({pct:.1f}\\%)"
        else:
            # Calculate median and IQR
            median = df[var].median()
            q25 = df[var].quantile(0.25)
            q75 = df[var].quantile(0.75)
            descriptives[label] = f"{median:.1f} ({q25:.1f}–{q75:.1f})"

    # Descriptives for continuous variables
    for var, label, df in continuous_variables:
        median = df[var].median()
        q25 = df[var].quantile(0.25)
        q75 = df[var].quantile(0.75)
        descriptives[label] = f"{median:.1f} ({q25:.1f}–{q75:.1f})"

    # Convert to a DataFrame for easier display
    descriptives_total = pd.DataFrame(descriptives.items(), columns=['Variable', 'Median(IQR)/N(%)'])
    descriptives_total = descriptives_total.set_index('Variable')
    descriptives_total = descriptives_total.reindex(sorter)
    descriptives_total = descriptives_total.reset_index()

    return descriptives_total


def cluster_descriptives(data_complete, questionnaireClusters, categorical_variables, continuous_variables, cluster_column, binary_variables, sorter):

    # Empty dictionary for results
    descriptives = {}

    # Group data by cluster 
    grouped_data_complete = data_complete.groupby(cluster_column)
    grouped_data_questionnaire = questionnaireClusters.groupby(cluster_column)

    # Combine both grouped datasets to handle variables from either df
    all_grouped_data = {'data_complete': grouped_data_complete, 'questionnaireClusters': grouped_data_questionnaire}
    
    # Loop through categorical variables
    for var, label, df in categorical_variables:
        descriptives[label] = {}
        grouped_data = all_grouped_data['data_complete'] if df is data_complete else all_grouped_data['questionnaireClusters']

        for cluster, cluster_data in grouped_data:
            if var in cluster_data.columns:
                if var in binary_variables:
                    count = cluster_data[var].value_counts()
                    percent = cluster_data[var].value_counts(normalize=True) * 100
                    if 1 in count:
                        cnt = count[1]
                        pct = percent[1]
                        descriptives[label][f'Cluster {cluster}'] = f"{cnt} ({pct:.1f}\\%)"
                else:
                    mean = cluster_data[var].mean()
                    std = cluster_data[var].std()
                    descriptives[label][f'Cluster {cluster}'] = f"{mean:.1f} ({std:.1f})"

    
    # Loop through continuous variables
    for var, label, df in continuous_variables:
        descriptives[label] = {}
        grouped_data = all_grouped_data['data_complete'] if df is data_complete else all_grouped_data['questionnaireClusters']

        for cluster, cluster_data in grouped_data:
            if var in cluster_data.columns:
                # Calculate median, 25th & 75th percentile
                median = cluster_data[var].median()
                q25 = cluster_data[var].quantile(0.25)
                q75 = cluster_data[var].quantile(0.75)
                descriptives[label][f'Cluster {cluster}'] = f"{median:.1f} ({q25:.1f}–{q75:.1f})"
    

    # Convert to a DataFrame 
    descriptives_cluster = pd.DataFrame(descriptives).T  # Transpose for correct format
    descriptives_cluster = descriptives_cluster.reindex(sorter) # Reorder variables according to sorter

    return descriptives_cluster