Skip to content
Snippets Groups Projects
Train_Active_Region_Im.py 26 KiB
Newer Older
blia's avatar
blia committed
# -*- coding: utf-8 -*-
"""
Created on Wed Mar  7 16:42:15 2018
This file is used to train the active learning framework with region specific annotation
@author: s161488
"""
import numpy as np
import os
import tensorflow as tf
from data_utils.prepare_data import aug_train_data, generate_batch
from data_utils.update_data import give_init_train_and_val_data, update_training_data, prepare_the_new_uncertain_input
from models.inference import ResNet_V2_DMNN
from optimization.loss_region_specific import Loss, train_op_batchnorm
from sklearn.utils import shuffle
from select_regions import selection as SPR_Region_Im
import pickle


print("--------------------------------------------------------------")
print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
print("--------------------------------------------------------------")
training_data_path = "DATA/Data/glanddata.npy"  # NOTE, NEED TO BE MANUALLY DEFINED
test_data_path = "DATA/Data/glanddata_testb.npy"  # NOTE, NEED TO BE MANUALLY DEFINED
resnet_dir = "pretrain_model/"
exp_dir = "Exp_Stat/"  # NOTE, NEED TO BE MANUALLY DEFINED
ckpt_dir_init = "Exp_Stat/initial_model/"
print("-------THE PATH FOR THE INITIAL MODEL NEEDS TO BE USER DEFINED", ckpt_dir_init)
print("--------------------------------------------------------------")
print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
print("--------------------------------------------------------------")
blia's avatar
blia committed


