Skip to content
Snippets Groups Projects
prepare_data.py 12.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • blia's avatar
    blia committed
    # -*- coding: utf-8 -*-
    """
    Created on Wed May 16 14:27:21 2018
    This script is for preparing the input for the active learning, we need to have training data, pool data,
    validation data.
    @author: s161488
    """
    import numpy as np
    import tensorflow as tf
    import cv2
    import matplotlib.pyplot as plt
    
    
    blia's avatar
    blia committed
    path_mom = "DATA/"  # NOTE, NEED TO BE MANUALLY DEFINED
    
    blia's avatar
    blia committed
    
    
    def prepare_train_data(path, select_benign_train, select_mali_train):
        """
        Args:
            path: the path where the data is saved
            select_benign_train: a list of selected benign images
            select_mali_train: a list of selected malignant images
        Ops:
            First, the images, labels. edges, im_index, cls_index can be extracted from the np.load
            images: shape [85, im_h, im_w, 3]
            labels: shape [85, im_h, im_w, 1]
            im_index: shape [85]
            cls_index: shape [85]
            Start_Point will determine how many images are initialized as training image
        Output:
        
        X_train, Y_train
        X_pool, Y_pool
        X_val, Y_val    
        """
        data_set = np.load(path, allow_pickle=True).item()
        images = data_set['image']
        labels = data_set['label']
        edges = data_set['edge']
        imageindex = data_set['ImageIndex']
        classindex = data_set['ClassIndex']
        benign_index = np.where(np.array(classindex) == 1)
        mali_index = np.where(np.array(classindex) == 2)
    
        choose_index_tr = np.concatenate([benign_index[0][select_benign_train], mali_index[0][select_mali_train]], axis=0)
        benign_index_left = np.delete(range(np.shape(benign_index[0])[0]), select_benign_train)
        mali_index_left = np.delete(range(np.shape(mali_index[0])[0]), select_mali_train)
    
        choose_index_pl = np.concatenate([benign_index[0][benign_index_left[:27]], mali_index[0][mali_index_left[:38]]],
                                         axis=0)
        choose_index_val = np.concatenate([benign_index[0][benign_index_left[-5:]], mali_index[0][mali_index_left[-5:]]],
                                          axis=0)
        data_train = extract_diff_data(images, labels, edges, imageindex, classindex, choose_index_tr)
        data_pl = extract_diff_data(images, labels, edges, imageindex, classindex, choose_index_pl)
        data_val = extract_diff_data(images, labels, edges, imageindex, classindex, choose_index_val)
    
        return data_train, data_pl, data_val
    
    
    def prepare_pool_data(path, aug=False):
        data_set = np.load(path, allow_pickle=True).item()
        images = data_set['image']
        labels = data_set['label']
        edges = data_set['edge']
        imageindex = data_set['ImageIndex']
        classindex = data_set['ClassIndex']
        select_benign_train = [0, 1, 2, 3, 4]
        select_mali_train = [2, 4, 5, 6, 7]
    
        benign_index = np.where(np.array(classindex) == 1)
        mali_index = np.where(np.array(classindex) == 2)
    
        benign_index_left = np.delete(range(np.shape(benign_index[0])[0]), select_benign_train)
        mali_index_left = np.delete(range(np.shape(mali_index[0])[0]), select_mali_train)
    
        choose_index_pl = np.concatenate([benign_index[0][benign_index_left[:27]], mali_index[0][mali_index_left[:38]]],
                                         axis=0)
        data_pl = extract_diff_data(images, labels, edges, imageindex, classindex, choose_index_pl)
        if aug is True:
            targ_height_npy = 528  # this is for padding images
            targ_width_npy = 784  # this is for padding images
            x_image_val, y_label_val, y_edge_val = padding_training_data(data_pl[0], data_pl[1],
                                                                         data_pl[2], targ_height_npy,
                                                                         targ_width_npy)
            data_pl = [x_image_val, y_label_val, y_edge_val]
    
        return data_pl
    
    
    def prepare_skin_data(path, num_tr, combine=True):
        """
        choose_index_tr: worst 16+best 16
        or middle 32
        this num_tr should be 1/2*total_number_of_training_images_at_inital_step
        I have tried it for 32, then I am going to check 16
    
        """
        val_num_im = 96
        tot_numeric_index = np.arange(900)
        if combine is True:
            tr_select_numeric_index = np.concatenate([tot_numeric_index[:num_tr], tot_numeric_index[-num_tr:]], axis=0)
        else:
            tr_select_numeric_index = tot_numeric_index[340:(340 + num_tr * 2)]
        val_select_numeric_index = tot_numeric_index[500:(500 + val_num_im)]
        pool_numeric_index = np.delete(tot_numeric_index,
                                       np.concatenate([tr_select_numeric_index, val_select_numeric_index], axis=0))
        im_seg_score = np.load('/home/s161488/Exp_Stat/Skin_Lesion/init_segment_score.npy')
        sorted_index = np.argsort(im_seg_score)
    
        data_set = np.load(path, encoding='latin1').item()
        images = np.array(data_set['image'])
        labels = np.array(data_set['label'])
        edges = np.array(data_set['edge'])
    
        labels = np.expand_dims(labels, axis=-1)
        edges = np.expand_dims(edges, axis=-1)
        tr_select_image_index = np.sort(sorted_index[tr_select_numeric_index])
        val_select_image_index = np.sort(sorted_index[val_select_numeric_index])
        pl_select_image_index = np.sort(sorted_index[pool_numeric_index])
    
        imindex = np.arange(np.shape(images)[0])
        clsindex = np.ones(np.shape(images)[0])
    
        data_tr = extract_diff_data(images, labels, edges, imindex, clsindex, tr_select_image_index)
        data_pl = extract_diff_data(images, labels, edges, imindex, clsindex, pl_select_image_index)
        data_val = extract_diff_data(images, labels, edges, imindex, clsindex, val_select_image_index)
    
        return data_tr[:3], data_pl[:3], data_val[:3]
    
    
    def prepare_test_data(path):
        if "_test" not in path:
            print("-------I am loading the data from pool set------")
            return prepare_pool_data(path)
        data_set = np.load(path, allow_pickle=True).item()
        images = data_set['image']
        labels = data_set['label']
        edges = data_set['edge']
        imageindex = data_set['ImageIndex']
        classindex = data_set['ClassIndex']
        return images, labels, edges, imageindex, classindex
    
    
    def generate_batch(x_image_tr, y_label_tr, y_edge_tr, y_binary_mask_tr, batch_index, batch_size):
        im_group = [x_image_tr, y_label_tr, y_edge_tr, y_binary_mask_tr]
        im_batch = []
        for single_im in im_group:
            _im_batch = single_im[batch_index:(batch_size + batch_index)]
            im_batch.append(_im_batch)
        batch_index = batch_index + batch_size
        return im_batch[0], im_batch[1], im_batch[2], im_batch[3], batch_index
    
    
    def padding_training_data(x_image, y_label, y_edge, target_height, target_width):
        """Each image has different size, so I need to pad it with zeros to make sure each image has the same size.
           Then I can perform random crop, rotation and other augmentation on per batch instead of per image
        """
        x_im_pad, y_la_pad, y_ed_pad = [], [], []
        num_image = np.shape(x_image)[0]
        for i in range(num_image):
            image_pad, label_pad, edge_pad = padding_zeros(x_image[i], y_label[i], y_edge[i], target_height, target_width)
            x_im_pad.append(image_pad)
            y_la_pad.append(label_pad)
            y_ed_pad.append(edge_pad)
        x_im_pad = np.reshape(x_im_pad, [num_image, target_height, target_width, 3])
        y_la_pad = np.reshape(y_la_pad, [num_image, target_height, target_width, 1])
        y_ed_pad = np.reshape(y_ed_pad, [num_image, target_height, target_width, 1])
        return x_im_pad, y_la_pad, y_ed_pad
    
    
    def padding_zeros(image, label, edge, target_height, target_width):
        im_h, im_w, _ = np.shape(image)
        delta_w = target_width - im_w
        delta_h = target_height - im_h
        top, bottom = delta_h // 2, delta_h - delta_h // 2
        left, right = delta_w // 2, delta_w - delta_w // 2
        image_pad = np.pad(image, ((top, bottom), (left, right), (0, 0)), mode='constant')
        label_pad = np.pad(label, ((top, bottom), (left, right)), mode='constant')
        edge_pad = np.pad(edge, ((top, bottom), (left, right)), mode='constant')
        return image_pad, label_pad, edge_pad
    
    
    def extract_diff_data(image, label, edge, im_index, cls_index, choose_index):
        new_data = [[] for _ in range(5)]
        old_data = [image, label, edge, im_index, cls_index]
        for i in choose_index:
            for single_new, single_old in zip(new_data, old_data):
                single_new.append(single_old[i])
        return new_data[0], new_data[1], new_data[2], new_data[3], new_data[4]
    
    
    def aug_train_data(image, label, edge, binary_mask, batch_size, aug, imshape):
        """This function is used for performing data augmentation. 
        image: placeholder. shape: [Batch_Size, im_h, im_w, 3], tf.float32
        label: placeholder. shape: [Batch_Size, im_h, im_w, 1], tf.int64
        edge: placeholder. shape: [Batch_Size, im_h, im_w, 1], tf.int64
        binary_mask: placeholder. shape: [Batch_Size, im_h, im_w, 1], tf.int64
        aug: bool
        imshape: [targ_h, targ_w, ch]
        Outputs:
        image: [Batch_Size, targ_h, targ_w, 3]
        label: [Batch_Size, targ_h, targ_w, 1]
        edge: [Batch_Size, targ_h, targ_w, 1]
        binary_mask: [Batch_Size, targ_h, targ_w, 1]
        
        """
        image = tf.cast(image, tf.int64)
        bigmatrix = tf.concat([image, label, edge, binary_mask], axis=3)
        target_height = imshape[0].astype('int32')
        target_width = imshape[1].astype('int32')
        if aug is True:
            bigmatrix_crop = tf.random_crop(bigmatrix, size=[batch_size, target_height, target_width, 6])
            bigmatrix_crop = tf.cond(tf.less_equal(tf.reduce_sum(bigmatrix_crop[:, :, :, 5]), 10),
                                     lambda: tf.image.resize_image_with_crop_or_pad(bigmatrix, target_height, target_width),
                                     lambda: bigmatrix_crop)
            # instead of judging by label, should do it by the binary mask!
            k = tf.random_uniform(shape=[batch_size], minval=0, maxval=6.5, dtype=tf.float32)
            bigmatrix_rot = tf.contrib.image.rotate(bigmatrix_crop, angles=k)
            image_aug = tf.cast(bigmatrix_rot[:, :, :, 0:3], tf.float32)
            label_aug = bigmatrix_rot[:, :, :, 3]
            edge_aug = bigmatrix_rot[:, :, :, 4]
            binary_mask_aug = bigmatrix_rot[:, :, :, 5]
        else:
            bigmatrix_rot = tf.image.resize_image_with_crop_or_pad(bigmatrix, target_height, target_width)
            image_aug = tf.cast(tf.cast(bigmatrix_rot[:, :, :, 0:3], tf.uint8), tf.float32)
            label_aug = tf.cast(bigmatrix_rot[:, :, :, 3], tf.int64)
            edge_aug = tf.cast(bigmatrix_rot[:, :, :, 4], tf.int64)
            binary_mask_aug = tf.cast(bigmatrix_rot[:, :, :, 5], tf.int64)
        return image_aug, tf.expand_dims(label_aug, -1), tf.expand_dims(edge_aug, -1), tf.expand_dims(binary_mask_aug, -1)
    
    
    def collect_test_data(resize=True):
        test_a_path = path_mom + "/Data/glanddata_testa.npy"
        test_b_path = path_mom + "/Data/glanddata_testb.npy"
        image_tot, label_tot = [], []
        target_height, target_width = 528, 784
        for single_path in [test_a_path, test_b_path]:
            data_set = np.load(single_path, allow_pickle=True).item()
            images = data_set['image']
            y_label_pl = data_set['label']
            y_edge_pl = data_set['edge']
            x_image_val = []
            y_label_val = []
            if resize is True:
                for single_im, single_label in zip(images, y_label_pl):
                    for _im_, _path_ in zip([single_im, single_label], [x_image_val, y_label_val]):
                        _im_ = cv2.resize(_im_, dsize=(784, 528), interpolation=cv2.INTER_CUBIC)
                        _path_.append(_im_)
            else:
                x_image_val, y_label_val, y_edge_val = padding_training_data(images, y_label_pl, y_edge_pl, target_height,
                                                                             target_width)
            image_tot.append(x_image_val)
            label_tot.append(y_label_val)
        image_tot = np.concatenate([image_tot[0], image_tot[1]], axis=0)
        label_tot = np.concatenate([label_tot[0], label_tot[1]], axis=0)
        print("The shape of the test images", np.shape(image_tot))
        return image_tot, label_tot
    
    
    def save_im():
        im_tot, la_tot = collect_test_data()
        rand_value = np.random.choice(np.arange(len(im_tot)), 3, replace=False)
        for i in rand_value:
            fig = plt.figure(figsize=(10, 4))
            im_ = im_tot[i]
            la_ = la_tot[i]
            la_judge = (la_ != 0)
            for iterr, single_im in enumerate([im_, la_, la_judge]):
                ax = fig.add_subplot(1, 3, iterr + 1)
                ax.imshow(single_im)
            plt.savefig('/home/blia/im_%d.pdf' % i, pad_inches=0, box_inches='tight')