Skip to content
Snippets Groups Projects
calc_calibration_score.py 14.9 KiB
Newer Older
  • Learn to ignore specific revisions
  • blia's avatar
    blia committed
    # -*- coding: utf-8 -*-
    """
    Created on Wed May  20 12:49 2020
    This file is for calculating the calibration score of the experiments
    @author: s161488
    """
    import os
    import numpy as np
    import time
    import pickle
    import matplotlib.pyplot as plt
    import shutil
    
    import eval_calibration.calibration_lib as calib
    from data_utils.prepare_data import collect_test_data, prepare_pool_data
    
    
    def get_group_region(method_use):
        """Give the calibration score for region active learning"""
        if method_use is "B":
            path_input = ["Method_B_Stage_1_Version_3", "Method_B_Stage_1_Version_4"]
        elif method_use is "C":
            path_input = ["Method_C_Stage_2_Version_1", "Method_C_Stage_2_Version_2"]
        elif method_use is "D":
            path_input = ["Method_D_Stage_3_Version_2", "Method_D_Stage_3_Version_3"]
        for single_path in path_input:
            calc_calibration_value(single_path, 6)
    
    
    def get_group_full(version_use):
        """Give the calibration score for full image active learning"""
        path_input = sorted(os.listdir("/scratch/blia/Act_Learn_Desperate_V%d/" % version_use))
        for single_path in path_input:
            calc_calibration_value(single_path, version_use)
    
    
    def calculate_aggre_pixel(path_input, version, path2read=None):
        """This function calculates the acquired number of pixels in per step
        based on the data in collect
        """
        if not path2read:
            path2read = '/scratch/blia/Act_Learn_Desperate_V%d/%s/' % (version, path_input)
        path2save = '/home/blia/Exp_Data/calibration_score/Act_Learn_Desperate_V%d/' % version
        version = int(path_input.strip().split('_')[-1])
        num_of_pixel_old = np.load(path2read + '/num_of_pixel.npy')
        num_of_image_old = np.load(path2read + '/num_of_image.npy')
        acq_step = len(os.listdir(path2read + '/collect_data/'))
        new_stat = np.zeros([acq_step, 2])
        im_index = []
        unique_im = np.zeros([acq_step])
        for i in range(acq_step):
            p = ["FE_step_00_version_%d" % version if i == 0 else "FE_step_%d_version_%d" % (i - 1, version)][0]
            stat = pickle.load(open(path2read + '/collect_data/%s/updated_uncertain.txt' % p, 'rb'))
            new_stat[i, 0] = np.sum(stat[-2])
            im_index.append(stat[-1])
            _im_ind = np.unique([v for j in im_index for v in j])
            unique_im[i] = len(_im_ind)
        new_stat[:, 0] = np.cumsum(new_stat[:, 0])
        new_stat[:, 1] = unique_im + 10
        print("number of acquired pixels old", num_of_pixel_old)
        print("number of acquired pixels new", new_stat[:, 0])
        print("number of acquired images old", num_of_image_old)
        print("number of acquired images new", new_stat[:, 1])
        np.save(path2save + 'query_stat_%s' % path_input, new_stat)
    
    
    def transfer_numeric_index_back_to_imindex(numeric_index):
        im_index_pool = np.arange(65)
        numeric_index = numeric_index.astype('int32')
        selected_imindex = np.zeros(np.shape(numeric_index))
        for i in range(len(numeric_index)):
            select_index = numeric_index[i]
            selected_imindex[i] = im_index_pool[select_index]
            im_index_pool = np.delete(im_index_pool, select_index)
        return selected_imindex.astype('int32')
    
    
    def collect_pool_data_multi_acquisition_step(version, use_str):
        """Args:
        version: the experiment version
        use_str: "Method_B_", or "Method_C_" or "Method_D_"
        """
        path = "/scratch/blia/Act_Learn_Desperate_V%d/" % version
        _, y_label_gt, _ = prepare_pool_data("/home/blia/Exp_Data/Data/glanddata.npy", True)
        path_group = [v for v in os.listdir(path) if use_str in v]
        for single_path in path_group:
            collect_pool_data(version, single_path, y_label_gt)
    
    
    def collect_pool_data(version, path_use, pool_fb_label):
        """This function returns the predicted uncertainty at each acquisition step for different
        acquisition functions
        Args:
            version: the experiment version
            path_use: the experiment path, such as "Method_B_Stage_1_Version_0"
            pool_fb_label: the pixel-wise ground truth label for images in the pool set
        """
        exp_version = int(path_use.strip().split('_')[-1])
        path = '/scratch/blia/Act_Learn_Desperate_V%d/%s/' % (version, path_use)
        save_dir = '/home/blia/Exp_Data/calibration_score/Act_Learn_Desperate_V%d/' % version
        numeric_index = np.load(path + 'total_acqu_index.npy')
        selected_imindex = transfer_numeric_index_back_to_imindex(numeric_index)
        pool_data_dir = path + 'Pool_Data/'
        fb_group = []
        gt_group = []
        bald_group = []
        print(path_use)
        for i in range(len(numeric_index))[:-3]:
            print("acquisition step ", i)
            fb_ = np.load(pool_data_dir + "/FE_step_%d_version_%d/Test_Data_A/fbprob.npy" % (i, exp_version))
            fb_group.append(fb_[selected_imindex[i + 1]])
            if "_D_" in path_use:
                bald_ = np.load(pool_data_dir + "/FE_step_%d_version_%d/Test_Data_A/fbbald.npy" % (i, exp_version))
                bald_group.append(bald_[selected_imindex[i + 1]])
            gt_group.append(pool_fb_label[selected_imindex[i + 1]])
        if "_D_" in path_use:
            fb_group = [fb_group, bald_group]
        np.save(save_dir + path_use + "pool_label", np.array(gt_group))
        np.save(save_dir + path_use + "pool_stat", np.array(fb_group))
    
    
    def give_calibration_histogram(path_group, version, test_step, method):
        """This function gives the ece histogram (expected calibration error) at different step
        for different acquisition functions"""
        x_image_test, y_label_test = collect_test_data(resize=False)
        y_label_binary = (y_label_test != 0).astype('int32')
        y_label_binary = np.reshape(y_label_binary, [-1])
        save_dir = "/home/blia/Exp_Data/save_fig/"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        path_ = "/scratch/blia/Act_Learn_Desperate_V%d/" % version
        stat_group = []
        for single_path in path_group:
            exp_version = int(single_path.strip().split("_")[-1])
            path_use = path_ + single_path + "/Test_Data/"
            _stat = _give_calibration_histogram(path_use, exp_version, test_step, y_label_binary)
            stat_group.append(_stat)
        np.save(save_dir + "/ece_historgram_stat_%d_%s_%d" % (test_step, method, version), stat_group)
        accu_mean = np.mean([v[1] for v in stat_group], axis=0)
        accu_std = np.std([v[1] for v in stat_group], axis=0) * 1.95 / np.sqrt(len(path_group))
        fig = plt.figure(figsize=(6, 3))
        ax = fig.add_subplot(111)
        ax.plot(stat_group[0][0], accu_mean, 'r')
        ax.fill_between(stat_group[0][0], accu_mean - accu_std, accu_mean + accu_std,
                        color='r', alpha=0.5)
        ax.plot([0.0, 1.0], [0.0, 1.0], color='b', ls=':')
        plt.savefig(save_dir + '/%d_%s_%d.jpg' % (test_step, method, version))
    
    
    def _give_calibration_histogram(path2read, exp_version, test_step, y_label_binary):
        path2read = path2read + "FE_step_%d_version_%d/" % (test_step, exp_version)
        path_a, path_b = "Test_Data_A", "Test_Data_B"
        fb_prob = []
        for single_path in [path_a, path_b]:
            fb = np.load(path2read + single_path + "/fbprob.npy")
            fb = np.reshape(fb, [-1, 2])
            fb_prob.append(fb)
        fb_prob = np.concatenate(fb_prob, axis=0)
        top_k_probs, is_correct = calib.get_multiclass_predictions_and_correctness(fb_prob,
                                                                                   y_label_binary, None)
        top_k_probs = top_k_probs.flatten()
        is_correct = is_correct.flatten()
        bin_edges, accuracies, counts = calib.bin_predictions_and_accuracies(top_k_probs,
                                                                             is_correct,
                                                                             bins=10)
        bin_centers = calib.bin_centers_of_mass(top_k_probs, bin_edges)
        return bin_centers, accuracies, counts
    
    
    def collect_region_uncertainty(version, path_subset, step):
        pathmom = '/scratch/blia/Act_Learn_Desperate_V%d/%s' % (version, path_subset)
        model_version = int(path_subset.strip().split('_')[-1])
        method = str(path_subset.strip().split('_')[1])
        path2read = pathmom + '/collect_data/'
        print("====loading experiment statistics from folder:", path2read)
        for i in range(step):
            print("===step %d=====" % i)
            path = path2read + 'FE_step_%d_version_%d/' % (i, model_version)
            _region_uncertainty(path, method)
    
    
    def _region_uncertainty(path, method):
        """Calculates the uncertainty from the region acquisition"""
        savepath = '/home/blia/Exp_Data/calibration_score/region_uncertainty/'
        if not os.path.exists(savepath):
            os.makedirs(savepath)
        path_split = path.strip().split('/')
        savename = path_split[3].strip().split('_')[-1] + '_' + path_split[4] + '_step_' + path_split[6].strip().split('_')[
            2]
        stat = pickle.load(open(path + 'updated_uncertain.txt', 'rb'))
        selected_image, fb_label, ed_label, binary_mask, imindex = stat
        fbprob = np.load(path + 'fbprob.npy')
        fbprob_subset = fbprob[imindex]
        if method is "D":
            fbbald = np.load(path + 'fbbald.npy')
            fbbald_subset = fbbald[imindex]
            fbprob_subset = [fbprob_subset, fbbald_subset]
        uncert = calc_uncertainty(fbprob_subset, method, False)
        print("-----maximum of uncertainty %.2f minimum of uncertainty %.2f------" % (np.max(uncert),
                                                                                      np.min(uncert)))
        uncert = (uncert - np.min(uncert)) / (np.max(uncert) - np.min(uncert))
        uncert = uncert * binary_mask[:, :, :, 0]
        uncert_aggre = uncert[uncert != 0]
    
        print(np.shape(uncert_aggre), np.sum(binary_mask))
        np.save(savepath + '/' + savename, uncert_aggre)
    
    
    def calc_uncertainty(prob, method, reshape=True):
        if method is "B":
            uncert = 1 - np.max(prob, axis=-1)
        elif method is "C":
            uncert = np.sum(-prob * np.log(prob + 1e-8), axis=-1)
        elif method is "D":
            prob, bald = prob
            bald_first = -np.sum(prob * np.log(prob + 1e-8), axis=-1)
            bald_second = np.sum(bald, axis=-1)
            uncert = bald_first + bald_second
        if reshape:
            return np.reshape(uncert, [-1])
        else:
            return uncert
    
    
    def aggregate_stat(version):
        """This function aggregates all the calibration score in one folder
        Args:
            version: int, Act_Learn_Desperate_%d % version
        Ops:
            1. The calibration score needs to be read
            2. Then this calibration score file will copied to a home directory
        """
        path2read = '/scratch/blia/Act_Learn_Desperate_V%d' % version
        path2save = '/home/blia/Exp_Data/calibration_score/'
        path2save = path2save + 'Act_Learn_Desperate_V%d' % version
        if not os.path.exists(path2save):
            os.makedirs(path2save)
        all_model = sorted([v for v in os.listdir(path2read) if 'Method_' in v])
        for single_model in all_model:
            print(single_model)
            orig_file_path = path2read + '/' + single_model + '/' + 'calibration_score.obj'
            new_file_path = path2save + '/%s.obj' % single_model
            shutil.copy(orig_file_path, new_file_path)
    
    
    def calc_calibration_value(path_input, version_use):
        """The calculated calibration score here is used to create Figure 1, 5 and E2 in the paper"""
        x_image_test, y_label_test = collect_test_data(resize=False)
        y_label_binary = (y_label_test != 0).astype('int32')
        num_image, imh, imw = np.shape(y_label_test)
    
        path_mom = os.path.join("/scratch/blia/Act_Learn_Desperate_V%d/" % version_use, path_input)
        path_sub = np.load(os.path.join(path_mom, 'total_select_folder.npy'))
        test_data_path = path_mom + '/Test_Data/'
        num_class = 2
        num_benign, num_mali = 37, 43
        y_label_benign_binary = np.reshape(y_label_binary[:num_benign], [num_benign * imh * imw])
        y_label_mali_binary = np.reshape(y_label_binary[num_benign:], [num_mali * imh * imw])
        y_label_binary = np.reshape(y_label_binary, [num_image * imh * imw])
        stat = {}
        # for each of them there will be a score for benign, and also for mali, and also overall
        # for the ece error, because it's only binary classification, so I will just do top-1
        ece_score = []
        brier_score = []
        nll_score = []
        brier_decompose_score = []
        for single_sub in path_sub:
            single_folder_name = single_sub.strip().split('/')[-2]
            tds_dir = test_data_path + single_folder_name + '/'
            pred = []
            for single_test in ["Test_Data_A/", "Test_Data_B/"]:
                tds_use = tds_dir + single_test
                fb_prob = np.load(tds_use + 'fbprob.npy')
                fb_reshape = np.reshape(np.squeeze(fb_prob, axis=(1, 2)),
                                        [len(fb_prob) * imh * imw, num_class])
                pred.append(fb_reshape)
    
            # --- first, nll score --------#
            time_init = time.time()
    
            nll_benign, nll_mali = calib.nll(pred[0]), calib.nll(pred[1])
            #        time_init = get_time(time_init, "nll")
            ece_benign = calib.expected_calibration_error_multiclass(pred[0], y_label_benign_binary, 10)
            ece_mali = calib.expected_calibration_error_multiclass(pred[1], y_label_mali_binary, 10)
            #        time_init = get_time(time_init, "ece")
            brier_benign = calib.brier_scores(y_label_benign_binary, probs=pred[0])
            brier_mali = calib.brier_scores(y_label_mali_binary, probs=pred[1])
            #        time_init = get_time(time_init, "brier score")
            brier_benign_decomp = calib.brier_decomp_npy(labels=y_label_benign_binary, probabilities=pred[0])
            brier_mali_decomp = calib.brier_decomp_npy(labels=y_label_mali_binary, probabilities=pred[1])
            #        time_init = get_time(time_init, "brier score decomposition")
            pred_conc = np.concatenate(pred, axis=0)
            nll_all = calib.nll(pred_conc)
            ece_all = calib.expected_calibration_error_multiclass(pred_conc, y_label_binary, 10)
            brier_all = calib.brier_scores(y_label_binary, probs=pred_conc)
            brier_all_decomp = calib.brier_decomp_npy(labels=y_label_binary, probabilities=pred_conc)
    
            #        print(time.time() - time_init)
            ece_score.append([ece_benign, ece_mali, ece_all])
            brier_score.append([np.mean(brier_benign), np.mean(brier_mali), np.mean(brier_all)])
            brier_decompose_score.append([brier_benign_decomp, brier_mali_decomp, brier_all_decomp])
            nll_score.append([nll_benign, nll_mali, nll_all])
    
        stat["ece_score"] = np.reshape(np.array(ece_score), [len(ece_score), 3])
        stat["nll_score"] = np.reshape(np.array(nll_score), [len(nll_score), 3])
        stat["bri_score"] = np.reshape(np.array(brier_score), [len(brier_score), 3])
        stat["bri_decompose_score"] = np.reshape(np.array(brier_decompose_score), [len(brier_decompose_score), 9])
    
        print("ece score", stat["ece_score"][0], ece_score[0])
        print("nll score", stat["nll_score"][0], nll_score[0])
        print("bri score", stat["bri_score"][0], brier_score[0])
        print("brier decompose score", stat["bri_decompose_score"][0], brier_decompose_score[0])
    
        with open(path_mom + "/calibration_score.obj", 'wb') as f:
            pickle.dump(stat, f)
    
    
    def get_time(time_init, opt):
        time_end = time.time()
        print("%s use time--------%.4f" % (opt, time_end - time_init))
        return time_end