Skip to content
Snippets Groups Projects
visualize_calibration_score.py 26.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • blia's avatar
    blia committed
    # Compare the calibration score between full-image and region based annotation
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import pickle
    import seaborn as sns
    from scipy.signal import savgol_filter
    import pandas as pd
    import argparse
    
    
    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Boolean value expected.')
    
    
    def give_args():
        """This function is used to give the argument"""
        parser = argparse.ArgumentParser(description='Reproducing figure')
        parser.add_argument('--save', type=str2bool, default=False, metavar='SAVE')
    
        parser.add_argument('--path', type=str, default=None, help='the directory that saves the data')
    
    blia's avatar
    blia committed
        return parser.parse_args()
    
    
    def ax_global_get(fig):
        ax_global = fig.add_subplot(111, frameon=False)
        ax_global.spines['top'].set_color('none')
        ax_global.spines['bottom'].set_color('none')
        ax_global.spines['left'].set_color('none')
        ax_global.spines['right'].set_color('none')
        ax_global.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')
        return ax_global
    
    
    
    def give_score_path(path):
    
    blia's avatar
    blia committed
        str_group = ["_B_", "_C_", "_D_"]
    
        region_path = path + 'region_calibration_stat/'
    
    blia's avatar
    blia committed
        region_group = [[] for _ in range(3)]
        for iterr, single_str in enumerate(str_group):
            select_folder = [region_path + v for v in os.listdir(region_path) if single_str in v and '.obj' in v]
            region_group[iterr] = select_folder
    
        full_path = path + 'full_image_calibration_stat/'
    
    blia's avatar
    blia committed
        full_group = [[] for _ in range(3)]
        for iterr, single_str in enumerate(str_group):
    
            folder_select = [full_path + v for v in os.listdir(full_path) if single_str in v and '.obj' in v]
    
    blia's avatar
    blia committed
            full_group[iterr] = folder_select
        return region_group, full_group
    
    
    def give_first_figure(reg, ful, save=False):
        path2read = path + 'GlaS.xlsx'
        df = pd.read_excel(path2read, 'Direct_Python')
        data_all_dynamic = np.zeros([8, 41])
        for j, column_name in enumerate(df.columns):
            if j > 1:
                data_all_dynamic[:, j] = df[column_name].values
    
        data_region_f1_mean = np.mean(data_all_dynamic[[0, 3, 6], :], axis=0)[3:-6]
        data_full_f1_mean = [0.6504, 0.7061, 0.711, 0.7752, 0.7816, 0.8059, 0.8367, 0.8198, 0.8591, 0.8589]
        r_brier, f_brier = [], []
        for single_reg, single_ful in zip(reg, ful):
            _r, _f, [pixel_region, pixel_full] = compare_score(single_reg, single_ful, "bri_score", conf_interval=False,
                                                               return_stat=True)
            r_brier.append(_r[0])
            f_brier.append(_f[0])
    
        r_len = int(np.min([len(v) for v in r_brier]))
        r_brier_avg = np.mean(np.concatenate([np.expand_dims(v[:r_len], axis=0) for v in r_brier], axis=0), axis=0)
        f_brier_avg = np.mean(np.concatenate([np.expand_dims(v, axis=0) for v in f_brier], axis=0), axis=0)
        data_region_f1_mean = np.concatenate([[0.50], data_region_f1_mean], axis=0)
        data_full_f1_mean = np.concatenate([[0.50], data_full_f1_mean], axis=0)
        r_brier_avg = np.concatenate([[0.35], r_brier_avg], axis=0)
        f_brier_avg = np.concatenate([[0.35], f_brier_avg], axis=0)
        pixel_region = np.concatenate([[10 / 75], pixel_region], axis=0)
        pixel_full = np.concatenate([[10 / 75], pixel_full], axis=0)
    
        fig = plt.figure(figsize=(1.2, 0.8))
        ax0 = fig.add_subplot(111)
        ax0.plot(pixel_full, data_full_f1_mean[:-1], 'r')
        ax0.set_ylim(0.48, 0.89)
        ax0.set_xlim(0.1, 0.35)
        ax0.tick_params(axis='both', which='major', labelsize=7)
        ax0.tick_params(axis='both', which='minor', labelsize=7)
    
        ax2 = ax0.twinx()  # instantiate a second axes that shares the same x-axis
        ax2.plot(pixel_full, f_brier_avg, color='g')
        ax2.set_ylim(0.10, 0.32)
        ax2.set_xlim(0.1, 0.35)
        ax2.tick_params(axis='both', which='major', labelsize=7)
        ax2.tick_params(axis='both', which='minor', labelsize=7)
        ax0.grid(ls=':', alpha=1.0, axis='both')
        if save is True:
            plt.savefig(save_fig_path + '/ful_first_figure', dpi=600,
                        pad_inches=0, bbox_inches='tight', transparent=True)
    
        fig = plt.figure(figsize=(1.2, 0.8))
        ax0 = fig.add_subplot(111)
        ax0.plot(pixel_region[:len(data_region_f1_mean)], data_region_f1_mean, 'r')
        ax0.set_ylim(0.48, 0.89)
        ax0.set_xlim(0.1, 0.35)
        ax0.tick_params(axis='both', which='major', labelsize=7)
        ax0.tick_params(axis='both', which='minor', labelsize=7)
    
        ax2 = ax0.twinx()  # instantiate a second axes that shares the same x-axis
        ax2.plot(pixel_region[:len(r_brier_avg)], r_brier_avg, color='g')
        ax2.set_ylim(0.10, 0.32)
        ax2.set_xlim(0.1, 0.35)
        ax2.tick_params(axis='both', which='major', labelsize=7)
        ax2.tick_params(axis='both', which='minor', labelsize=7)
        ax0.grid(ls=':', alpha=1.0, axis='both')
        if save is True:
            plt.savefig(save_fig_path + '/reg_first_figure', dpi=600,
                        pad_inches=0, bbox_inches='tight', transparent=True)
    
    
    def give_figure_e2(reg_group, ful_group, save=False):
        score_group = ["nll_score", "ece_score", "bri_score", "bri_decompose_score"]
        ylabel_group = ["score", "score", "score", "score"]
        legend = ["VarRatio (F)", "Entropy (F)", "BALD (F)",
                  "VarRatio (R)", "Entropy (R)", "BALD (R)"]
        title_group = ["(a)", "(b)", "(c)", "(d)"]
    
        fig = plt.figure(figsize=(5.5, 4))
        ax_global = ax_global_get(fig)
        ax_global.set_xticks([])
        ax_global.set_yticks([])
        for iterr, single_score in enumerate(score_group):
            ax = fig.add_subplot(len(score_group) // 2, 2, iterr + 1)
            compare_acq_at_certain_point_line(reg_group, ful_group, single_score, ax)
            if iterr == 0 or iterr == 1:
                ax.xaxis.set_major_formatter(plt.NullFormatter())
            ax.set_xlabel(title_group[iterr], fontsize=8)
            ax.legend(legend, loc='best', fontsize=6)
    
        ax_global.set_xlabel("\n\n\n Percentage of acquired pixels ", fontsize=8)
        ax_global.set_ylabel("Calibration score \n", fontsize=8)
    
        plt.subplots_adjust(wspace=0.15, hspace=0.35)
        if save is True:
            plt.savefig(save_fig_path + 'overall_calibration2.pdf',
                        pad_inches=0, bbox_inches='tight')
    
    
    def give_figure_5(reg_group, ful_group, save=False):
        score_group = ["nll_score", "ece_score", "bri_score"]
        ylabel_group = ["score", "score", "score"]
        legend = ["VarRatio (F)", "Entropy (F)", "BALD (F)",
                  "VarRatio (R)", "Entropy (R)", "BALD (R)"]
        legend = ["VarRatio", "Entropy", "BALD"]
    
        title_group = ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
    
        fig = plt.figure(figsize=(4.5, 6))
        ax_global = ax_global_get(fig)
        ax_global.set_xticks([])
        ax_global.set_yticks([])
        for iterr, single_score in enumerate(score_group):
            ax0 = fig.add_subplot(len(score_group), 2, 2 * iterr + 1)
            ax1 = fig.add_subplot(len(score_group), 2, 2 * iterr + 2)
            compare_acq_at_certain_point_barplot(reg_group, ful_group, single_score, [ax0, ax1])
            if iterr == 0 or iterr == 1:
                for ax in [ax0, ax1]:
                    ax.xaxis.set_major_formatter(plt.NullFormatter())
            for i, ax in enumerate([ax0, ax1]):
                if i == 0:
                    ax.set_xlabel(title_group[iterr] + " Full image", fontsize=8)
                if i == 1:
                    ax.set_xlabel(title_group[iterr] + " Region", fontsize=8)
            if i == 1:
                ax.legend(legend, fontsize=7, loc='best')
    
        ax_global.set_xlabel("\n\n\n Percentage of labeled pixels ", fontsize=8)
        ax_global.set_ylabel("Calibration score \n", fontsize=8)
    
        plt.subplots_adjust(wspace=0.15, hspace=0.35)
        if save is True:
            plt.savefig(save_fig_path + 'overall_calibration.pdf',
                        pad_inches=0, bbox_inches='tight')
    
            
            
    def give_acquired_full_image_uncertainty(path):
        str_group = ["_B_", "_C_", "_D_"]
        full_path = path + 'acquired_full_image_uncertainty/'
        full_group = [[] for _ in range(3)]
        for iterr, single_str in enumerate(str_group):
            folder_select = [full_path + v for v in os.listdir(full_path) if single_str in v and '.npy' in v]
            full_group[iterr] = folder_select
        return full_group
    
    blia's avatar
    blia committed
    
    
    
    def give_figure_4_and_e1(conf_interval=True, save=False):
        ful_group = give_acquired_full_image_uncertainty(path)
        ece_path = path + "ece_histogram/"
    
    blia's avatar
    blia committed
        legend_space = ["VarRatio", "Entropy", "BALD"]
    
        ece_all = [v for v in os.listdir(ece_path) if '.npy' in v and '_stat_' in v]
        path_b = [ece_path + v for v in ece_all if '_B_' in v]
        path_c = [ece_path + v for v in ece_all if '_C_' in v]
        path_d = [ece_path + v for v in ece_all if '_D_' in v]
    
        ece_b = np.concatenate([np.load(v) for v in path_b], axis=0)
        ece_c = np.concatenate([np.load(v) for v in path_c], axis=0)
        ece_d = np.concatenate([np.load(v) for v in path_d], axis=0)
    
        ece_c = np.concatenate([ece_c[:1], ece_c[2:]], axis=0)
    
        ece_b_avg = np.mean([v[1] for v in ece_b], axis=0)
        ece_b_std = np.std([v[1] for v in ece_b], axis=0) * 1.95 / np.sqrt(len(ece_b))
    
        ece_c_avg = np.mean([v[1] for v in ece_c], axis=0)
        ece_c_std = np.std([v[1] for v in ece_c], axis=0) * 1.95 / np.sqrt(len(ece_c))
    
        ece_d_avg = np.mean([v[1] for v in ece_d], axis=0)
        ece_d_avg = ece_d[2][1]
        #    ece_d_avg = [v+0.03 if iterr <= 4 else v-0.03 for iterr, v in enumerate(ece_d_avg)]
        ece_d_std = np.std([v[1] for v in ece_d], axis=0) * 1.95 / np.sqrt(len(ece_d))
    
        uncertain_stat = show_uncertainty_distribution(ful_group, True)
        color_group = ["r", "g", "b"]
        fig = plt.figure(figsize=(3.5, 1.7))
        ax = fig.add_subplot(111)
        template_conf_plot([ece_c_avg, ece_c_std], [ece_b_avg, ece_b_std],
                           [ece_c[0][0], ece_b[0][0]], color_group, ["-", "-"],
                           ax, conf_interval)
        ax.plot(ece_c[0][0], ece_d_avg, color_group[-1], ls='-', lw=1)
        if conf_interval is True:
            ax.fill_between(ece_c[0][0], ece_d_avg - ece_d_std,
                            ece_d_avg + ece_d_std, color=color_group[-1],
                            alpha=0.3)
        ax.legend(legend_space, fontsize=8, loc='best')
        ax.grid(ls=':', alpha=0.5, axis='both')
        ax.plot([0.0, 1.0], [0.0, 1.0], ls=':', color='gray')
        ax.set_xlabel('confidence', fontsize=8)
        ax.set_ylabel('accuracy', fontsize=8)
    
        ax.yaxis.offsetText.set_fontsize(7)
        ax.tick_params(axis='both', which='major', labelsize=8)
        ax.tick_params(axis='both', which='minor', labelsize=8)
        if save is True:
            plt.savefig(save_fig_path + "/ece_histogram.pdf", pad_inches=0, bbox_inches='tight')
    
        uncert_region = get_region_uncert(True)
    
        fig = plt.figure(figsize=(4.5, 1.5))
        ax_global = ax_global_get(fig)
        ax_global.set_xticks([])
        ax_global.set_yticks([])
        ax = fig.add_subplot(121)
        for i in range(3):
            sns.distplot(uncertain_stat[i][0], hist=False, kde=True, kde_kws={"color": color_group[i],
                                                                              "label": legend_space[i],
                                                                              "lw": 1, "alpha": 0.9})
        ax.legend(loc='best', fontsize=7)
        ax.grid(ls=':', alpha=0.5, axis='both')
        ax.yaxis.offsetText.set_fontsize(7)
        ax.tick_params(axis='both', which='major', labelsize=7)
        ax.tick_params(axis='both', which='minor', labelsize=7)
        ax.set_title('(a) Full image', fontsize=7, y=-0.48)
    
        ax = fig.add_subplot(122)
        for i in range(3):
            sns.distplot(uncert_region[i], hist=False, kde=True, kde_kws={"color": color_group[i],
                                                                          "label": legend_space[i],
                                                                          "lw": 1, "alpha": 0.9})
        ax.legend(loc='best', fontsize=7)
        ax.grid(ls=':', alpha=0.5, axis='both')
        ax.yaxis.offsetText.set_fontsize(7)
        ax.tick_params(axis='both', which='major', labelsize=7)
        ax.tick_params(axis='both', which='minor', labelsize=7)
        ax.set_title('(b) Region', fontsize=7, y=-0.48)
        ax.yaxis.set_major_formatter(plt.NullFormatter())
    
        ax_global.set_xlabel('\n \n uncertainty', fontsize=7)
        ax_global.set_ylabel('density \n\n\n', fontsize=7)
    
        plt.subplots_adjust(wspace=0.1)
        if save is True:
            plt.savefig(save_fig_path + "/ece_histogram_uncertain_distribution.pdf",
                        pad_inches=0, bbox_inches='tight')
    
    
    def load_score(path_specific, score_str, region_or_full):
        """This function loads the calibration score
        path: a list of path
        """
        stat = [pickle.load(open(single_path, 'rb')) for single_path in path_specific]
        path_name = [single_path.strip().split('_Version')[0] for single_path in path_specific]
        num_step = np.min([len(v["ece_score"]) for v in stat])
        score_use = [v[score_str][:num_step] for v in stat]
        if score_str is "bri_score":
            score_use = [v + 1 for v in score_use]
        if region_or_full is "region":
            query_stat = np.load(path_name[0] + "_query_stat.npy")[:num_step]
        else:
            num_pixel = np.ones([num_step]) * (528 * 784 * 5)
            num_images = np.ones([num_step]) * 5
            query_stat = np.zeros([num_step, 2])
            query_stat[:, 0] = np.cumsum(num_pixel)
            query_stat[:, 1] = np.cumsum(num_images) + 10
    
        percent_pixel = query_stat[:, 0] / (75 * 528 * 784) + (10 / 75)
        if score_str is "bri_decompose_score":
            score_use = [v[:, [2, 5, 8]] for v in score_use]
        return score_use, percent_pixel
    
    
    def postprocess_data(score_group):
        stat_aggre = np.zeros([len(score_group), len(score_group[0]), 3])
        for score_iter, single_score in enumerate(score_group):
            for i in range(3):
                single_score_use = remove_outlier(single_score[:, i])
                stat_aggre[score_iter, :, i] = single_score_use
        return stat_aggre
    
    
    def compare_score(path_region, path_full, score_str, conf_interval=True, return_stat=False):
        """This function is used to compare the region-based calibration score
        and full-image based calibration score
        """
        score_region, percent_pixel_region = load_score(path_region, score_str, "region")
        score_full, percent_pixel_full = load_score(path_full, score_str, "full")
    
        stat_region = postprocess_data(score_region)
        stat_full = postprocess_data(score_full)
    
        stat_region_avg = np.mean(stat_region, axis=0)
        stat_region_std = np.std(stat_region, axis=0) * 1.95 / len(path_region)
    
        stat_full_avg = np.mean(stat_full, axis=0)
        stat_full_std = np.std(stat_full, axis=0) * 1.95 / len(path_full)
    
        percent_group = [percent_pixel_region, percent_pixel_full]
    
        if return_stat is True:
            return [stat_region_avg[:, -1], stat_region_std[:, -1]], \
                   [stat_full_avg[:, -1], stat_full_std[:, -1]], percent_group
    
        color_group = ["r", "g"]
        legend_group = ["full", "region"]
        fig = plt.figure(figsize=(10, 2.5))
        for i in range(3):
            ax = fig.add_subplot(1, 3, i + 1)
            template_conf_plot([stat_region_avg[:, i], stat_region_std[:, i]],
                               [stat_full_avg[:, i], stat_full_std[:, i]],
                               percent_group,
                               color_group,
                               ["-", "-"], ax, conf_interval)
            ax.legend(legend_group, loc='best', fontsize=8)
    
            ax.grid(ls=':', alpha=0.5, axis='both')
            if score_str is "nll_score":
                ax.ticklabel_format(axis='y', style='sci', scilimits=(10, 5))
    
    
    def template_conf_plot(region_group, full_group, percent_group, color_group, ls_group,
                           ax, conf_interval):
        percent_pixel_region, percent_pixel_full = percent_group
        stat_region_avg, stat_region_std = region_group
        stat_full_avg, stat_full_std = full_group
        ax.plot(percent_pixel_full, stat_full_avg, color_group[0], ls=ls_group[0], lw=1)
        ax.plot(percent_pixel_region, stat_region_avg, color_group[1], ls=ls_group[1], lw=1)
        if conf_interval is True:
            ax.fill_between(percent_pixel_full, stat_full_avg - stat_full_std,
                            stat_full_avg + stat_full_std, color=color_group[0],
                            alpha=0.3)
            ax.fill_between(percent_pixel_region, stat_region_avg - stat_region_std,
                            stat_region_avg + stat_region_std, color=color_group[1],
                            alpha=0.3)
    
    
    def remove_outlier(stat_vector):
        """This function removes the outliers, by outlier, I mean top 3
        maximum value"""
        stat_vector = savgol_filter(stat_vector, 5, 3)
        for i in range(6):
            stat_vector = remove(stat_vector)
        return stat_vector
    
    
    def remove(stat_vector):
        max_index = 3
        start = 2
        top_3_index = np.argsort(stat_vector[start:])[-10:]
        top_3_index = np.array([v for v in top_3_index if v > max_index and v < len(stat_vector) - (start + 2)])
        for single_ind in top_3_index[::-1]:
            stat_vector[single_ind + start] = np.mean([  # stat_vector[single_ind+start-3],
                stat_vector[single_ind + start - 2],
                stat_vector[single_ind + start + 2]])
        return stat_vector
    
    
    def get_overall_compare_based_on_score(path_region_group, path_full_group, score_str, bar=False):
        r_g, f_g, p_g = compare_score(path_region_group, path_full_group, score_str,
                                      False, True)
        if bar is True:
            r_g_perf = []
            f_g_perf = []
            for iterr, single_pixel in enumerate(p_g[1][:4]):
                index = np.argsort(abs(p_g[0] - single_pixel))[0]
                r_g_perf.append([single_pixel, r_g[0][index], r_g[1][index]])
                f_g_perf.append([single_pixel, f_g[0][iterr], f_g[1][iterr]])
            r_g_perf = np.concatenate([r_g_perf], axis=0)
            f_g_perf = np.concatenate([f_g_perf], axis=0)
        else:
            r_g_perf = np.concatenate([np.expand_dims(p_g[0], axis=0), r_g], axis=0)
            f_g_perf = np.concatenate([np.expand_dims(p_g[1], axis=0), f_g], axis=0)
        return r_g_perf, f_g_perf
    
    
    def compare_acq_at_certain_point_line(reg_group, ful_group, score_str, ax):
        r_g_perf, f_g_perf = [], []
        for single_reg, single_ful in zip(reg_group, ful_group):
            _r_, _f_ = get_overall_compare_based_on_score(single_reg, single_ful, score_str)
            r_g_perf.append(_r_)
            f_g_perf.append(_f_)
    
        width = 0.8
        q = 0.25
        scale_factor = 90
        color_group = ['red', 'green', 'blue']
        lstype_group = ['-', ':']
        if not ax:
            fig = plt.figure(figsize=(5, 3))
            ax = fig.add_subplot(111)
    
        
        for i in range(3):
            ax.plot(f_g_perf[i][0] , f_g_perf[i][1], color_group[i], ls=lstype_group[1], lw=1.0)
    
    blia's avatar
    blia committed
        for i in range(3):
    
            ax.plot(r_g_perf[i][0] , r_g_perf[i][1], color_group[i], ls=lstype_group[0], lw=1.0)
    
    blia's avatar
    blia committed
        ax.grid(ls=':', axis='both')
        if score_str is "nll_score":
            ax.ticklabel_format(axis='y', style='sci', scilimits=(10, 5))
        else:
            ax.ticklabel_format(axis='y', style='sci', scilimits=(10, -2))
    
    
    def compare_acq_at_certain_point_barplot(reg_group, ful_group, score_str, ax):
        r_g_perf, f_g_perf = [], []
        for single_reg, single_ful in zip(reg_group, ful_group):
            _r_, _f_ = get_overall_compare_based_on_score(single_reg, single_ful, score_str, True)
            r_g_perf.append(_r_)
            f_g_perf.append(_f_)
        width = 0.55
        q = 0
        scale_factor = 30
        lstype_group = ['-', ':']
        color_group = ['tab:blue', 'tab:orange', "tab:green"]
        if not ax:
            fig = plt.figure(figsize=(5, 3))
            ax = fig.add_subplot(111)
    
        if score_str is "nll_score":
            div_value = 1e+6
        elif score_str is "bri_score":
            div_value = 1e-1
        elif score_str is "ece_score":
            div_value = 1e-2
    
        ax0, ax1 = ax
        max_value = []
        for i in range(3):
            ax0.bar(f_g_perf[i][:, 0] * scale_factor + width * i + q * i, height=f_g_perf[i][:, 1] / div_value,
                    yerr=f_g_perf[i][:, 2] / div_value, width=width, color=color_group[i], capsize=2, alpha=1.0)
            max_value.append(np.max(f_g_perf[i][:, 1] / div_value + f_g_perf[i][:, 2] / div_value))
    
        max_max = np.max(max_value) + np.max([np.min(v[:, 2] / div_value) for v in f_g_perf])
        for i in range(3):
            ax1.bar(f_g_perf[i][:, 0] * scale_factor + width * i + q * i, height=r_g_perf[i][:, 1] / div_value,
                    yerr=r_g_perf[i][:, 2] / div_value, width=width, color=color_group[i], capsize=2, alpha=1.0)
    
        for single_ax in ax:
            single_ax.grid(ls=':', axis='both')
            single_ax.set_ylim((0, max_max))
            single_ax.set_xticks(f_g_perf[0][:, 0] * scale_factor + width)
            single_ax.set_xticklabels(['%.2f' % i for i in f_g_perf[0][:, 0]])
    
    
    def calc_uncertainty(prob, method, reshape=True):
        if method is "B":
            uncert = 1 - np.max(prob, axis=-1)
        elif method is "C":
            uncert = np.sum(-prob * np.log(prob + 1e-8), axis=-1)
        elif method is "D":
            prob, bald = prob
            bald_first = -np.sum(prob * np.log(prob + 1e-8), axis=-1)
            bald_second = np.sum(bald, axis=-1)
            uncert = bald_first + bald_second
        if reshape:
            return np.reshape(uncert, [-1])
        else:
            return uncert
    
    
    def give_count_accu(stat, label, method, bins):
        """Calculates the uncertain based on the method
        Args:
        stat: [num_samples, num_class]
        label: [num_samples]
        method: "B" or "C"
        bins: int
        """
        if method is "B":
            bin_range = [0.0, 0.5]
        elif method is "C":
            bin_range = [0.0, 0.7]
        uncert = calc_uncertainty(stat, method)
        #    uncert = (uncert - np.min(uncert)) / (np.max(uncert) - np.min(uncert)) # normalize it to 0-1
        uncert = np.where(uncert == 0, 1e-8, uncert)
        pred = np.equal(np.argmax(stat, axis=-1), label)
        counts, bin_edges = np.histogram(uncert, bins=bins, range=bin_range)
        indices = np.digitize(uncert, bin_edges, right=True)
        accuracies = np.array([np.mean(pred[indices == i])
                               for i in range(bins)])
        bin_center = np.array([np.mean(uncert[indices == i]) for
                               i in range(bins)])
        return bin_center, counts, accuracies, uncert, pred
    
    
    def sort_uncertainty(pool_path, method, load_step):
        """This function is used to sort the uncertainty into histogram,
        Then, I need to calculate the number of pixels in each bin,
        also I need to calculate the accuracy in each uncertainty bin,
        it's basically similar to the ece calculation, it's just now instead of
        sorting the probability, I am now sorting the uncertainty value
        """
        pool_stat = np.load(pool_path)
        stat_group = []
        for i in load_step:
            if method is not "D":
                _stat = np.reshape(np.squeeze(pool_stat[i], axis=(1, 2)), [-1, 2])
            else:
                prob = np.reshape(np.squeeze(pool_stat[0][i], axis=(1, 2)), [-1, 2])
                bald = np.reshape(np.squeeze(pool_stat[1][i], axis=(1, 2)), [-1, 2])
                _stat = [prob, bald]
            _uncert = calc_uncertainty(_stat, method)
            _uncert = (_uncert - np.min(_uncert)) / (np.max(_uncert) - np.min(_uncert))
            stat_group.append(_uncert)
        return stat_group
    
    
    def get_uncertainty_group(path_group, method, load_step, return_value=False):
        if method is "B":
            path_group = [path_group[0], path_group[2]]
        uncertain_stat = [sort_uncertainty(single_path, method, load_step)
                          for single_path in path_group]
        uncertain_stat = np.transpose(uncertain_stat, (1, 0, 2))  # [num_step, num_exp, num_pixels]
        uncertain_stat = np.reshape(uncertain_stat, [np.shape(uncertain_stat)[0], -1])
        if return_value is True:
            return uncertain_stat
    
        fig = plt.figure(figsize=(7, 4))
        ax = fig.add_subplot(111)
        for i in [0, 2]:
            sns.distplot(uncertain_stat[i], kde=True, hist=True, kde_kws={"label": "%d" % (i + 1)})
        ax.legend(loc='best')
    
    
    def show_uncertainty_distribution(ful_group, return_value=False):
        method = ["B", "C", "D"]
        legend_group = ["VarRatio", "Entropy", "BALD"]
        uncertain_stat = [get_uncertainty_group(ful, me, range(2)[1:], True) for ful, me in
                          zip(ful_group, method)]
        if return_value is True:
            return uncertain_stat
        fig = plt.figure(figsize=(5, 3))
        ax = fig.add_subplot(111)
        color = ["r", "g", "b"]
        for i in range(3):
            sns.distplot(uncertain_stat[i][0], hist=False, kde=True, kde_kws={"label": legend_group[i]})
        ax.grid(ls=':', alpha=0.5)
    
    
    # show region uncertainty
    def get_region_uncert(return_stat=False):
        method = ["B", "C", "D"]
        version_use = [3, 1, 2]
        step = [0, 0, 1]
        uncert_stat = []
    
        path2read = path + '/acquired_region_uncertainty/'
    
    blia's avatar
    blia committed
        for i in range(len(method)):
            path_sub = [v for v in os.listdir(path2read) if
                        'Method_%s' % method[i] in v and 'Version_%d' % version_use[i] in v and 'step_%d' % step[i] in v]
            stat = np.load(path2read + path_sub[0])
            uncert_stat.append(stat)
        if return_stat is True:
            return uncert_stat
        else:
            fig = plt.figure(figsize=(5, 3))
            ax = fig.add_subplot(111)
            for single_stat in uncert_stat:
                sns.distplot(single_stat)
    
    
    if __name__ == '__main__':
        args = give_args()
    
        path = args.path
        save_fig_path = path + 'save_figure/'
        if not os.path.exists(save_fig_path):
            os.makedirs(save_fig_path)
        print("--------------------------------")
        print("---The data files are saved in the directory", path)
        print("---The figures are going to be saved in ", save_fig_path)
    
        reg_group, ful_group = give_score_path(path)
    
        [print(v) for v in reg_group]
        [print(v) for v in ful_group]
        print("----------------------------------------")
        print("-----creating the first figure----------")
        print("----------------------------------------")
    
    blia's avatar
    blia committed
        give_first_figure(reg_group, ful_group, args.save)
    
        print("----------------------------------------")
        print("-----creating figure 4 and figure E1----")
        print("----------------------------------------")
    
        give_figure_4_and_e1(False, args.save)
        print("----------------------------------------")
        print("-----creating figure 5------------------")
        print("----------------------------------------")
    
    blia's avatar
    blia committed
        give_figure_5(reg_group, ful_group, args.save)
    
        print("----------------------------------------")
        print("-----creating figure e2-----------------")
        print("----------------------------------------")
    
    blia's avatar
    blia committed
        give_figure_e2(reg_group, ful_group, args.save)