Skip to content
Snippets Groups Projects
acquistion_full_image.py 5.13 KiB
Newer Older
  • Learn to ignore specific revisions
  • blia's avatar
    blia committed
    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Mon Jun 18 11:45:49 2018
    This file include all the acquisition function
    There is one option for all the functions in this file is that the aggregation method could be different
    1. we consider the utility score for all the pixels in per image, therefore, it would be a sum over all the pixels 
    2. we consider the most uncertain pixels in per image, therfore it would be like we select the quantile criterior,
     and only consider
    the pixels whose utility score is larger than that criterior 
    3. It's on the way, I don't know it yet.
    @author: s161488
    """
    import numpy as np
    
    
    def extract_uncertainty_index(images, fb_prob, agg_method, quantile_cri):
        num_image = np.shape(fb_prob)[0]
        uncert = np.zeros([num_image, 1])
        for i in range(num_image):
            sele_index = np.where(np.mean(images[i, :, :, :], -1) != 0)
            fb_prob_single = fb_prob[i, np.min(sele_index[0]):np.max(sele_index[0] + 1),
                             np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
            fb_index = np.argmax(fb_prob_single, axis=-1)
            fb_prob_map = 1 - (fb_index * fb_prob_single[:, :, 1] + (1 - fb_index) * fb_prob_single[:, :, 0])
            fb_prob_reshape = np.reshape(fb_prob_map, [-1])
            if agg_method == 'Simple_Sum':
                uncert[i, 0] = np.sum(fb_prob_reshape)
            elif agg_method == 'Quantile':
                num_quant = np.percentile(fb_prob_reshape, q=quantile_cri)
                uncert[i, 0] = np.sum(fb_prob_reshape[fb_prob_reshape >= num_quant])
            else:
                print("Hey, the aggregation method is on its way :)")
        return uncert
    
    
    def extract_entropy_index(fb_prob, images, agg_method, quantile_cri):
        num_image = np.shape(images)[0]
        entropy_value = np.zeros([num_image, 1])
        for i in range(num_image):
            sele_index = np.where(np.mean(images[i, :, :, :], -1) != 0)
            fb_prob_single = fb_prob[i, np.min(sele_index[0]):np.max(sele_index[0] + 1),
                             np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
            fb_entropy = np.sum(-fb_prob_single * np.log(fb_prob_single + 1e-8),
                                axis=-1)  # calculate the sum w.r.t the number of classes
            fb_entropy_reshape = np.reshape(fb_entropy, [-1])
            if agg_method == 'Simple_Sum':
                entropy_value[i, 0] = np.sum(fb_entropy_reshape)
            elif agg_method == 'Quantile':
                num_quant = np.percentile(fb_entropy_reshape, q=quantile_cri)
                entropy_value[i, 0] = np.sum(fb_entropy_reshape[fb_entropy_reshape >= num_quant])
            else:
                print("Hey, the aggregation method is on its way :)")
        return entropy_value
    
    
    def extract_bald_index(fb_prob_mean_bald, fb_prob, x_image_pl, agg_method, quantile_cri):
        """This is for acquiring image based on BALD method
        Args:
            fb_prob_mean_bald: shape [Number_of_Image, im_h, im_w, 2]
            fb_prob_mean_bald = 1/t*p_c*log(p_c)
            fb_prob: the predicted probability, shape [Number_of_Image, im_h, im_w, 2]
            x_image_pl: [num_image, imh, imw, 3]
            agg_method: "sum", "quantile"
            quantile_cri: int
        Return:
            BALD_Value
        """
        BALD_value = np.zeros([np.shape(x_image_pl)[0], 1])
        for i in range(np.shape(x_image_pl)[0]):
            sele_index = np.where(np.mean(x_image_pl[i, :, :, :], -1) != 0)
            fb_prob_mean_bald_single = fb_prob_mean_bald[i, np.min(sele_index[0]):np.max(sele_index[0] + 1),
                                       np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
            fb_prob_single = fb_prob[i, np.min(sele_index[0]):np.max(sele_index[0] + 1),
                             np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
            bald_first_term = -np.sum(fb_prob_single * np.log(fb_prob_single + 1e-08), axis=-1)
            bald_second_term = np.sum(fb_prob_mean_bald_single, axis=-1)
            bald_value = bald_first_term + bald_second_term
            bald_reshape = np.reshape(bald_value, [-1])
            if agg_method == 'Simple_Sum':
                BALD_value[i, 0] = np.sum(bald_reshape)
            elif agg_method == 'Quantile':
                num_quant = np.percentile(bald_reshape, q=quantile_cri)
                BALD_value[i, 0] = np.sum(bald_reshape[bald_reshape >= num_quant])
            else:
                print("Hey, the aggregation method is on its way :)")
        return BALD_value
    
    
    def extract_informative_index(acq_method, x_image_pl, fb_prob, fb_prob_var, fb_prob_mean_bald, num_select_point,
                                 agg_method, quantile_cri):
        if acq_method is "B":
            print("acquisition function is uncertainty")
            margin_diff = extract_uncertainty_index(x_image_pl, fb_prob, agg_method, quantile_cri)
        elif acq_method is "C":
            print("acquisition function is entropy")
            margin_diff = extract_entropy_index(fb_prob, x_image_pl, agg_method, quantile_cri)
        elif acq_method is "D":
            print("acquisition function is BALD")
            margin_diff = extract_bald_index(fb_prob_mean_bald, fb_prob, x_image_pl, agg_method, quantile_cri)
        else:
            print("Hey, the acquisition function is on its way :)")
        marg_index = np.argsort(margin_diff[:, 0], axis=0)
        Acq_Index = marg_index[-num_select_point:]
        return Acq_Index