Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 18 11:45:49 2018
This file include all the acquisition function
There is one option for all the functions in this file is that the aggregation method could be different
1. we consider the utility score for all the pixels in per image, therefore, it would be a sum over all the pixels
2. we consider the most uncertain pixels in per image, therfore it would be like we select the quantile criterior,
and only consider
the pixels whose utility score is larger than that criterior
3. It's on the way, I don't know it yet.
@author: s161488
"""
import numpy as np
def extract_uncertainty_index(images, fb_prob, agg_method, quantile_cri):
num_image = np.shape(fb_prob)[0]
uncert = np.zeros([num_image, 1])
for i in range(num_image):
sele_index = np.where(np.mean(images[i, :, :, :], -1) != 0)
fb_prob_single = fb_prob[i, np.min(sele_index[0]):np.max(sele_index[0] + 1),
np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
fb_index = np.argmax(fb_prob_single, axis=-1)
fb_prob_map = 1 - (fb_index * fb_prob_single[:, :, 1] + (1 - fb_index) * fb_prob_single[:, :, 0])
fb_prob_reshape = np.reshape(fb_prob_map, [-1])
if agg_method == 'Simple_Sum':
uncert[i, 0] = np.sum(fb_prob_reshape)
elif agg_method == 'Quantile':
num_quant = np.percentile(fb_prob_reshape, q=quantile_cri)
uncert[i, 0] = np.sum(fb_prob_reshape[fb_prob_reshape >= num_quant])
else:
print("Hey, the aggregation method is on its way :)")
return uncert
def extract_entropy_index(fb_prob, images, agg_method, quantile_cri):
num_image = np.shape(images)[0]
entropy_value = np.zeros([num_image, 1])
for i in range(num_image):
sele_index = np.where(np.mean(images[i, :, :, :], -1) != 0)
fb_prob_single = fb_prob[i, np.min(sele_index[0]):np.max(sele_index[0] + 1),
np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
fb_entropy = np.sum(-fb_prob_single * np.log(fb_prob_single + 1e-8),
axis=-1) # calculate the sum w.r.t the number of classes
fb_entropy_reshape = np.reshape(fb_entropy, [-1])
if agg_method == 'Simple_Sum':
entropy_value[i, 0] = np.sum(fb_entropy_reshape)
elif agg_method == 'Quantile':
num_quant = np.percentile(fb_entropy_reshape, q=quantile_cri)
entropy_value[i, 0] = np.sum(fb_entropy_reshape[fb_entropy_reshape >= num_quant])
else:
print("Hey, the aggregation method is on its way :)")
return entropy_value
def extract_bald_index(fb_prob_mean_bald, fb_prob, x_image_pl, agg_method, quantile_cri):
"""This is for acquiring image based on BALD method
Args:
fb_prob_mean_bald: shape [Number_of_Image, im_h, im_w, 2]
fb_prob_mean_bald = 1/t*p_c*log(p_c)
fb_prob: the predicted probability, shape [Number_of_Image, im_h, im_w, 2]
x_image_pl: [num_image, imh, imw, 3]
agg_method: "sum", "quantile"
quantile_cri: int
Return:
BALD_Value
"""
BALD_value = np.zeros([np.shape(x_image_pl)[0], 1])
for i in range(np.shape(x_image_pl)[0]):
sele_index = np.where(np.mean(x_image_pl[i, :, :, :], -1) != 0)
fb_prob_mean_bald_single = fb_prob_mean_bald[i, np.min(sele_index[0]):np.max(sele_index[0] + 1),
np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
fb_prob_single = fb_prob[i, np.min(sele_index[0]):np.max(sele_index[0] + 1),
np.min(sele_index[1]):np.max(sele_index[1] + 1), :]
bald_first_term = -np.sum(fb_prob_single * np.log(fb_prob_single + 1e-08), axis=-1)
bald_second_term = np.sum(fb_prob_mean_bald_single, axis=-1)
bald_value = bald_first_term + bald_second_term
bald_reshape = np.reshape(bald_value, [-1])
if agg_method == 'Simple_Sum':
BALD_value[i, 0] = np.sum(bald_reshape)
elif agg_method == 'Quantile':
num_quant = np.percentile(bald_reshape, q=quantile_cri)
BALD_value[i, 0] = np.sum(bald_reshape[bald_reshape >= num_quant])
else:
print("Hey, the aggregation method is on its way :)")
return BALD_value
def extract_informative_index(acq_method, x_image_pl, fb_prob, fb_prob_var, fb_prob_mean_bald, num_select_point,
agg_method, quantile_cri):
if acq_method is "B":
print("acquisition function is uncertainty")
margin_diff = extract_uncertainty_index(x_image_pl, fb_prob, agg_method, quantile_cri)
elif acq_method is "C":
print("acquisition function is entropy")
margin_diff = extract_entropy_index(fb_prob, x_image_pl, agg_method, quantile_cri)
elif acq_method is "D":
print("acquisition function is BALD")
margin_diff = extract_bald_index(fb_prob_mean_bald, fb_prob, x_image_pl, agg_method, quantile_cri)
else:
print("Hey, the acquisition function is on its way :)")
marg_index = np.argsort(margin_diff[:, 0], axis=0)
Acq_Index = marg_index[-num_select_point:]
return Acq_Index