def run_loop_active_learning_region(stage, round_number=np.array([0, 1, 2, 3])):
blia's avatar
blia committed
    """This function is used to train the active learning framework with region specific annotation.
    Args:
        stage: int, 0--> random selection, 1--> VarRatio, 2--> entropy, 3--> BALD
        round_number: list, [int], repeat experiments in order to get confidence interval
blia's avatar
blia committed
    Ops:
        1. this script can only be run given the model that is trained with the initial training data (10)!!!
        2. in each acquisition step, the experiment is repeated # times to avoid bad local optimal
        3. Then after we get the updated model, the function SPR_Region_Im is used to evaluate all the regions
        in the pool data. It selects the most #num_most_uncertain_patch from the pool set. And the selections
        are added into the training data.
        4. Again, a new model will be trained as described in step 2 with the updated data from step 3
        5. step 2 to 4 is repeated for #total_active_step times
    """
    for single_round_number in round_number:
        logs_path = exp_dir
        flag_arch_name = "resnet_v2_50"
        resnet_ckpt = os.path.join(resnet_dir, flag_arch_name) + '.ckpt'
        total_active_step = 10
        num_repeat_per_exp = 4
        acq_method_total = ["A", "B", "C", "D"]
        acq_selec_method = acq_method_total[stage]
        kernel_window = np.ones([150, 150])
        stride_size = 30
        num_most_uncert_patch = 20
        logs_path = os.path.join(logs_path,
                                 'Method_%s_Stage_%d_Version_%d' % (acq_selec_method, stage, single_round_number))
        most_init_train_data, all_the_time_val_data = give_init_train_and_val_data(training_data_path)
        num_of_pixels_need_to_be_annotate = np.zeros([total_active_step])
        total_folder_info = []
        total_num_im = np.zeros([total_active_step])

        for single_acq_step in range(total_active_step):
            if single_acq_step == 0:
                tds = os.path.join(ckpt_dir_init, 'pool_data')
                most_uncertain_data = SPR_Region_Im(tds, ckpt_dir_init, acq_selec_method, None, None, kernel_window,
                                                    stride_size, num_most_uncert_patch=20, data_path=training_data_path,
                                                    check_overconfident=False)
                updated_training_data = update_training_data(most_init_train_data[:4], [], most_uncertain_data[:4])
                already_selected_imindex = most_uncertain_data[-1]
                already_selected_binary_mask = most_uncertain_data[-2]
                most_uncert_old = most_uncertain_data
            num_of_pixels_need_to_be_annotate[single_acq_step] = np.sum(np.reshape(most_uncert_old[-2], [
                -1]))  # this is the binary mask, the number of pixels that needs to be annotate
            # equal to the number of pixels which are assigned to be 1
            num_im = np.shape(updated_training_data[0])[0]
            total_num_im[single_acq_step] = num_im
            epsilon_opt = 0.001
            batch_size = 5
            epoch_size = 1300
            model_dir = os.path.join(logs_path, 'FE_step_%d_version_%d' % (single_acq_step, single_round_number))
            tot_train_val_stat_for_diff_exp_same_step = np.zeros(
                [num_repeat_per_exp, 4])  # fb loss, ed loss, fb f1 score, fb auc score
            if single_acq_step == 5:
                regu_par = 0.0005
            else:
                regu_par = 0.001
            if single_acq_step >= 10:
                decay_steps = np.ceil(num_im / 5) * 600
            else:
                decay_steps = (num_im // 5) * 600
            for repeat_time in range(num_repeat_per_exp):
                print("=====================Start Experiment No.%d===========================" % repeat_time)
                model_dir_sub = os.path.join(model_dir, 'rep_%d' % repeat_time)
                signal = False
                while signal is False:
                    signal_for_bad_optimal = False
                    while signal_for_bad_optimal is False:
                        train(resnet_ckpt=resnet_ckpt,
                              ckpt_dir=None,
                              model_dir=model_dir_sub,
                              epoch_size=20,
                              decay_steps=decay_steps,
                              epsilon_opt=epsilon_opt,
                              regu_par=regu_par,
                              batch_size=batch_size,
                              training_data=updated_training_data,
                              val_data=all_the_time_val_data,
                              FLAG_PRETRAIN=False)
                        train_stat = np.load(os.path.join(model_dir_sub, 'trainstat.npy'))
                        val_stat = np.load(os.path.join(model_dir_sub, 'valstat.npy'))
                        sec_cri = [np.mean(train_stat[-10:, 1]), np.mean(val_stat[-1, 1])]  # fb f1 score
                        thir_cri = [np.mean(train_stat[-10:, 2]), np.mean(val_stat[-1, 2])]  # fb auc score
                        if np.mean(sec_cri) == 0.0 or np.mean(thir_cri) == 0.5:
                            signal_for_bad_optimal = False
                            all_the_files = os.listdir(model_dir_sub)
                            for single_file in all_the_files:
                                os.remove(os.path.join(model_dir_sub, single_file))
                            print("--------------------The model start from a really bad optimal----------------")
                        else:
                            signal_for_bad_optimal = True
                    train(resnet_ckpt=resnet_ckpt,
                          ckpt_dir=model_dir_sub,
                          model_dir=model_dir_sub,
                          epoch_size=epoch_size,
                          decay_steps=decay_steps,
                          epsilon_opt=epsilon_opt,
                          regu_par=regu_par,
                          batch_size=batch_size,
                          training_data=updated_training_data,
                          val_data=all_the_time_val_data,
                          FLAG_PRETRAIN=True)
                    train_stat = np.load(os.path.join(model_dir_sub, 'trainstat.npy'))
                    val_stat = np.load(os.path.join(model_dir_sub, 'valstat.npy'))
                    first_cri = [np.mean(train_stat[-20:, -1]), np.mean(val_stat[-10:, -1])]  # ed loss
                    sec_cri = [np.mean(train_stat[-20:, 1]), np.mean(val_stat[-10:, 1])]  # fb f1 score
                    thir_cri = [np.mean(train_stat[-20:, 2]), np.mean(val_stat[-10:, 2])]  # fb auc score
                    fourth_cri = [np.mean(train_stat[-20:, 0]), np.mean(val_stat[-10:, 0])]  # fb loss
                    if np.mean(first_cri) >= 0.30 or np.mean(sec_cri) <= 0.80 or np.mean(thir_cri) <= 0.80 or np.mean(
                            fourth_cri) > 0.50:
                        signal = False
                    else:
                        signal = True
                    if signal is False:
                        all_the_files = os.listdir(model_dir_sub)
                        for single_file in all_the_files:
                            os.remove(os.path.join(model_dir_sub, single_file))
                        print("mmm The trained model doesn't work, I need to retrain it...")
                    if signal is True:
                        tot_train_val_stat_for_diff_exp_same_step[repeat_time, :] = [np.mean(fourth_cri),
                                                                                     np.mean(first_cri),
                                                                                     np.mean(sec_cri),
                                                                                     np.mean(thir_cri)]
                print("=============================Finish Experiment No.%d============================" % repeat_time)
            # ---------Below is for selecting the best experiment based on the training and validation statistics-----#
            fb_loss_index = np.argmin(tot_train_val_stat_for_diff_exp_same_step[:, 0])
            ed_loss_index = np.argmin(tot_train_val_stat_for_diff_exp_same_step[:, 1])
            fb_f1_index = np.argmax(tot_train_val_stat_for_diff_exp_same_step[:, 2])
            fb_auc_index = np.argmax(tot_train_val_stat_for_diff_exp_same_step[:, 3])
            perf_comp = [fb_loss_index, ed_loss_index, fb_f1_index, fb_auc_index]
            best_per_index = max(set(perf_comp), key=perf_comp.count)
            model_dir_goes_into_act_stage = os.path.join(model_dir, 'rep_%d' % best_per_index)
            print("The selected folder", model_dir_goes_into_act_stage)
            total_folder_info.append(model_dir_goes_into_act_stage)
            tds_select = os.path.join(model_dir_goes_into_act_stage, 'pool_data')
            most_uncertain = SPR_Region_Im(tds_select, model_dir_goes_into_act_stage, acq_selec_method,
                                           already_selected_imindex,
                                           already_selected_binary_mask,
                                           kernel_window,
                                           stride_size,
                                           num_most_uncert_patch, data_path=training_data_path,
                                           check_overconfident=False)
            updated_most_uncertain = prepare_the_new_uncertain_input(most_uncert_old, most_uncertain)
            updated_training_data = update_training_data(most_init_train_data[:4], [], updated_most_uncertain[:4])
            already_selected_imindex = updated_most_uncertain[-1]
            already_selected_binary_mask = updated_most_uncertain[-2]
            most_uncert_old = updated_most_uncertain
            print("The numeric image index for the most uncertain image:\n", already_selected_imindex)
            # np.save(os.path.join(logs_path, 'total_acqu_index'), Already_Selected_Imindex)
            np.save(os.path.join(logs_path, 'num_of_pixel'), num_of_pixels_need_to_be_annotate)
            np.save(os.path.join(logs_path, 'total_select_folder'), total_folder_info)
            np.save(os.path.join(logs_path, 'num_of_image'), total_num_im)
            uncertain_data = os.path.join(logs_path, 'updated_uncertain.txt')
            with open(uncertain_data, 'wb') as f:
                pickle.dump(most_uncert_old, f)


def train(resnet_ckpt, ckpt_dir, model_dir, epoch_size, decay_steps, epsilon_opt, regu_par, batch_size, training_data,
          val_data, FLAG_PRETRAIN=False):
    # --------Here lots of parameters need to be set------Or maybe we could set it in the configuration file-----#
    # batch_size = 5
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    image_w, image_h, image_c = [480, 480, 3]
    IMAGE_SHAPE = np.array([image_w, image_h, image_c])
    targ_height_npy = 528  # this is for padding images
    targ_width_npy = 784  # this is for padding images
    FLAG_DECAY = True
    #    if (Acq_Method == "F") and (Acq_Index_Old is None):
    #        learning_rate = 0.0009
    #    else:
    learning_rate = 0.001
    decay_rate = 0.1
    save_checkpoint_period = 200
    # epsilon_opt = 0.001
    FLAG_L2_REGU = True
    # FLAG_PRETRAIN = False
    ckpt_dir = ckpt_dir
    MOVING_AVERAGE_DECAY = 0.999
    auxi_weight_num = 1
    auxi_decay_step = 300
    val_step_size = 10

    checkpoint_path = os.path.join(model_dir, 'model.ckpt')
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    # ----The part below is for extracting the initial Training Data and Initial Val Data-------------------#
    with tf.Graph().as_default():
        #  This three placeholder is for extracting the augmented training data##
        image_aug_placeholder = tf.placeholder(tf.float32, [batch_size, targ_height_npy, targ_width_npy, 3])
        label_aug_placeholder = tf.placeholder(tf.int64, [batch_size, targ_height_npy, targ_width_npy, 1])
        edge_aug_placeholder = tf.placeholder(tf.int64, [batch_size, targ_height_npy, targ_width_npy, 1])
        binary_mask_aug_placeholder = tf.placeholder(tf.int64, [batch_size, targ_height_npy, targ_width_npy, 1])
        #  The placeholder below is for extracting the input for the network #####
        images_train = tf.placeholder(tf.float32, [batch_size, image_w, image_h, image_c])
        instance_labels_train = tf.placeholder(tf.int64, [batch_size, image_w, image_h, 1])
        edges_labels_train = tf.placeholder(tf.int64, [batch_size, image_w, image_h, 1])
        binary_mask_train = tf.placeholder(tf.int64, [batch_size, image_w, image_h, 1])
        phase_train = tf.placeholder(tf.bool, shape=None, name="training_state")
        dropout_phase = tf.placeholder(tf.bool, shape=None, name="dropout_state")
        auxi_weight = tf.placeholder(tf.float32, shape=None, name="auxiliary_weight")
        global_step = tf.train.get_or_create_global_step()
        #  ----------------------Here is for preparing the dataset for training, pooling and validation---#

        x_image_tr, y_label_tr, y_edge_tr, y_binary_mask_tr = training_data
        x_image_val, y_label_val, y_edge_val, y_binary_mask_val = val_data

        print("-----training data shape----")
        [print(np.shape(v)) for v in training_data]
        print("-----validation data shape---")
        [print(np.shape(v)) for v in val_data]

        iteration = np.shape(x_image_tr)[0] // batch_size

        # ----------Perform data augmentation only on training data------------------------------------------------#
        x_image_aug, y_label_aug, y_edge_aug, y_binary_mask_aug = aug_train_data(image_aug_placeholder,
                                                                                 label_aug_placeholder,
                                                                                 edge_aug_placeholder,
                                                                                 binary_mask_aug_placeholder,
                                                                                 batch_size, True, IMAGE_SHAPE)
        x_image_aug_val, y_label_aug_val, y_edge_aug_val, \
            y_binary_mask_aug_val = aug_train_data(image_aug_placeholder, label_aug_placeholder,
                                                   edge_aug_placeholder, binary_mask_aug_placeholder,
                                                   batch_size, False, IMAGE_SHAPE)

        # ------------------------------Here is for build up the network-------------------------------------------#
        fb_logits, ed_logits = ResNet_V2_DMNN(images=images_train, training_state=phase_train,
                                              dropout_state=dropout_phase, Num_Classes=2)

        edge_loss, edge_f1_score, edge_auc_score = Loss(logits=ed_logits, labels=edges_labels_train,
                                                        binary_mask=binary_mask_train,
                                                        auxi_weight=auxi_weight, loss_name="ed")
        fb_loss, fb_f1_score, fb_auc_score = Loss(logits=fb_logits, labels=instance_labels_train,
                                                  binary_mask=binary_mask_train,
                                                  auxi_weight=auxi_weight, loss_name="fb")

        var_train = tf.trainable_variables()
        total_loss = edge_loss + fb_loss
        if FLAG_L2_REGU is True:
            var_l2 = [v for v in var_train if (('kernel' in v.name) or ('weights' in v.name))]
            total_loss = tf.add_n(
                [total_loss, tf.add_n([tf.nn.l2_loss(v) for v in var_l2 if 'logits' not in v.name]) * regu_par],
                name="Total_Loss")
        # var_opt = [v for v in var_train if ('resnet' not in v.name)]
        # -------------COnduct BackPropagation------------------------------------------------------------#

        train = train_op_batchnorm(total_loss=total_loss, global_step=global_step, initial_learning_rate=learning_rate,
                                   lr_decay_rate=decay_rate, decay_steps=decay_steps,
                                   epsilon_opt=epsilon_opt, var_opt=tf.trainable_variables(),
                                   MOVING_AVERAGE_DECAY=MOVING_AVERAGE_DECAY)

        # summary_op = tf.summary.merge_all()
        if FLAG_PRETRAIN is False:
            set_resnet_var = [v for v in var_train if (v.name.startswith('resnet_v2') & ('logits' not in v.name))]
            saver_set_resnet = tf.train.Saver(set_resnet_var, max_to_keep=3)
            saver_set_all = tf.train.Saver(tf.global_variables(), max_to_keep=1)

        else:
            saver_set_all = tf.train.Saver(max_to_keep=1)

        print("\n =====================================================")
        print("The shape of new training data", np.shape(x_image_tr)[0])
        print("The final validation data size %d" % np.shape(x_image_val)[0])
        print("There are %d iteratioins in each epoch" % iteration)
        print("ckpt files are saved to: ", model_dir)
        print("Epsilon used in Adam optimizer: ", epsilon_opt)
        print("Initial learning rate", learning_rate)
        print("Use the Learning rate weight decay", FLAG_DECAY)
        print("The learning is decayed every %d steps by %.3f " % (decay_steps, decay_rate))
        print("The moving average parameter is ", MOVING_AVERAGE_DECAY)
        print("Batch Size:", batch_size)
        print("Max epochs: ", epoch_size)
        print("Use pretrained model:", FLAG_PRETRAIN)
        print("The checkpoing file is saved every %d steps" % save_checkpoint_period)
        print("The L2 regularization is turned on:", FLAG_L2_REGU)
        print(" =====================================================")
        with tf.Session() as sess:
            if FLAG_PRETRAIN is False:
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())
                saver_set_resnet.restore(sess, resnet_ckpt)
            else:
                ckpt = tf.train.get_checkpoint_state(ckpt_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver_set_all.restore(sess, ckpt.model_checkpoint_path)
                    print("restore parameter from ", ckpt.model_checkpoint_path)
            all_file = os.listdir(model_dir)
            for v in all_file:
                os.remove(os.path.join(model_dir, v))
                print("-------remove the initial trained model-----")

            # train_writer = tf.summary.FileWriter(model_dir, sess.graph)
            train_tot_stat = np.zeros([epoch_size, 4])
            val_tot_stat = np.zeros([epoch_size // val_step_size, 4])
            print(
                "Epoch, foreground-background loss,  "
                "foreground-background accu, contour loss, contour accuracy, total loss")
            for single_epoch in range(epoch_size):
                if auxi_weight_num > 0.001:
                    auxi_weight_num = np.power(0.1, np.floor(single_epoch / auxi_decay_step))
                else:
                    auxi_weight_num = 0
                x_image_sh, y_label_sh, y_edge_sh, y_binary_mask_sh = shuffle(x_image_tr, y_label_tr, y_edge_tr,
                                                                              y_binary_mask_tr)

                batch_index = 0

                train_stat_per_epoch = np.zeros([iteration, 4])
                for single_batch in range(iteration):
                    x_image_batch, y_label_batch, y_edge_batch, y_binary_mask_batch, batch_index = generate_batch(
                        x_image_sh,
                        y_label_sh,
                        y_edge_sh,
                        y_binary_mask_sh,
                        batch_index, batch_size)
                    feed_dict_aug = {image_aug_placeholder: x_image_batch,
                                     label_aug_placeholder: y_label_batch,
                                     edge_aug_placeholder: y_edge_batch,
                                     binary_mask_aug_placeholder: y_binary_mask_batch}
                    x_image_npy, y_label_npy, y_edge_npy, y_binary_mask_npy = sess.run(
                        [x_image_aug, y_label_aug, y_edge_aug, y_binary_mask_aug], feed_dict=feed_dict_aug)

                    feed_dict_op = {images_train: x_image_npy,
                                    instance_labels_train: y_label_npy,
                                    edges_labels_train: y_edge_npy,
                                    binary_mask_train: y_binary_mask_npy,
                                    auxi_weight: auxi_weight_num,
                                    phase_train: True,
                                    dropout_phase: True}
                    fetches_train = [train, fb_loss, fb_f1_score, fb_auc_score, edge_loss]
                    # fetches_train = [train, fb_loss, fb_auc_score, edge_loss]
                    _, _fb_loss, _fb_f1, _fb_auc, _ed_loss = sess.run(fetches=fetches_train, feed_dict=feed_dict_op)
                    # _, _fb_loss, _fb_auc, _ed_loss = sess.run(fetches = fetches_train, feed_dict = feed_dict_op)
                    # _fb_f1 = 0.9
                    # _fb_auc = 0.9
                    train_stat_per_epoch[single_batch, 0] = _fb_loss
                    train_stat_per_epoch[single_batch, 1] = _fb_f1
                    train_stat_per_epoch[single_batch, 2] = _fb_auc
                    train_stat_per_epoch[single_batch, 3] = _ed_loss
                train_tot_stat[single_epoch, :] = np.mean(train_stat_per_epoch, axis=0)
                print(single_epoch, train_tot_stat[single_epoch, :])

                if single_epoch % val_step_size == 0:
                    val_iteration = np.shape(x_image_val)[0] // batch_size
                    print("start validating .......with %d images and %d iterations" % (
                        np.shape(x_image_val)[0], val_iteration))

                    val_batch_index = 0
                    val_stat_per_epoch = np.zeros([val_iteration, 4])
                    for single_batch_val in range(val_iteration):
                        x_image_batch_val, y_label_batch_val, y_edge_batch_val, \
                            y_binary_mask_batch_val, val_batch_index = generate_batch(x_image_val, y_label_val,
                                                                                      y_edge_val, y_binary_mask_val,
                                                                                      val_batch_index, batch_size)
                        feed_dict_aug_val = {image_aug_placeholder: x_image_batch_val,
                                             label_aug_placeholder: y_label_batch_val,
                                             edge_aug_placeholder: y_edge_batch_val,
                                             binary_mask_aug_placeholder: y_binary_mask_batch_val}
                        x_image_npy_val, y_label_npy_val, y_edge_npy_val, y_binary_mask_npy_val = sess.run(
                            [x_image_aug_val,
                             y_label_aug_val,
                             y_edge_aug_val,
                             y_binary_mask_aug_val], feed_dict=feed_dict_aug_val)

                        fetches_valid = [fb_loss, fb_f1_score, fb_auc_score, edge_loss]
                        # fetches_valid = [fb_loss, fb_auc_score, edge_loss]
                        feed_dict_valid = {images_train: x_image_npy_val,
                                           instance_labels_train: y_label_npy_val,
                                           edges_labels_train: y_edge_npy_val,
                                           binary_mask_train: y_binary_mask_npy_val,
                                           auxi_weight: 0,
                                           phase_train: False,
                                           dropout_phase: False}
                        _fbloss_val, _fb_f1_val, _fb_auc_val, _edloss_val = sess.run(fetches=fetches_valid,
                                                                                     feed_dict=feed_dict_valid)
                        # _fb_f1_val = 0.9
                        val_stat_per_epoch[single_batch_val, 0] = _fbloss_val
                        val_stat_per_epoch[single_batch_val, 1] = _fb_f1_val
                        val_stat_per_epoch[single_batch_val, 2] = _fb_auc_val
                        val_stat_per_epoch[single_batch_val, 3] = _edloss_val

                    val_tot_stat[single_epoch // val_step_size, :] = np.mean(val_stat_per_epoch, axis=0)
                    print("validation", single_epoch, val_tot_stat[single_epoch // val_step_size, :])

                if single_epoch % save_checkpoint_period == 0 or single_epoch == (epoch_size - 1):
                    saver_set_all.save(sess, checkpoint_path, global_step=single_epoch)
                if single_epoch == (epoch_size - 1):
                    np.save(os.path.join(model_dir, 'trainstat'), train_tot_stat)
                    np.save(os.path.join(model_dir, 'valstat'), val_tot_stat)

#