Skip to content
Snippets Groups Projects
Train_Active_Region_Im.py 26 KiB
Newer Older
  • Learn to ignore specific revisions
  • blia's avatar
    blia committed
    # -*- coding: utf-8 -*-
    """
    Created on Wed Mar  7 16:42:15 2018
    This file is used to train the active learning framework with region specific annotation
    @author: s161488
    """
    import numpy as np
    import os
    import tensorflow as tf
    from data_utils.prepare_data import aug_train_data, generate_batch
    from data_utils.update_data import give_init_train_and_val_data, update_training_data, prepare_the_new_uncertain_input
    from models.inference import ResNet_V2_DMNN
    from optimization.loss_region_specific import Loss, train_op_batchnorm
    from sklearn.utils import shuffle
    from select_regions import selection as SPR_Region_Im
    import pickle
    
    
    
    print("--------------------------------------------------------------")
    print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
    print("--------------------------------------------------------------")
    training_data_path = "DATA/Data/glanddata.npy"  # NOTE, NEED TO BE MANUALLY DEFINED
    test_data_path = "DATA/Data/glanddata_testb.npy"  # NOTE, NEED TO BE MANUALLY DEFINED
    resnet_dir = "pretrain_model/"
    exp_dir = "Exp_Stat/"  # NOTE, NEED TO BE MANUALLY DEFINED
    ckpt_dir_init = "Exp_Stat/initial_model/"
    print("-------THE PATH FOR THE INITIAL MODEL NEEDS TO BE USER DEFINED", ckpt_dir_init)
    print("--------------------------------------------------------------")
    print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
    print("--------------------------------------------------------------")
    
    blia's avatar
    blia committed
    
    
    
    def run_loop_active_learning_region(stage, round_number=np.array([0, 1, 2, 3])):
    
    blia's avatar
    blia committed
        """This function is used to train the active learning framework with region specific annotation.
        Args:
            stage: int, 0--> random selection, 1--> VarRatio, 2--> entropy, 3--> BALD
    
            round_number: list, [int], repeat experiments in order to get confidence interval
    
    blia's avatar
    blia committed
        Ops:
            1. this script can only be run given the model that is trained with the initial training data (10)!!!
            2. in each acquisition step, the experiment is repeated # times to avoid bad local optimal
            3. Then after we get the updated model, the function SPR_Region_Im is used to evaluate all the regions
            in the pool data. It selects the most #num_most_uncertain_patch from the pool set. And the selections
            are added into the training data.
            4. Again, a new model will be trained as described in step 2 with the updated data from step 3
            5. step 2 to 4 is repeated for #total_active_step times
        """
        for single_round_number in round_number:
            logs_path = exp_dir
            flag_arch_name = "resnet_v2_50"
            resnet_ckpt = os.path.join(resnet_dir, flag_arch_name) + '.ckpt'
            total_active_step = 10
            num_repeat_per_exp = 4
            acq_method_total = ["A", "B", "C", "D"]
            acq_selec_method = acq_method_total[stage]
            kernel_window = np.ones([150, 150])
            stride_size = 30
            num_most_uncert_patch = 20
            logs_path = os.path.join(logs_path,
                                     'Method_%s_Stage_%d_Version_%d' % (acq_selec_method, stage, single_round_number))
            most_init_train_data, all_the_time_val_data = give_init_train_and_val_data(training_data_path)
            num_of_pixels_need_to_be_annotate = np.zeros([total_active_step])
            total_folder_info = []
            total_num_im = np.zeros([total_active_step])
    
            for single_acq_step in range(total_active_step):
                if single_acq_step == 0:
                    tds = os.path.join(ckpt_dir_init, 'pool_data')
                    most_uncertain_data = SPR_Region_Im(tds, ckpt_dir_init, acq_selec_method, None, None, kernel_window,
                                                        stride_size, num_most_uncert_patch=20, data_path=training_data_path,
                                                        check_overconfident=False)
                    updated_training_data = update_training_data(most_init_train_data[:4], [], most_uncertain_data[:4])
                    already_selected_imindex = most_uncertain_data[-1]
                    already_selected_binary_mask = most_uncertain_data[-2]
                    most_uncert_old = most_uncertain_data
                num_of_pixels_need_to_be_annotate[single_acq_step] = np.sum(np.reshape(most_uncert_old[-2], [
                    -1]))  # this is the binary mask, the number of pixels that needs to be annotate
                # equal to the number of pixels which are assigned to be 1
                num_im = np.shape(updated_training_data[0])[0]
                total_num_im[single_acq_step] = num_im
                epsilon_opt = 0.001
                batch_size = 5
                epoch_size = 1300
                model_dir = os.path.join(logs_path, 'FE_step_%d_version_%d' % (single_acq_step, single_round_number))
                tot_train_val_stat_for_diff_exp_same_step = np.zeros(
                    [num_repeat_per_exp, 4])  # fb loss, ed loss, fb f1 score, fb auc score
                if single_acq_step == 5:
                    regu_par = 0.0005
                else:
                    regu_par = 0.001
                if single_acq_step >= 10:
                    decay_steps = np.ceil(num_im / 5) * 600
                else:
                    decay_steps = (num_im // 5) * 600
                for repeat_time in range(num_repeat_per_exp):
                    print("=====================Start Experiment No.%d===========================" % repeat_time)
                    model_dir_sub = os.path.join(model_dir, 'rep_%d' % repeat_time)
                    signal = False
                    while signal is False:
                        signal_for_bad_optimal = False
                        while signal_for_bad_optimal is False:
                            train(resnet_ckpt=resnet_ckpt,
                                  ckpt_dir=None,
                                  model_dir=model_dir_sub,
                                  epoch_size=20,
                                  decay_steps=decay_steps,
                                  epsilon_opt=epsilon_opt,
                                  regu_par=regu_par,
                                  batch_size=batch_size,
                                  training_data=updated_training_data,
                                  val_data=all_the_time_val_data,
                                  FLAG_PRETRAIN=False)
                            train_stat = np.load(os.path.join(model_dir_sub, 'trainstat.npy'))
                            val_stat = np.load(os.path.join(model_dir_sub, 'valstat.npy'))
                            sec_cri = [np.mean(train_stat[-10:, 1]), np.mean(val_stat[-1, 1])]  # fb f1 score
                            thir_cri = [np.mean(train_stat[-10:, 2]), np.mean(val_stat[-1, 2])]  # fb auc score
                            if np.mean(sec_cri) == 0.0 or np.mean(thir_cri) == 0.5:
                                signal_for_bad_optimal = False
                                all_the_files = os.listdir(model_dir_sub)
                                for single_file in all_the_files:
                                    os.remove(os.path.join(model_dir_sub, single_file))
                                print("--------------------The model start from a really bad optimal----------------")
                            else:
                                signal_for_bad_optimal = True
                        train(resnet_ckpt=resnet_ckpt,
                              ckpt_dir=model_dir_sub,
                              model_dir=model_dir_sub,
                              epoch_size=epoch_size,
                              decay_steps=decay_steps,
                              epsilon_opt=epsilon_opt,
                              regu_par=regu_par,
                              batch_size=batch_size,
                              training_data=updated_training_data,
                              val_data=all_the_time_val_data,
                              FLAG_PRETRAIN=True)
                        train_stat = np.load(os.path.join(model_dir_sub, 'trainstat.npy'))
                        val_stat = np.load(os.path.join(model_dir_sub, 'valstat.npy'))
                        first_cri = [np.mean(train_stat[-20:, -1]), np.mean(val_stat[-10:, -1])]  # ed loss
                        sec_cri = [np.mean(train_stat[-20:, 1]), np.mean(val_stat[-10:, 1])]  # fb f1 score
                        thir_cri = [np.mean(train_stat[-20:, 2]), np.mean(val_stat[-10:, 2])]  # fb auc score
                        fourth_cri = [np.mean(train_stat[-20:, 0]), np.mean(val_stat[-10:, 0])]  # fb loss
                        if np.mean(first_cri) >= 0.30 or np.mean(sec_cri) <= 0.80 or np.mean(thir_cri) <= 0.80 or np.mean(
                                fourth_cri) > 0.50:
                            signal = False
                        else:
                            signal = True
                        if signal is False:
                            all_the_files = os.listdir(model_dir_sub)
                            for single_file in all_the_files:
                                os.remove(os.path.join(model_dir_sub, single_file))
                            print("mmm The trained model doesn't work, I need to retrain it...")
                        if signal is True:
                            tot_train_val_stat_for_diff_exp_same_step[repeat_time, :] = [np.mean(fourth_cri),
                                                                                         np.mean(first_cri),
                                                                                         np.mean(sec_cri),
                                                                                         np.mean(thir_cri)]
                    print("=============================Finish Experiment No.%d============================" % repeat_time)
                # ---------Below is for selecting the best experiment based on the training and validation statistics-----#
                fb_loss_index = np.argmin(tot_train_val_stat_for_diff_exp_same_step[:, 0])
                ed_loss_index = np.argmin(tot_train_val_stat_for_diff_exp_same_step[:, 1])
                fb_f1_index = np.argmax(tot_train_val_stat_for_diff_exp_same_step[:, 2])
                fb_auc_index = np.argmax(tot_train_val_stat_for_diff_exp_same_step[:, 3])
                perf_comp = [fb_loss_index, ed_loss_index, fb_f1_index, fb_auc_index]
                best_per_index = max(set(perf_comp), key=perf_comp.count)
                model_dir_goes_into_act_stage = os.path.join(model_dir, 'rep_%d' % best_per_index)
                print("The selected folder", model_dir_goes_into_act_stage)
                total_folder_info.append(model_dir_goes_into_act_stage)
                tds_select = os.path.join(model_dir_goes_into_act_stage, 'pool_data')
                most_uncertain = SPR_Region_Im(tds_select, model_dir_goes_into_act_stage, acq_selec_method,
                                               already_selected_imindex,
                                               already_selected_binary_mask,
                                               kernel_window,
                                               stride_size,
                                               num_most_uncert_patch, data_path=training_data_path,
                                               check_overconfident=False)
                updated_most_uncertain = prepare_the_new_uncertain_input(most_uncert_old, most_uncertain)
                updated_training_data = update_training_data(most_init_train_data[:4], [], updated_most_uncertain[:4])
                already_selected_imindex = updated_most_uncertain[-1]
                already_selected_binary_mask = updated_most_uncertain[-2]
                most_uncert_old = updated_most_uncertain
                print("The numeric image index for the most uncertain image:\n", already_selected_imindex)
                # np.save(os.path.join(logs_path, 'total_acqu_index'), Already_Selected_Imindex)
                np.save(os.path.join(logs_path, 'num_of_pixel'), num_of_pixels_need_to_be_annotate)
                np.save(os.path.join(logs_path, 'total_select_folder'), total_folder_info)
                np.save(os.path.join(logs_path, 'num_of_image'), total_num_im)
                uncertain_data = os.path.join(logs_path, 'updated_uncertain.txt')
                with open(uncertain_data, 'wb') as f:
                    pickle.dump(most_uncert_old, f)
    
    
    def train(resnet_ckpt, ckpt_dir, model_dir, epoch_size, decay_steps, epsilon_opt, regu_par, batch_size, training_data,
              val_data, FLAG_PRETRAIN=False):
        # --------Here lots of parameters need to be set------Or maybe we could set it in the configuration file-----#
        # batch_size = 5
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        image_w, image_h, image_c = [480, 480, 3]
        IMAGE_SHAPE = np.array([image_w, image_h, image_c])
        targ_height_npy = 528  # this is for padding images
        targ_width_npy = 784  # this is for padding images
        FLAG_DECAY = True
        #    if (Acq_Method == "F") and (Acq_Index_Old is None):
        #        learning_rate = 0.0009
        #    else:
        learning_rate = 0.001
        decay_rate = 0.1
        save_checkpoint_period = 200
        # epsilon_opt = 0.001
        FLAG_L2_REGU = True
        # FLAG_PRETRAIN = False
        ckpt_dir = ckpt_dir
        MOVING_AVERAGE_DECAY = 0.999
        auxi_weight_num = 1
        auxi_decay_step = 300
        val_step_size = 10
    
        checkpoint_path = os.path.join(model_dir, 'model.ckpt')
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        # ----The part below is for extracting the initial Training Data and Initial Val Data-------------------#
        with tf.Graph().as_default():
            #  This three placeholder is for extracting the augmented training data##
            image_aug_placeholder = tf.placeholder(tf.float32, [batch_size, targ_height_npy, targ_width_npy, 3])
            label_aug_placeholder = tf.placeholder(tf.int64, [batch_size, targ_height_npy, targ_width_npy, 1])
            edge_aug_placeholder = tf.placeholder(tf.int64, [batch_size, targ_height_npy, targ_width_npy, 1])
            binary_mask_aug_placeholder = tf.placeholder(tf.int64, [batch_size, targ_height_npy, targ_width_npy, 1])
            #  The placeholder below is for extracting the input for the network #####
            images_train = tf.placeholder(tf.float32, [batch_size, image_w, image_h, image_c])
            instance_labels_train = tf.placeholder(tf.int64, [batch_size, image_w, image_h, 1])
            edges_labels_train = tf.placeholder(tf.int64, [batch_size, image_w, image_h, 1])
            binary_mask_train = tf.placeholder(tf.int64, [batch_size, image_w, image_h, 1])
            phase_train = tf.placeholder(tf.bool, shape=None, name="training_state")
            dropout_phase = tf.placeholder(tf.bool, shape=None, name="dropout_state")
            auxi_weight = tf.placeholder(tf.float32, shape=None, name="auxiliary_weight")
            global_step = tf.train.get_or_create_global_step()
            #  ----------------------Here is for preparing the dataset for training, pooling and validation---#
    
            x_image_tr, y_label_tr, y_edge_tr, y_binary_mask_tr = training_data
            x_image_val, y_label_val, y_edge_val, y_binary_mask_val = val_data
    
            print("-----training data shape----")
            [print(np.shape(v)) for v in training_data]
            print("-----validation data shape---")
            [print(np.shape(v)) for v in val_data]
    
            iteration = np.shape(x_image_tr)[0] // batch_size
    
            # ----------Perform data augmentation only on training data------------------------------------------------#
            x_image_aug, y_label_aug, y_edge_aug, y_binary_mask_aug = aug_train_data(image_aug_placeholder,
                                                                                     label_aug_placeholder,
                                                                                     edge_aug_placeholder,
                                                                                     binary_mask_aug_placeholder,
                                                                                     batch_size, True, IMAGE_SHAPE)
            x_image_aug_val, y_label_aug_val, y_edge_aug_val, \
                y_binary_mask_aug_val = aug_train_data(image_aug_placeholder, label_aug_placeholder,
                                                       edge_aug_placeholder, binary_mask_aug_placeholder,
                                                       batch_size, False, IMAGE_SHAPE)
    
            # ------------------------------Here is for build up the network-------------------------------------------#
            fb_logits, ed_logits = ResNet_V2_DMNN(images=images_train, training_state=phase_train,
                                                  dropout_state=dropout_phase, Num_Classes=2)
    
            edge_loss, edge_f1_score, edge_auc_score = Loss(logits=ed_logits, labels=edges_labels_train,
                                                            binary_mask=binary_mask_train,
                                                            auxi_weight=auxi_weight, loss_name="ed")
            fb_loss, fb_f1_score, fb_auc_score = Loss(logits=fb_logits, labels=instance_labels_train,
                                                      binary_mask=binary_mask_train,
                                                      auxi_weight=auxi_weight, loss_name="fb")
    
            var_train = tf.trainable_variables()
            total_loss = edge_loss + fb_loss
            if FLAG_L2_REGU is True:
                var_l2 = [v for v in var_train if (('kernel' in v.name) or ('weights' in v.name))]
                total_loss = tf.add_n(
                    [total_loss, tf.add_n([tf.nn.l2_loss(v) for v in var_l2 if 'logits' not in v.name]) * regu_par],
                    name="Total_Loss")
            # var_opt = [v for v in var_train if ('resnet' not in v.name)]
            # -------------COnduct BackPropagation------------------------------------------------------------#
    
            train = train_op_batchnorm(total_loss=total_loss, global_step=global_step, initial_learning_rate=learning_rate,
                                       lr_decay_rate=decay_rate, decay_steps=decay_steps,
                                       epsilon_opt=epsilon_opt, var_opt=tf.trainable_variables(),
                                       MOVING_AVERAGE_DECAY=MOVING_AVERAGE_DECAY)
    
            # summary_op = tf.summary.merge_all()
            if FLAG_PRETRAIN is False:
                set_resnet_var = [v for v in var_train if (v.name.startswith('resnet_v2') & ('logits' not in v.name))]
                saver_set_resnet = tf.train.Saver(set_resnet_var, max_to_keep=3)
                saver_set_all = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    
            else:
                saver_set_all = tf.train.Saver(max_to_keep=1)
    
            print("\n =====================================================")
            print("The shape of new training data", np.shape(x_image_tr)[0])
            print("The final validation data size %d" % np.shape(x_image_val)[0])
            print("There are %d iteratioins in each epoch" % iteration)
            print("ckpt files are saved to: ", model_dir)
            print("Epsilon used in Adam optimizer: ", epsilon_opt)
            print("Initial learning rate", learning_rate)
            print("Use the Learning rate weight decay", FLAG_DECAY)
            print("The learning is decayed every %d steps by %.3f " % (decay_steps, decay_rate))
            print("The moving average parameter is ", MOVING_AVERAGE_DECAY)
            print("Batch Size:", batch_size)
            print("Max epochs: ", epoch_size)
            print("Use pretrained model:", FLAG_PRETRAIN)
            print("The checkpoing file is saved every %d steps" % save_checkpoint_period)
            print("The L2 regularization is turned on:", FLAG_L2_REGU)
            print(" =====================================================")
            with tf.Session() as sess:
                if FLAG_PRETRAIN is False:
                    sess.run(tf.global_variables_initializer())
                    sess.run(tf.local_variables_initializer())
                    saver_set_resnet.restore(sess, resnet_ckpt)
                else:
                    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
                    if ckpt and ckpt.model_checkpoint_path:
                        saver_set_all.restore(sess, ckpt.model_checkpoint_path)
                        print("restore parameter from ", ckpt.model_checkpoint_path)
                all_file = os.listdir(model_dir)
                for v in all_file:
                    os.remove(os.path.join(model_dir, v))
                    print("-------remove the initial trained model-----")
    
                # train_writer = tf.summary.FileWriter(model_dir, sess.graph)
                train_tot_stat = np.zeros([epoch_size, 4])
                val_tot_stat = np.zeros([epoch_size // val_step_size, 4])
                print(
                    "Epoch, foreground-background loss,  "
                    "foreground-background accu, contour loss, contour accuracy, total loss")
                for single_epoch in range(epoch_size):
                    if auxi_weight_num > 0.001:
                        auxi_weight_num = np.power(0.1, np.floor(single_epoch / auxi_decay_step))
                    else:
                        auxi_weight_num = 0
                    x_image_sh, y_label_sh, y_edge_sh, y_binary_mask_sh = shuffle(x_image_tr, y_label_tr, y_edge_tr,
                                                                                  y_binary_mask_tr)
    
                    batch_index = 0
    
                    train_stat_per_epoch = np.zeros([iteration, 4])
                    for single_batch in range(iteration):
                        x_image_batch, y_label_batch, y_edge_batch, y_binary_mask_batch, batch_index = generate_batch(
                            x_image_sh,
                            y_label_sh,
                            y_edge_sh,
                            y_binary_mask_sh,
                            batch_index, batch_size)
                        feed_dict_aug = {image_aug_placeholder: x_image_batch,
                                         label_aug_placeholder: y_label_batch,
                                         edge_aug_placeholder: y_edge_batch,
                                         binary_mask_aug_placeholder: y_binary_mask_batch}
                        x_image_npy, y_label_npy, y_edge_npy, y_binary_mask_npy = sess.run(
                            [x_image_aug, y_label_aug, y_edge_aug, y_binary_mask_aug], feed_dict=feed_dict_aug)
    
                        feed_dict_op = {images_train: x_image_npy,
                                        instance_labels_train: y_label_npy,
                                        edges_labels_train: y_edge_npy,
                                        binary_mask_train: y_binary_mask_npy,
                                        auxi_weight: auxi_weight_num,
                                        phase_train: True,
                                        dropout_phase: True}
                        fetches_train = [train, fb_loss, fb_f1_score, fb_auc_score, edge_loss]
                        # fetches_train = [train, fb_loss, fb_auc_score, edge_loss]
                        _, _fb_loss, _fb_f1, _fb_auc, _ed_loss = sess.run(fetches=fetches_train, feed_dict=feed_dict_op)
                        # _, _fb_loss, _fb_auc, _ed_loss = sess.run(fetches = fetches_train, feed_dict = feed_dict_op)
                        # _fb_f1 = 0.9
                        # _fb_auc = 0.9
                        train_stat_per_epoch[single_batch, 0] = _fb_loss
                        train_stat_per_epoch[single_batch, 1] = _fb_f1
                        train_stat_per_epoch[single_batch, 2] = _fb_auc
                        train_stat_per_epoch[single_batch, 3] = _ed_loss
                    train_tot_stat[single_epoch, :] = np.mean(train_stat_per_epoch, axis=0)
                    print(single_epoch, train_tot_stat[single_epoch, :])
    
                    if single_epoch % val_step_size == 0:
                        val_iteration = np.shape(x_image_val)[0] // batch_size
                        print("start validating .......with %d images and %d iterations" % (
                            np.shape(x_image_val)[0], val_iteration))
    
                        val_batch_index = 0
                        val_stat_per_epoch = np.zeros([val_iteration, 4])
                        for single_batch_val in range(val_iteration):
                            x_image_batch_val, y_label_batch_val, y_edge_batch_val, \
                                y_binary_mask_batch_val, val_batch_index = generate_batch(x_image_val, y_label_val,
                                                                                          y_edge_val, y_binary_mask_val,
                                                                                          val_batch_index, batch_size)
                            feed_dict_aug_val = {image_aug_placeholder: x_image_batch_val,
                                                 label_aug_placeholder: y_label_batch_val,
                                                 edge_aug_placeholder: y_edge_batch_val,
                                                 binary_mask_aug_placeholder: y_binary_mask_batch_val}
                            x_image_npy_val, y_label_npy_val, y_edge_npy_val, y_binary_mask_npy_val = sess.run(
                                [x_image_aug_val,
                                 y_label_aug_val,
                                 y_edge_aug_val,
                                 y_binary_mask_aug_val], feed_dict=feed_dict_aug_val)
    
                            fetches_valid = [fb_loss, fb_f1_score, fb_auc_score, edge_loss]
                            # fetches_valid = [fb_loss, fb_auc_score, edge_loss]
                            feed_dict_valid = {images_train: x_image_npy_val,
                                               instance_labels_train: y_label_npy_val,
                                               edges_labels_train: y_edge_npy_val,
                                               binary_mask_train: y_binary_mask_npy_val,
                                               auxi_weight: 0,
                                               phase_train: False,
                                               dropout_phase: False}
                            _fbloss_val, _fb_f1_val, _fb_auc_val, _edloss_val = sess.run(fetches=fetches_valid,
                                                                                         feed_dict=feed_dict_valid)
                            # _fb_f1_val = 0.9
                            val_stat_per_epoch[single_batch_val, 0] = _fbloss_val
                            val_stat_per_epoch[single_batch_val, 1] = _fb_f1_val
                            val_stat_per_epoch[single_batch_val, 2] = _fb_auc_val
                            val_stat_per_epoch[single_batch_val, 3] = _edloss_val
    
                        val_tot_stat[single_epoch // val_step_size, :] = np.mean(val_stat_per_epoch, axis=0)
                        print("validation", single_epoch, val_tot_stat[single_epoch // val_step_size, :])
    
                    if single_epoch % save_checkpoint_period == 0 or single_epoch == (epoch_size - 1):
                        saver_set_all.save(sess, checkpoint_path, global_step=single_epoch)
                    if single_epoch == (epoch_size - 1):
                        np.save(os.path.join(model_dir, 'trainstat'), train_tot_stat)
                        np.save(os.path.join(model_dir, 'valstat'), val_tot_stat)
    
    #