#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 10 18:03:32 2018
This file is the new version for calculating the uncertainty value in each patch
It's better because:
    1. the way of choosing the most uncertain patch is automate
    2. The weight ratio for each regions can be easily changed to any value
    @author: s161488
"""
import numpy as np
from scipy import signal, ndimage
from skimage.morphology import dilation, disk


def select_most_uncertain_patch(x_image_pl, y_label_pl, fb_pred, ed_pred, fb_prob_mean_bald, kernel_window, stride_size,
                                already_select_image_index, previously_selected_binary_mask, num_most_uncert_patch,
                                method):
    """This function is used to acquire the #most uncertain patches in the pooling set.
    Args:
        x_image_pl: [Num_Im, Im_h, Im_w,3]
        y_label_pl: [Num_Im, Im_h, Im_w,1]
        fb_pred: [Num_Im, Im_h, Im_w, 2]
        ed_pred: [Num_Im, Im_h, Im_w, 2]
        fb_prob_mean_bald: [num_im, imw, imw]
        kernel_window: [kh, kw] determine the size of the region
        stride_size: int, determine the stride between every two regions
        already_select_image_index: if it's None, then it means that's the first acquistion step,
            otherwise it's the numeric image index for the previously selected patches
        previously_selected_binary_mask: [num_already_selected_images, Im_h, Im_w,1]
        num_most_uncert_patch: int, number of patches that are selected in each acquisition step
        method: acquisition method: 'B', 'C', 'D'
    Returns:
        Most_Uncert_Im: [Num_Selected, Im_h, Im_w, 3]imp
        Most_Uncert_FB_GT: [Num_Selected, Im_h, Im_w,1]
        Most_Uncert_ED_GT: [Num_Selected, Im_h, Im_w,1]
        Most_Uncert_Binary_Mask: [Num_Selected, Im_h, Im_w,1]
        Selected_Image_Index: [Num_Selected]
    """
    num_im = np.shape(x_image_pl)[0]
    uncertainty_map_tot = []
    for i in range(num_im):
        if method == 'B':
            var_stat = get_uncert_heatmap(x_image_pl[i], fb_pred[i])
        elif method == 'C':
            var_stat = get_entropy_heatmap(fb_pred[i])
        elif method == 'D':
            var_stat = get_bald_heatmap(fb_prob_mean_bald[i], fb_pred[i])
        uncertainty_map_tot.append(var_stat)
    uncertainty_map_tot = np.array(uncertainty_map_tot)
    if already_select_image_index is None:
        print("--------This is the beginning of the selection process-------")
    else:
        print(
            "----------Some patches have already been annotated, I need to deal with that")
        previously_selected_binary_mask = np.squeeze(previously_selected_binary_mask, axis=-1)
        for i in range(np.shape(previously_selected_binary_mask)[0]):
            uncertainty_map_single = uncertainty_map_tot[already_select_image_index[i]]
            uncertainty_map_updated = uncertainty_map_single * (1 - previously_selected_binary_mask[i])
            uncertainty_map_tot[already_select_image_index[i]] = uncertainty_map_updated
    selected_numeric_image_index, binary_mask_updated_tot = calculate_score_for_patch(uncertainty_map_tot,
                                                                                      kernel_window, stride_size,
                                                                                      num_most_uncert_patch)
    pseudo_fb_la_tot = []
    pseudo_ed_la_tot = []
    for index, single_selected_image_index in enumerate(selected_numeric_image_index):
        pseudo_fb_la, pseudo_ed_la = return_pseudo_label(y_label_pl[single_selected_image_index],
                                                         fb_pred[single_selected_image_index],
                                                         ed_pred[single_selected_image_index],
                                                         binary_mask_updated_tot[index])

        pseudo_fb_la_tot.append(pseudo_fb_la)
        pseudo_ed_la_tot.append(pseudo_ed_la)
    most_uncert_im_tot = x_image_pl[selected_numeric_image_index]
    most_uncertain = [most_uncert_im_tot,
                      pseudo_fb_la_tot,
                      pseudo_ed_la_tot,
                      binary_mask_updated_tot,
                      selected_numeric_image_index]
    return most_uncertain


def calculate_score_for_patch(uncert_est, kernel, stride_size, num_most_uncertain_patch):
    """This function is used to calculate the utility score for each patch.
    Args:
        uncert_est: [num_image, imh, imw]
        kernel: the size of each searching shape
        stride_size: the stride between every two regions
        num_most_uncertain_patch: int, the number of selected regions
    Returns:
        most_uncert_image_index: [Num_Most_Selec] this should be the real image index
        %most_uncert_patch_index: [Num_Most_Selec] this should be the numeric index for the selected patches
        binary_mask: [Num_Most_Selec, Im_h, Im_w,1]
        %pseudo_label: [Num_Most_Selec, Im_h, Im_w,1]
    """
    num_im, imh, imw = np.shape(uncert_est)
    kh, kw = np.shape(kernel)
    h_num_patch = imh - np.shape(kernel)[0] + 1
    w_num_patch = imw - np.shape(kernel)[1] + 1
    num_row_wise = h_num_patch // stride_size
    num_col_wise = w_num_patch // stride_size
    if stride_size == 1:
        tot_num_patch_per_im = num_row_wise * num_col_wise
    else:
        tot_num_patch_per_im = (num_row_wise + 1) * (num_col_wise + 1)
    print('-------------------------------There are %d patches in per image' % tot_num_patch_per_im)
    patch_tot = []
    for i in range(num_im):
        patch_subset = select_patches_in_image_area(uncert_est[i], kernel, stride_size, num_row_wise, num_col_wise)
        patch_tot.append(np.reshape(patch_subset, [-1]))
    patch_tot = np.reshape(np.array(patch_tot), [-1])
    # print('Based on the experiments, there are %d patches in total'%np.shape(patch_tot)[0])
    # print('Based on the calculation, there supposed to be %d patches in tot'%(Num_Im*tot_num_patch_per_im))
    sorted_index = np.argsort(patch_tot)
    select_most_uncert_patch = (sorted_index[-num_most_uncertain_patch:]).astype('int64')
    select_most_uncert_patch_imindex = (select_most_uncert_patch // tot_num_patch_per_im).astype('int64')
    select_most_uncert_patch_index_per_im = (select_most_uncert_patch % tot_num_patch_per_im).astype('int64')
    if stride_size == 1:
        select_most_uncert_patch_rownum_per_im = (select_most_uncert_patch_index_per_im // num_col_wise).astype('int64')
        select_most_uncert_patch_colnum_per_im = (select_most_uncert_patch_index_per_im % num_col_wise).astype('int64')
    else:
        select_most_uncert_patch_rownum_per_im = (select_most_uncert_patch_index_per_im // (num_col_wise + 1)).astype(
            'int64')
        select_most_uncert_patch_colnum_per_im = (select_most_uncert_patch_index_per_im % (num_col_wise + 1)).astype(
            'int64')
    transfered_rownum, transfered_colnum = transfer_strid_rowcol_backto_nostride_rowcol(
        select_most_uncert_patch_rownum_per_im,
        select_most_uncert_patch_colnum_per_im,
        [h_num_patch, w_num_patch],
        [num_row_wise + 1, num_col_wise + 1],
        stride_size)

    binary_mask_tot = []
    # print("The numeric index for the selected most uncertain patches-----", select_most_uncert_patch)
    # print("The corresponding uncertainty value in the selected patch-----", patch_tot[select_most_uncert_patch])
    # print("The image index for the selected most uncertain patches-------", select_most_uncert_patch_imindex)
    # print("The index of the patch in per image---------------------------", select_most_uncert_patch_index_per_im)
    # print("The row index for the selected patch--------------------------", select_most_uncert_patch_rownum_per_im)
    # print("The col index for the selected patch--------------------------", select_most_uncert_patch_colnum_per_im)
    # print("The transfered row index for the selected patch---------------", transfered_rownum)
    # print("The transfered col index for the selected patch---------------", transfered_colnum)

    for i in range(num_most_uncertain_patch):
        single_binary_mask = generate_binary_mask(imh, imw,
                                                  transfered_rownum[i],
                                                  transfered_colnum[i],
                                                  kh, kw)
        binary_mask_tot.append(single_binary_mask)
    binary_mask_tot = np.array(binary_mask_tot)
    unique_im_index = np.unique(select_most_uncert_patch_imindex)
    if np.shape(unique_im_index)[0] == num_most_uncertain_patch:
        print("----------------------------There is no replication for the selected images")
        uncertain_info = [select_most_uncert_patch_imindex, binary_mask_tot]
    else:
        print("-----These images have been selected more than twice", unique_im_index)
        binary_mask_final_tot = []
        for i in unique_im_index:
            loc_im = np.where(select_most_uncert_patch_imindex == i)[0].astype('int64')
            binary_mask_combine = (np.sum(binary_mask_tot[loc_im], axis=0) != 0).astype('int64')
            binary_mask_final_tot.append(binary_mask_combine)
        uncertain_info = [unique_im_index.astype('int64'), np.array(binary_mask_final_tot)]
    print("the shape for binary mask", np.shape(binary_mask_final_tot))
    return uncertain_info


def return_pseudo_label(single_gt, single_fb_pred, single_ed_pred, single_binary_mask):
    """This function is used to return the pseudo label for the selected patches in per image
    Args:
        single_gt: [imh, imw,1]
        single_fb_pred: [imh, imw, 2]
        single_ed_pred: [imh, imw, 2]
        single_binary_mask: [imh, imw]
    Return:
        pseudo_fb_la: [Im_h, Im_w, 1]
        pseudo_ed_la: [Im_h, Im_w, 1]
    """
    single_gt = (single_gt != 0).astype('int64')
    edge_gt = extract_edge(single_gt)
    fake_pred = (single_fb_pred[:, :, -1:] >= 0.5).astype('int64')
    fake_ed_pred = (single_ed_pred[:, :, -1:] >= 0.2).astype('int64')
    print(np.shape(fake_pred), np.shape(single_binary_mask), np.shape(single_gt), np.shape(edge_gt))
    pseudo_fb_la = fake_pred * (1 - single_binary_mask) + single_gt * single_binary_mask
    pseudo_ed_la = fake_ed_pred * (1 - single_binary_mask) + edge_gt * single_binary_mask
    return pseudo_fb_la, pseudo_ed_la


def extract_edge(la_sep):
    """This function is utilized to extract the edge from the ground truth
    Args:
        la_sep [im_h, im_w]
    Return 
        edge_gt [im_h, im_w]
    """
    selem = disk(3)
    sx = ndimage.sobel(la_sep, axis=0, mode='constant')
    sy = ndimage.sobel(la_sep, axis=1, mode='constant')
    sob = np.hypot(sx, sy)
    row = (np.reshape(sob, -1) > 0) * 1
    edge_sep = np.reshape(row, [np.shape(sob)[0], np.shape(sob)[1]])
    edge_sep = dilation(edge_sep, selem)
    edge_sep = np.expand_dims(edge_sep, axis=-1)

    return edge_sep.astype('int64')


def generate_binary_mask(imh, imw, rowindex, colindex, kh, kw):
    """This function is used to generate the binary mask for the selected most uncertain images
    Args:
        Im_h, Im_w are the size of the binary mask
        row_index, col_index are the corresponding row and column index for most uncertain patch
        kh,kw are the kernel size
    Output:
        Binary_Mask
    Opts: 
        To transform from the selected patch index to the original image. It will be like
        rowindex:rowindex+kh
        colindex:colindex+kw
    """
    binary_mask = np.zeros([imh, imw, 1])
    binary_mask[rowindex:(rowindex + kh), colindex:(colindex + kw)] = 1
    return binary_mask


def transfer_strid_rowcol_backto_nostride_rowcol(rownum, colnum, no_stride_row_col, stride_row_col, stride_size):
    """This function is used to map the row index and col index from the strided version back to the original version
    if the row_num and col_num are not equal to the last row num or last col num
    then the transfer is just rownum*stride_size, colnum*stride_size
    but if the row_num and colnum are actually the last row num or last col num
    then the transfer is that rownum*stride_size, colnum_no_stride, or row_num_no_stride, colnum*stride_size
    """
    if stride_size != 1:
        row_num_no_stride, col_num_no_stride = no_stride_row_col
        row_num_stride, col_num_stride = stride_row_col
        transfered_row_num = np.zeros([np.shape(rownum)[0]])
        for i in range(np.shape(rownum)[0]):
            if rownum[i] != (row_num_stride - 1):
                transfered_row_num[i] = stride_size * rownum[i]
            else:
                transfered_row_num[i] = row_num_no_stride - 1
        transfered_col_num = np.zeros([np.shape(colnum)[0]])
        for i in range(np.shape(colnum)[0]):
            if colnum[i] != (col_num_stride - 1):
                transfered_col_num[i] = colnum[i] * stride_size
            else:
                transfered_col_num[i] = col_num_no_stride - 1
    else:
        transfered_row_num = rownum
        transfered_col_num = colnum
    return transfered_row_num.astype('int64'), transfered_col_num.astype('int64')


def select_patches_in_image_area(single_fb, kernel, stride_size, num_row_wise, num_col_wise):
    """There needs to be a stride"""
    utility_patches = signal.convolve(single_fb, kernel, mode='valid')
    if stride_size != 1:
        subset_patch = np.zeros([num_row_wise + 1, num_col_wise + 1])
        for i in range(num_row_wise):
            for j in range(num_col_wise):
                subset_patch[i, j] = utility_patches[i * stride_size, j * stride_size]
        for i in range(num_row_wise):
            subset_patch[i, -1] = utility_patches[i * stride_size, -1]
        for j in range(num_col_wise):
            subset_patch[-1, j] = utility_patches[-1, j * stride_size]
        subset_patch[-1, -1] = utility_patches[-1, -1]
    else:
        subset_patch = utility_patches
    return subset_patch


def get_uncert_heatmap(image_single, fb_prob_single, check_rank=False):
    if check_rank is True:
        sele_index = np.where(np.mean(image_single, -1) != 0)
        fb_prob_single = fb_prob_single[np.min(sele_index[0]):np.max(sele_index[0] + 1),
                            np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
    else:
        fb_prob_single = fb_prob_single
    fb_index = (fb_prob_single[:, :, 1] >= 0.5).astype('int64')
    fb_prob_map = fb_index * fb_prob_single[:, :, 1] + (1 - fb_index) * fb_prob_single[:, :, 0]
    only_base_fb = 1 - fb_prob_map
    return only_base_fb


def get_entropy_heatmap(fb_prob_single):
    fb_entropy = np.sum(-fb_prob_single * np.log(fb_prob_single + 1e-8),
                        axis=-1)  # calculate the sum w.r.t the number of classes
    return fb_entropy


def get_bald_heatmap(fb_prob_mean_bald_single, fb_prob_single):
    bald_first_term = -np.sum(fb_prob_single * np.log(fb_prob_single + 1e-08), axis=-1)
    bald_second_term = np.sum(fb_prob_mean_bald_single, axis=-1)
    bald_value = bald_first_term + bald_second_term
    return bald_value