Skip to content
Snippets Groups Projects
acquisition_region.py 14.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • blia's avatar
    blia committed
    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Wed Oct 10 18:03:32 2018
    This file is the new version for calculating the uncertainty value in each patch
    It's better because:
        1. the way of choosing the most uncertain patch is automate
        2. The weight ratio for each regions can be easily changed to any value
        @author: s161488
    """
    import numpy as np
    from scipy import signal, ndimage
    from skimage.morphology import dilation, disk
    
    
    def select_most_uncertain_patch(x_image_pl, y_label_pl, fb_pred, ed_pred, fb_prob_mean_bald, kernel_window, stride_size,
                                    already_select_image_index, previously_selected_binary_mask, num_most_uncert_patch,
                                    method):
        """This function is used to acquire the #most uncertain patches in the pooling set.
        Args:
            x_image_pl: [Num_Im, Im_h, Im_w,3]
            y_label_pl: [Num_Im, Im_h, Im_w,1]
            fb_pred: [Num_Im, Im_h, Im_w, 2]
            ed_pred: [Num_Im, Im_h, Im_w, 2]
            fb_prob_mean_bald: [num_im, imw, imw]
            kernel_window: [kh, kw] determine the size of the region
            stride_size: int, determine the stride between every two regions
            already_select_image_index: if it's None, then it means that's the first acquistion step,
                otherwise it's the numeric image index for the previously selected patches
            previously_selected_binary_mask: [num_already_selected_images, Im_h, Im_w,1]
            num_most_uncert_patch: int, number of patches that are selected in each acquisition step
            method: acquisition method: 'B', 'C', 'D'
        Returns:
            Most_Uncert_Im: [Num_Selected, Im_h, Im_w, 3]imp
            Most_Uncert_FB_GT: [Num_Selected, Im_h, Im_w,1]
            Most_Uncert_ED_GT: [Num_Selected, Im_h, Im_w,1]
            Most_Uncert_Binary_Mask: [Num_Selected, Im_h, Im_w,1]
            Selected_Image_Index: [Num_Selected]
        """
        num_im = np.shape(x_image_pl)[0]
        uncertainty_map_tot = []
        for i in range(num_im):
            if method == 'B':
                var_stat = get_uncert_heatmap(x_image_pl[i], fb_pred[i])
            elif method == 'C':
                var_stat = get_entropy_heatmap(fb_pred[i])
            elif method == 'D':
                var_stat = get_bald_heatmap(fb_prob_mean_bald[i], fb_pred[i])
            uncertainty_map_tot.append(var_stat)
        uncertainty_map_tot = np.array(uncertainty_map_tot)
        if already_select_image_index is None:
            print("--------This is the beginning of the selection process-------")
        else:
            print(
                "----------Some patches have already been annotated, I need to deal with that")
            previously_selected_binary_mask = np.squeeze(previously_selected_binary_mask, axis=-1)
            for i in range(np.shape(previously_selected_binary_mask)[0]):
                uncertainty_map_single = uncertainty_map_tot[already_select_image_index[i]]
                uncertainty_map_updated = uncertainty_map_single * (1 - previously_selected_binary_mask[i])
                uncertainty_map_tot[already_select_image_index[i]] = uncertainty_map_updated
        selected_numeric_image_index, binary_mask_updated_tot = calculate_score_for_patch(uncertainty_map_tot,
                                                                                          kernel_window, stride_size,
                                                                                          num_most_uncert_patch)
        pseudo_fb_la_tot = []
        pseudo_ed_la_tot = []
        for index, single_selected_image_index in enumerate(selected_numeric_image_index):
            pseudo_fb_la, pseudo_ed_la = return_pseudo_label(y_label_pl[single_selected_image_index],
                                                             fb_pred[single_selected_image_index],
                                                             ed_pred[single_selected_image_index],
                                                             binary_mask_updated_tot[index])
    
            pseudo_fb_la_tot.append(pseudo_fb_la)
            pseudo_ed_la_tot.append(pseudo_ed_la)
        most_uncert_im_tot = x_image_pl[selected_numeric_image_index]
        most_uncertain = [most_uncert_im_tot,
                          pseudo_fb_la_tot,
                          pseudo_ed_la_tot,
                          binary_mask_updated_tot,
                          selected_numeric_image_index]
        return most_uncertain
    
    
    def calculate_score_for_patch(uncert_est, kernel, stride_size, num_most_uncertain_patch):
        """This function is used to calculate the utility score for each patch.
        Args:
            uncert_est: [num_image, imh, imw]
            kernel: the size of each searching shape
            stride_size: the stride between every two regions
            num_most_uncertain_patch: int, the number of selected regions
        Returns:
            most_uncert_image_index: [Num_Most_Selec] this should be the real image index
            %most_uncert_patch_index: [Num_Most_Selec] this should be the numeric index for the selected patches
            binary_mask: [Num_Most_Selec, Im_h, Im_w,1]
            %pseudo_label: [Num_Most_Selec, Im_h, Im_w,1]
        """
        num_im, imh, imw = np.shape(uncert_est)
        kh, kw = np.shape(kernel)
        h_num_patch = imh - np.shape(kernel)[0] + 1
        w_num_patch = imw - np.shape(kernel)[1] + 1
        num_row_wise = h_num_patch // stride_size
        num_col_wise = w_num_patch // stride_size
        if stride_size == 1:
            tot_num_patch_per_im = num_row_wise * num_col_wise
        else:
            tot_num_patch_per_im = (num_row_wise + 1) * (num_col_wise + 1)
        print('-------------------------------There are %d patches in per image' % tot_num_patch_per_im)
        patch_tot = []
        for i in range(num_im):
            patch_subset = select_patches_in_image_area(uncert_est[i], kernel, stride_size, num_row_wise, num_col_wise)
            patch_tot.append(np.reshape(patch_subset, [-1]))
        patch_tot = np.reshape(np.array(patch_tot), [-1])
        # print('Based on the experiments, there are %d patches in total'%np.shape(patch_tot)[0])
        # print('Based on the calculation, there supposed to be %d patches in tot'%(Num_Im*tot_num_patch_per_im))
        sorted_index = np.argsort(patch_tot)
        select_most_uncert_patch = (sorted_index[-num_most_uncertain_patch:]).astype('int64')
        select_most_uncert_patch_imindex = (select_most_uncert_patch // tot_num_patch_per_im).astype('int64')
        select_most_uncert_patch_index_per_im = (select_most_uncert_patch % tot_num_patch_per_im).astype('int64')
        if stride_size == 1:
            select_most_uncert_patch_rownum_per_im = (select_most_uncert_patch_index_per_im // num_col_wise).astype('int64')
            select_most_uncert_patch_colnum_per_im = (select_most_uncert_patch_index_per_im % num_col_wise).astype('int64')
        else:
            select_most_uncert_patch_rownum_per_im = (select_most_uncert_patch_index_per_im // (num_col_wise + 1)).astype(
                'int64')
            select_most_uncert_patch_colnum_per_im = (select_most_uncert_patch_index_per_im % (num_col_wise + 1)).astype(
                'int64')
        transfered_rownum, transfered_colnum = transfer_strid_rowcol_backto_nostride_rowcol(
            select_most_uncert_patch_rownum_per_im,
            select_most_uncert_patch_colnum_per_im,
            [h_num_patch, w_num_patch],
            [num_row_wise + 1, num_col_wise + 1],
            stride_size)
    
        binary_mask_tot = []
        # print("The numeric index for the selected most uncertain patches-----", select_most_uncert_patch)
        # print("The corresponding uncertainty value in the selected patch-----", patch_tot[select_most_uncert_patch])
        # print("The image index for the selected most uncertain patches-------", select_most_uncert_patch_imindex)
        # print("The index of the patch in per image---------------------------", select_most_uncert_patch_index_per_im)
        # print("The row index for the selected patch--------------------------", select_most_uncert_patch_rownum_per_im)
        # print("The col index for the selected patch--------------------------", select_most_uncert_patch_colnum_per_im)
        # print("The transfered row index for the selected patch---------------", transfered_rownum)
        # print("The transfered col index for the selected patch---------------", transfered_colnum)
    
        for i in range(num_most_uncertain_patch):
            single_binary_mask = generate_binary_mask(imh, imw,
                                                      transfered_rownum[i],
                                                      transfered_colnum[i],
                                                      kh, kw)
            binary_mask_tot.append(single_binary_mask)
        binary_mask_tot = np.array(binary_mask_tot)
        unique_im_index = np.unique(select_most_uncert_patch_imindex)
        if np.shape(unique_im_index)[0] == num_most_uncertain_patch:
            print("----------------------------There is no replication for the selected images")
            uncertain_info = [select_most_uncert_patch_imindex, binary_mask_tot]
        else:
            print("-----These images have been selected more than twice", unique_im_index)
            binary_mask_final_tot = []
            for i in unique_im_index:
                loc_im = np.where(select_most_uncert_patch_imindex == i)[0].astype('int64')
                binary_mask_combine = (np.sum(binary_mask_tot[loc_im], axis=0) != 0).astype('int64')
                binary_mask_final_tot.append(binary_mask_combine)
            uncertain_info = [unique_im_index.astype('int64'), np.array(binary_mask_final_tot)]
        print("the shape for binary mask", np.shape(binary_mask_final_tot))
        return uncertain_info
    
    
    def return_pseudo_label(single_gt, single_fb_pred, single_ed_pred, single_binary_mask):
        """This function is used to return the pseudo label for the selected patches in per image
        Args:
            single_gt: [imh, imw,1]
            single_fb_pred: [imh, imw, 2]
            single_ed_pred: [imh, imw, 2]
            single_binary_mask: [imh, imw]
        Return:
            pseudo_fb_la: [Im_h, Im_w, 1]
            pseudo_ed_la: [Im_h, Im_w, 1]
        """
        single_gt = (single_gt != 0).astype('int64')
        edge_gt = extract_edge(single_gt)
        fake_pred = (single_fb_pred[:, :, -1:] >= 0.5).astype('int64')
        fake_ed_pred = (single_ed_pred[:, :, -1:] >= 0.2).astype('int64')
        print(np.shape(fake_pred), np.shape(single_binary_mask), np.shape(single_gt), np.shape(edge_gt))
        pseudo_fb_la = fake_pred * (1 - single_binary_mask) + single_gt * single_binary_mask
        pseudo_ed_la = fake_ed_pred * (1 - single_binary_mask) + edge_gt * single_binary_mask
        return pseudo_fb_la, pseudo_ed_la
    
    
    def extract_edge(la_sep):
        """This function is utilized to extract the edge from the ground truth
        Args:
            la_sep [im_h, im_w]
        Return 
            edge_gt [im_h, im_w]
        """
        selem = disk(3)
        sx = ndimage.sobel(la_sep, axis=0, mode='constant')
        sy = ndimage.sobel(la_sep, axis=1, mode='constant')
        sob = np.hypot(sx, sy)
        row = (np.reshape(sob, -1) > 0) * 1
        edge_sep = np.reshape(row, [np.shape(sob)[0], np.shape(sob)[1]])
        edge_sep = dilation(edge_sep, selem)
        edge_sep = np.expand_dims(edge_sep, axis=-1)
    
        return edge_sep.astype('int64')
    
    
    def generate_binary_mask(imh, imw, rowindex, colindex, kh, kw):
        """This function is used to generate the binary mask for the selected most uncertain images
        Args:
            Im_h, Im_w are the size of the binary mask
            row_index, col_index are the corresponding row and column index for most uncertain patch
            kh,kw are the kernel size
        Output:
            Binary_Mask
        Opts: 
            To transform from the selected patch index to the original image. It will be like
            rowindex:rowindex+kh
            colindex:colindex+kw
        """
        binary_mask = np.zeros([imh, imw, 1])
        binary_mask[rowindex:(rowindex + kh), colindex:(colindex + kw)] = 1
        return binary_mask
    
    
    def transfer_strid_rowcol_backto_nostride_rowcol(rownum, colnum, no_stride_row_col, stride_row_col, stride_size):
        """This function is used to map the row index and col index from the strided version back to the original version
        if the row_num and col_num are not equal to the last row num or last col num
        then the transfer is just rownum*stride_size, colnum*stride_size
        but if the row_num and colnum are actually the last row num or last col num
        then the transfer is that rownum*stride_size, colnum_no_stride, or row_num_no_stride, colnum*stride_size
        """
        if stride_size != 1:
            row_num_no_stride, col_num_no_stride = no_stride_row_col
            row_num_stride, col_num_stride = stride_row_col
            transfered_row_num = np.zeros([np.shape(rownum)[0]])
            for i in range(np.shape(rownum)[0]):
                if rownum[i] != (row_num_stride - 1):
                    transfered_row_num[i] = stride_size * rownum[i]
                else:
                    transfered_row_num[i] = row_num_no_stride - 1
            transfered_col_num = np.zeros([np.shape(colnum)[0]])
            for i in range(np.shape(colnum)[0]):
                if colnum[i] != (col_num_stride - 1):
                    transfered_col_num[i] = colnum[i] * stride_size
                else:
                    transfered_col_num[i] = col_num_no_stride - 1
        else:
            transfered_row_num = rownum
            transfered_col_num = colnum
        return transfered_row_num.astype('int64'), transfered_col_num.astype('int64')
    
    
    def select_patches_in_image_area(single_fb, kernel, stride_size, num_row_wise, num_col_wise):
        """There needs to be a stride"""
        utility_patches = signal.convolve(single_fb, kernel, mode='valid')
        if stride_size != 1:
            subset_patch = np.zeros([num_row_wise + 1, num_col_wise + 1])
            for i in range(num_row_wise):
                for j in range(num_col_wise):
                    subset_patch[i, j] = utility_patches[i * stride_size, j * stride_size]
            for i in range(num_row_wise):
                subset_patch[i, -1] = utility_patches[i * stride_size, -1]
            for j in range(num_col_wise):
                subset_patch[-1, j] = utility_patches[-1, j * stride_size]
            subset_patch[-1, -1] = utility_patches[-1, -1]
        else:
            subset_patch = utility_patches
        return subset_patch
    
    
    def get_uncert_heatmap(image_single, fb_prob_single, check_rank=False):
        if check_rank is True:
            sele_index = np.where(np.mean(image_single, -1) != 0)
            fb_prob_single = fb_prob_single[np.min(sele_index[0]):np.max(sele_index[0] + 1),
                                np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
        else:
            fb_prob_single = fb_prob_single
        fb_index = (fb_prob_single[:, :, 1] >= 0.5).astype('int64')
        fb_prob_map = fb_index * fb_prob_single[:, :, 1] + (1 - fb_index) * fb_prob_single[:, :, 0]
        only_base_fb = 1 - fb_prob_map
        return only_base_fb
    
    
    def get_entropy_heatmap(fb_prob_single):
        fb_entropy = np.sum(-fb_prob_single * np.log(fb_prob_single + 1e-8),
                            axis=-1)  # calculate the sum w.r.t the number of classes
        return fb_entropy
    
    
    def get_bald_heatmap(fb_prob_mean_bald_single, fb_prob_single):
        bald_first_term = -np.sum(fb_prob_single * np.log(fb_prob_single + 1e-08), axis=-1)
        bald_second_term = np.sum(fb_prob_mean_bald_single, axis=-1)
        bald_value = bald_first_term + bald_second_term
        return bald_value