import os
import re
import warnings
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from measure import measure_cumulative_volume, measure_cumulative_points

def parse_params(path):
    with open(os.path.join(path, 'params.txt')) as f:
        lines = f.readlines()

    coords = []
    thickness = []

    for line in lines:
        if 'origin' in line:
            coords = [float(x) for x in re.findall('-?\d+\.?[0-9]*', line)]
        if 'thickness' in line:
            thickness = [float(x) for x in re.findall('-?\d+\.?[0-9]*', line)]

    if len(coords) != 3:
        print(f"found {len(coords)} coordinates when looking for 'origin' in params.txt")
    if len(thickness) != 1:
        print(f"found {len(coords)} coordinates when looking for 'origin' in params.txt")

    thickness = thickness[0]

    return coords, thickness

def load_contours(path):
    contour_dir = os.path.join(path, 'domain_contours')
    file_fnames = [fname for fname in os.listdir(contour_dir) if fname.lower().startswith('z_')]
    z_values = [float(file_fname[2:-4]) for file_fname in file_fnames]
    z_values, file_fnames = zip(*sorted(zip(z_values, file_fnames)))
    contours = []

    for file_fname, z in zip(file_fnames, z_values):
        if '.git' in file_fname:
            continue
        X = np.loadtxt(os.path.join(contour_dir, file_fname))
        Xp = np.pad(X, ((0,1), (0,0)), mode='wrap')
        contours.append((z, X))      

    return contours


def load_points(path):
    point_dirs = [os.path.join(path, path_name) for path_name in os.listdir(path) if path_name.startswith('points_marker_')]
    marker_names = [dir_name.split('points_marker_')[-1] for dir_name in point_dirs]

    points = defaultdict(list)
    for point_dir, marker_name in zip(point_dirs, marker_names):
        file_fnames = [fname for fname in os.listdir(point_dir) if '.git' not in fname]
        z_values = [float(file_fname[2:-4]) for file_fname in file_fnames]
        z_values, file_fnames = zip(*sorted(zip(z_values, file_fnames)))
        for file_fname, z in zip(file_fnames, z_values):
            X = np.loadtxt(os.path.join(point_dir, file_fname))
            points[marker_name].append((z, X))            

    return points

def write_csv(path, R, dP, dV):

    intdP = dP[1:] - dP[:-1]
    intdV = dV[1:] - dV[:-1]

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        Dens = dP / dV
        intDens = intdP / intdV

    with open(path, 'w') as f:
        f.write(','.join(['Cummulative Reading','','','','','Interval Reading\n']))
        f.write(','.join(['Distance','Point Counts','Volume (mu^3)','Density (n/mu^3)','','Interval start (mu)',
                          'Interval end (mu)','Point Counts (n)','Volume (mu^3)','Density (n/mu^3)\n']))
        for i0,i1,a,b,c,d,e,v in zip(R[:-1], R[1:], dP, dV, Dens, intdP, intdV, intDens):
            f.write(f'{i0},{a},{b},{c},,{i0},{i1},{d},{e},{v}\n')

def run(base_dir, output_dir):

    R = np.arange(0, 3001, 125)

    
    os.makedirs(output_dir, exist_ok=True)

    for experiment in tqdm(os.listdir(base_dir), desc='experiment'):
        if '.git' in experiment:
            continue
        for reading in os.listdir(os.path.join(base_dir, experiment)):

            data_dir = os.path.join(base_dir, experiment, reading)
            if '.git' in data_dir:
                continue
            origin, thickness = parse_params(data_dir)
            
            contours = sorted(load_contours(data_dir))
            points = load_points(data_dir)
            for marker in points.keys():
                all_points = sorted(points[marker])
                dV = []
                dP = []
                for sec_contour, sec_points in zip(contours, all_points):
                    if sec_contour[0] != sec_points[0]:
                        raise ValueError('z-coordinate misalignment, does points and contour have corresponding z-value')
                    z, P = sec_contour
                    dV.append(measure_cumulative_volume(z, P, thickness, origin, R))
                    z, X = sec_points
                    dP.append(measure_cumulative_points(z, X, thickness, origin, R))
                
                dP = np.sum(dP, axis=0)
                dV = np.sum(dV, axis=0)

                fname = f'{experiment}_{reading}_{marker}.csv'
                path = os.path.join(output_dir, fname)

                write_csv(path, R, dP, dV)

if __name__ == '__main__':
    base_dir = 'data'
    output_dir = 'output'
    run(base_dir, output_dir)