%% The code is written by Behnaz Pirzamanbein, bepi@dtu.dk last version 2020.08.04
%% move from pipeline to here to select the points
save_folder = 'Results\point\';
if ~exist(save_folder,'dir')
    mkdir(save_folder)
end 

%% initail points layer
points = [];
points.contrast_enhansed_range = ce;%contrast enhansed range

fprintf('Suggestions:\n0  & 45  dg: first & last slice\n90 & 180 dg: first, middle & last slice or\n             first, 1/3, 2/3 & last slice\n')
show_matvol(Vol)

% move to desire slice and run this line to get the points for that slice
fprintf('Controls\n--------\n up-arrow    : Next slice in volume \n down-arrow  : Previous slice in volume \n right-arrow : +10 slices \n left-arrow  : -10 slices \n PgUp        : -50 slices \n PgDown      : +50 slices \n')
fprintf('Choose the slices by scrolling in the volume and press Enter \n DO NOT CLOSE THE FIGURE!')
input('');

answer = inputdlg('Enter space-separated slice numbers:','Slices',[1 50],{'100 2550'});
points.slice_no = str2num(answer{1});

num_layers = length(points.slice_no);

%% select points, measure the length and distribute points
for l = 1:2
    flag_layer = layer_name{l};
    fprintf('%s layer:\n',flag_layer)
    switch length(points.slice_no)
        case 2
            fprintf('Go to slice %d and select points on the %s layer\n',points.slice_no(1),flag_layer)
            input('');
            [x1,y1] = click_points;
            fprintf('Go to slice %d and select %d points on the %s layer\n',points.slice_no(2),length(x1),flag_layer)
            input('');
            [x2,y2] = click_points;

            pts1 = round([x1',y1']);
            pts2 = round([x2',y2']);
            points.two_slice = cat(3,pts1,pts2);

            if l == 1 %compute the length of the layers. this gives us number of px per layers and we keep it same for inner and outer layer
                dx = [];
                dy = [];
                len0 = [];
                for s = 1:2
                    dx = points.two_slice(1:end-1,1,s)-points.two_slice(2:end,1,s);
                    dy = points.two_slice(1:end-1,2,s)-points.two_slice(2:end,2,s);
                    len0(1,s) = sum(sqrt(dx.^2+dy.^2));
                end
                m_l = round(mean(len0));
            end

            p_new1 = distribute_points(pts1,'number',m_l,1);
            p_new2 = distribute_points(pts2,'number',m_l,1);

            points.all_slice = slice_points_interpolation(p_new2,p_new1,...
                                points.slice_no(2),points.slice_no(1),0);
            dx = [];
            dy = [];
            len1 = [];
            for s = 1:size(points.all_slice,1)
                dx = points.all_slice(s,1,1:end-1)-points.all_slice(s,1,2:end);
                dy = points.all_slice(s,2,1:end-1)-points.all_slice(s,2,2:end);
                len1(1,s) = sum(sqrt(dx.^2+dy.^2));
            end            
            fprintf('Approximated length ~ %d or %d \n',m_l,round(mean(len1)));

        case 3

            fprintf('Go to slice %d and select points on the %s layer\n',points.slice_no(1),flag_layer)
            input('');
            [x1,y1] = click_points;
            fprintf('Go to slice %d and select %d points on the %s layer\n',points.slice_no(2),length(x1),flag_layer)
            input('');
            [x2,y2] = click_points;
            fprintf('Go to slice %d and select %d points on the %s layer\n',points.slice_no(3),length(x1),flag_layer)
            input('');
            [x3,y3] = click_points;

            pts1 = round([x1',y1']);
            pts2 = round([x2',y2']);
            pts3 = round([x3',y3']);

            points.three_slice = cat(3,pts1,pts2,pts3);

            if l == 1 %compute the length of the layers. this gives us number of px per layers and we keep it same for inner and outer layer
                dx = [];
                dy = [];
                len0 = [];
                for s = 1:3
                    dx = points.three_slice(1:end-1,1,s)-points.three_slice(2:end,1,s);
                    dy = points.three_slice(1:end-1,2,s)-points.three_slice(2:end,2,s);
                    len0(1,s) = sum(sqrt(dx.^2+dy.^2));
                end
                m_l = round(mean(len0));
            end

            p_new1 = distribute_points(pts1,'number',m_l,1);
            p_new2 = distribute_points(pts2,'number',m_l,1);
            p_new3 = distribute_points(pts3,'number',m_l,1);

            slice_23 = slice_points_interpolation(p_new3,p_new2,...
                                points.slice_no(3),points.slice_no(2),0);
            slice_12 = slice_points_interpolation(p_new2,p_new1,...
                                points.slice_no(2)-1,points.slice_no(1),0);
            points.all_slice = [slice_12;slice_23];

            dx = [];
            dy = [];
            len1 = [];
            for s = 1:size(points.all_slice,1)
                dx = points.all_slice(s,1,1:end-1)-points.all_slice(s,1,2:end);
                dy = points.all_slice(s,2,1:end-1)-points.all_slice(s,2,2:end);
                len1(1,s) = sum(sqrt(dx.^2+dy.^2));
            end
            fprintf('Approximated length ~ %d or %d \n',m_l,round(mean(len1)));

        case 4

            fprintf('Go to slice %d and select points on the %s layer\n',points.slice_no(1),flag_layer)
            input('');
            [x1,y1] = click_points;
            fprintf('Go to slice %d and select %d points on the %s layer\n',points.slice_no(2),length(x1),flag_layer)
            input('');
            [x2,y2] = click_points;
            fprintf('Go to slice %d and select %d points on the %s layer\n',points.slice_no(3),length(x1),flag_layer)
            input('');
            [x3,y3] = click_points;
            fprintf('Go to slice %d and select %d points on the %s layer\n',points.slice_no(4),length(x1),flag_layer)
            input('');
            [x4,y4] = click_points;

            pts1 = round([x1',y1']);
            pts2 = round([x2',y2']);
            pts3 = round([x3',y3']);
            pts4 = round([x4',y4']);

            points.four_slice = cat(3,pts1,pts2,pts3,pts4);

            if l == 1 %compute the length of the layers. this gives us number of px per layers and we keep it same for inner and outer layer
                dx = [];
                dy = [];
                len0 = [];
                for s = 1:4
                    dx = points.four_slice(1:end-1,1,s)-points.four_slice(2:end,1,s);
                    dy = points.four_slice(1:end-1,2,s)-points.four_slice(2:end,2,s);
                    len0(1,s) = sum(sqrt(dx.^2+dy.^2));
                end
                m_l = round(mean(len0));
            end

            p_new1 = distribute_points(pts1,'number',m_l,1);
            p_new2 = distribute_points(pts2,'number',m_l,1);
            p_new3 = distribute_points(pts3,'number',m_l,1);
            p_new4 = distribute_points(pts4,'number',m_l,1);

            slice_34 = slice_points_interpolation(p_new4,p_new3,...
                                points.slice_no(4),points.slice_no(3),0);
            slice_23 = slice_points_interpolation(p_new3,p_new2,...
                                points.slice_no(3),points.slice_no(2),0);
            slice_12 = slice_points_interpolation(p_new2,p_new1,...
                                points.slice_no(2)-1,points.slice_no(1),0);
            points.all_slice = [slice_12;slice_23;slice_34];

            dx = [];
            dy = [];
            len1 = [];
            for s = 1:size(points.all_slice,1)
                dx = points.all_slice(s,1,1:end-1)-points.all_slice(s,1,2:end);
                dy = points.all_slice(s,2,1:end-1)-points.all_slice(s,2,2:end);
                len1(1,s) = sum(sqrt(dx.^2+dy.^2));
            end
            fprintf('Approximated length ~ %d or %d \n',m_l,round(mean(len1)));
    end

    %% save points
    if flag_save
        save([save_folder,flag_layer,'_points_',sample,'_',dg],'points')
    end
end
close all

%% test the points by visualization
% for s = 1:points.slice_no(2)-points.slice_no(1)
%     fig1 = figure(1);
%     slice_no = points.slice_no(1)+s-1;
%     imagesc(Vol(:,:,slice_no)),colormap gray
%     title(['slice ',num2str(s)])
%     hold on 
%     plot(squeeze(points.all_slice(s,1,:)),squeeze(points.all_slice(s,2,:)),'-m','LineWidth',1)
%     axis off
%     drawnow
%     if ~mod(s,20)
%         close(fig1)
%     end
% end