Skip to content
Snippets Groups Projects
loss_region_specific.py 6.24 KiB
Newer Older
  • Learn to ignore specific revisions
  • blia's avatar
    blia committed
    # -*- coding: utf-8 -*-
    """
    Created on Fri Mar  2 11:56:26 2018
    
    @author: s161488
    """
    ############################################################
    #  Loss Functions
    ############################################################
    import tensorflow as tf
    import numpy as np
    from sklearn.metrics import f1_score, roc_curve, auc
    
    
    def calc_f1_score(pred, label):
        """
        pred: [batch_size, im_h, im_w, 1]
        label: [batch_size, im_h, im_w, 1]
        """
        pred = np.reshape(pred, [-1])
        label = np.reshape(label, [-1])
        f1score = f1_score(y_true=label, y_pred=pred)
    
        return f1score.astype('float32')
    
    
    def tf_f1score(pred, label):
        f1score_tensor = tf.py_func(calc_f1_score, [pred, label], tf.float32)
        return f1score_tensor
    
    
    def calc_auc_score(pred, label):
        """
        In this function, the pred is not just predicted label, instead it's the predicted probability
        pred: [batch_size, im_h, im_w, 1]
        label: [batch_size, im_h, im_w, 1]
        """
        pred = np.reshape(pred, [-1])
        label = np.reshape(label, [-1])
        fpr, tpr, thresholds = roc_curve(y_true=label, y_score=pred,
                                         pos_label=1)
        auc_value = auc(fpr, tpr)
        return auc_value.astype('float32')
    
    
    def tf_auc_score(pred, label):
        auc_score_tensor = tf.py_func(calc_auc_score, [pred, label], tf.float32)
        return auc_score_tensor
    
    
    def Loss(logits, labels, binary_mask, auxi_weight, loss_name):
        """Calculate the cross entropy loss for the edge detection.
        Args:
        logits: The logits output from the edge detection channel. Shape [6, Num_Batch, Image_Height, Image_Width, 1] (6 is because 5 
        side-output, and 1 fuse-output). 
        labels: The edge label. Shape [Num_Batch, Image_height, Image_Width, 1].
        DCAN: A boolean variable. If DCAN is True. The last channel we use softmax_cross_entropy. If it's False, then last feature map
        we use sigmoid_loss. 
        
        Returns:
        Edge_loss: the total loss for the edge detection channel.
        
        Operations:
        Because there is a class imbalancing situation, the pixels belong to background must be much larger than the pixels belong to
        edge, so that we add a penalty beta. beta = Y_/Y. 
        """
        y = tf.reshape(labels, [-1])
        y = tf.cast(tf.not_equal(y, 0), tf.int32)
        Num_Map = np.shape(logits)[0]
        cost = 0
        binary_mask = tf.reshape(binary_mask, [-1])  # [batch_size, imh, imw, 1]
        binary_mask = tf.cast(tf.not_equal(binary_mask, 0), tf.int32)
        for i in range(Num_Map):
            cross_entropy_sep = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=tf.reshape(logits[i], [-1, 2]),
                                                                               labels=y)
            cross_entropy_sep = tf.boolean_mask(cross_entropy_sep, tf.equal(binary_mask, 1))
            cross_entropy_sep = tf.reduce_mean(cross_entropy_sep, name='auxiliary' + loss_name + '%d' % i)
            tf.add_to_collection('loss', cross_entropy_sep)
            cost += cross_entropy_sep
        fuse_map = tf.add_n(logits)
        fuse_cost = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=tf.reshape(fuse_map, [-1, 2]))
        fuse_cost = tf.reduce_mean(tf.boolean_mask(fuse_cost, tf.equal(binary_mask, 1)), name=loss_name + 'fuse_cost')
        tf.add_to_collection('loss', fuse_cost)
        cost = auxi_weight * cost + fuse_cost
        pred = tf.reshape(tf.argmax(fuse_map, axis=-1, output_type=tf.int32), [-1])
        pred_bi = tf.boolean_mask(pred, tf.equal(binary_mask, 1))
        y_bi = tf.boolean_mask(y, tf.equal(binary_mask, 1))
        pred_bi_cond_f1 = tf.equal(tf.reduce_sum(pred_bi), 0)
        y_bi_cond_f1 = tf.equal(tf.reduce_sum(y_bi), 0)
        accuracy = tf.cond(tf.logical_and(pred_bi_cond_f1, y_bi_cond_f1),
                           lambda: tf.constant(1.0),
                           lambda: tf_f1score(pred=pred_bi, label=y_bi))
        auc_pred_bi = tf.boolean_mask(tf.reshape(fuse_map[:, :, :, 1], [-1]), tf.equal(binary_mask, 1))
        auc_score = tf.cond(tf.equal(tf.reduce_mean(y_bi), 1),
                            lambda: tf.constant(0.0),
                            lambda: tf_auc_score(pred=auc_pred_bi, label=y_bi))
        return cost, accuracy, auc_score
    
    
    def _add_loss_summaries(total_loss):
        """Add summaries for losses in CIFAR-10 model.
        Generates moving average for all losses and associated summaries for
        visualizing the performance of the network.
        Args:
          total_loss: Total loss from loss().
        Returns:
          loss_averages_op: op for generating moving averages of losses.
        """
        # Compute the moving average of all individual losses and the total loss.
        loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
        losses = tf.get_collection('loss')
        loss_averages_op = loss_averages.apply(losses + [total_loss])
        return loss_averages_op
    
    
    def train_op_batchnorm(total_loss, global_step, initial_learning_rate, lr_decay_rate, decay_steps, epsilon_opt, var_opt,
                           MOVING_AVERAGE_DECAY):
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    
        lr = tf.train.exponential_decay(initial_learning_rate,
                                        global_step,
                                        decay_steps,
                                        lr_decay_rate,
                                        staircase=False)
        # if staircase is True, then the division between global_step/decay_steps is a interger,
        # otherwise it's not an interger.
        tf.summary.scalar('learning_rate', lr)
    
        variables_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
        # variables_averages_op = variables_averages.apply(var_opt)
        with tf.control_dependencies(update_ops):
            loss_averages_op = _add_loss_summaries(total_loss)
            with tf.control_dependencies([loss_averages_op]):
                opt = tf.train.AdamOptimizer(lr, epsilon=epsilon_opt)
                grads = opt.compute_gradients(total_loss, var_list=var_opt)
    
            for var in tf.trainable_variables():
                tf.summary.histogram(var.op.name, var)
    
            # Add histograms for gradients.
            for grad, var in grads:
                if grad is not None:
                    tf.summary.histogram(var.op.name + '/gradients', grad)
    
            apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
            with tf.control_dependencies([apply_gradient_op]):
                train_op = variables_averages.apply(var_opt)
    
        return train_op