#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 29 11:30:17 2019

@author: vand@dtu.dk
"""

import numpy as np
import skimage.io
import matplotlib.pyplot as plt
import st2d
    
#%% ST AND ORIENTATIONS - VISUALIZATION OPTIONS
plt.close('all')

filename = 'example_data_2D/drawn_fibres_B.png';
sigma = 0.5
rho = 2

image = skimage.io.imread(filename)
S = st2d.structure_tensor(image, sigma, rho)
val,vec = st2d.eig_special(S)
    
# visualization
figsize = (10,5)
fig, ax = plt.subplots(1, 5, figsize=figsize, sharex=True, sharey=True)
ax[0].imshow(image,cmap=plt.cm.gray)
st2d.plot_orientations(ax[0], image.shape, vec)
ax[0].set_title('Orientation as arrows')
orientation_st_rgba = plt.cm.hsv((np.arctan2(vec[1], vec[0])/np.pi).reshape(image.shape))
ax[1].imshow(plt.cm.gray(image)*orientation_st_rgba)
ax[1].set_title('Orientation as color on image')
ax[2].imshow(orientation_st_rgba)
ax[2].set_title('Orientation as color')
anisotropy = (1-val[0]/val[1]).reshape(image.shape)
ax[3].imshow(anisotropy)
ax[3].set_title('Degree of anisotropy')
ax[4].imshow(plt.cm.gray(anisotropy)*orientation_st_rgba)
ax[4].set_title('Orientation and anisotropy')
plt.show()    

 #%% ST AND ORIENTATIONS - HISTOGRAMS OPTIONS

filename = 'example_data_2D/10X.png';
sigma = 0.5
rho = 15
N = 180 # number of angle bins for orientation histogram

# computation
image = skimage.io.imread(filename)
image = np.mean(image[:,:,0:3],axis=2).astype(np.uint8)
S = st2d.structure_tensor(image, sigma, rho)
val,vec = st2d.eig_special(S)
angles = np.arctan2(vec[1], vec[0]) # angles from 0 to pi
distribution = np.histogram(angles, bins=N, range=(0.0, np.pi))[0]
    
# visualization
figsize = (10,5)
fig, ax = plt.subplots(1, 2, figsize=figsize, sharex=True, sharey=True)
ax[0].imshow(image,cmap=plt.cm.gray)
ax[0].set_title('Input image')
orientation_st_rgba = plt.cm.hsv((angles/np.pi).reshape(image.shape))
ax[1].imshow(plt.cm.gray(image)*orientation_st_rgba)
ax[1].set_title('Orientation as color on image')

fig, ax = plt.subplots(1,2, figsize=figsize)
bin_centers = (np.arange(N)+0.5)*np.pi/N # halp circle (180 deg)
colors = plt.cm.hsv(bin_centers/np.pi)
ax[0].bar(bin_centers, distribution, width = np.pi/N, color = colors)
ax[0].set_xlabel('angle')
ax[0].set_xlim([0,np.pi])
ax[0].set_aspect(np.pi/ax[0].get_ylim()[1])
ax[0].set_xticks([0,np.pi/2,np.pi])
ax[0].set_xticklabels(['0','pi/2','pi'])
ax[0].set_ylabel('count')
ax[0].set_title('Histogram over angles')
st2d.polar_histogram(ax[1], distribution)
ax[1].set_title('Polar histogram')
plt.show()

 #%% ST AND ORIENTATIONS - HISTOGRAMS OPTIONS

filename = 'example_data_2D/drawn_field.png';
sigma = 0.5
rho = 15
N = 90 # number of angle bins for orientation histogram

# computation
image = skimage.io.imread(filename)
image = np.mean(image[:,:,0:3],axis=2).astype(np.uint8)
S = st2d.structure_tensor(image, sigma, rho)
val,vec = st2d.eig_special(S)
angles = np.arctan2(vec[1], vec[0]) # angles from 0 to pi
distribution = np.histogram(angles, bins=N, range=(0.0, np.pi))[0]
distribution_weighted = np.histogram(angles, bins=N, range=(0.0, np.pi), weights = (image>175).ravel().astype(np.float))[0]
    
# visualization
figsize = (10,5)
fig, ax = plt.subplots(1, 2, figsize=figsize, sharex=True, sharey=True)
ax[0].imshow(image,cmap=plt.cm.gray)
ax[0].set_title('Input image')
orientation_st_rgba = plt.cm.hsv((angles/np.pi).reshape(image.shape))
ax[1].imshow(plt.cm.gray(image)*orientation_st_rgba)
ax[1].set_title('Orientation as color on image')

fig, ax = plt.subplots(2, 2, figsize=figsize)
bin_centers = (np.arange(N)+0.5)*np.pi/N # halp circle (180 deg)
colors = plt.cm.hsv(bin_centers/np.pi)
ax[0][0].bar(bin_centers, distribution, width = np.pi/N, color = colors)
ax[0][0].set_xlabel('angle')
ax[0][0].set_xlim([0,np.pi])
ax[0][0].set_aspect(np.pi/ax[0][0].get_ylim()[1])
ax[0][0].set_xticks([0,np.pi/2,np.pi])
ax[0][0].set_xticklabels(['0','pi/2','pi'])
ax[0][0].set_ylabel('count')
ax[0][0].set_title('Histogram over angles - all orientations')
st2d.polar_histogram(ax[0][1], distribution)
ax[0][1].set_title('Polar histogram - all orientations')
ax[1][0].bar(bin_centers, distribution_weighted, width = np.pi/N, color = colors)
ax[1][0].set_xlabel('angle')
ax[1][0].set_xlim([0,np.pi])
ax[1][0].set_aspect(np.pi/ax[1][0].get_ylim()[1])
ax[1][0].set_xticks([0,np.pi/2,np.pi])
ax[1][0].set_xticklabels(['0','pi/2','pi'])
ax[1][0].set_ylabel('count')
ax[1][0].set_title('Histogram over angles - orientations at fibres')
st2d.polar_histogram(ax[1][1], distribution_weighted)
ax[1][1].set_title('Polar histogram - orientations at fibres')
plt.show()

   #%% YET ANOTHER EXAMPLE 

filename = 'example_data_2D/OCT_im_org.png';
sigma = 0.5
rho = 5
N = 180 # number of angle bins for orientation histogram

# computation
image = skimage.io.imread(filename)
S = st2d.structure_tensor(image, sigma, rho)
val,vec = st2d.eig_special(S)
angles = np.arctan2(vec[1], vec[0]) # angles from 0 to pi
distribution = np.histogram(angles, bins=N, range=(0.0, np.pi))[0]
    
# visualization
figsize = (10,5)
fig, ax = plt.subplots(1, 2, figsize=figsize, sharex=True, sharey=True)
ax[0].imshow(image,cmap=plt.cm.gray)
ax[0].set_title('Input image')
orientation_st_rgba = plt.cm.hsv((angles/np.pi).reshape(image.shape))
ax[1].imshow(plt.cm.gray(image)*orientation_st_rgba)
ax[1].set_title('Orientation as color on image')

fig, ax = plt.subplots(1,2, figsize=figsize)
bin_centers = (np.arange(N)+0.5)*np.pi/N # halp circle (180 deg)
colors = plt.cm.hsv(bin_centers/np.pi)
ax[0].bar(bin_centers, distribution, width = np.pi/N, color = colors)
ax[0].set_xlabel('angle')
ax[0].set_xlim([0,np.pi])
ax[0].set_aspect(np.pi/ax[0].get_ylim()[1])
ax[0].set_xticks([0,np.pi/2,np.pi])
ax[0].set_xticklabels(['0','pi/2','pi'])
ax[0].set_ylabel('count')
ax[0].set_title('Histogram over angles')
st2d.polar_histogram(ax[1], distribution)
ax[1].set_title('Polar histogram')
plt.show()
    
#%% INVESTIGATING THE EFFECT OF RHO

filename = 'example_data_2D/short_fibres.png'
image = skimage.io.imread(filename)
image = np.mean(image[:,:,0:3],axis=2)
image -= np.min(image)
image /= np.max(image)
s = 128 # quiver arrow spacing
sigma = 0.5
figsize = (10,5)

rhos = [2,10,20,50]

for k in range(4):
    
    # computation
    rho = rhos[k] # changing integration radius
    S = st2d.structure_tensor(image,sigma,rho)
    val,vec = st2d.eig_special(S)
        
    # visualization
    fig, ax = plt.subplots(1, 2, figsize=figsize, sharex=True, sharey=True)
    ax[0].imshow(image,cmap=plt.cm.gray)
    st2d.plot_orientations(ax[0], image.shape, vec, s = s)
    ax[0].set_title(f'Rho = {rho}, arrows')
    intensity_rgba = plt.cm.gray(image)
    orientation_st_rgba = plt.cm.hsv((np.arctan2(vec[1], vec[0])/np.pi).reshape(image.shape))
    ax[1].imshow((0.5+0.5*intensity_rgba)*orientation_st_rgba)
    ax[1].set_title(f'Rho = {rho}, color')

    #basis_filename = filename.split('/')[-1].split('.')[0]
    #fig.savefig(basis_filename + '_rho_' + str(rho) + '.png')
    plt.show()



#%% INVESTIGATING THE EFFECT OF SCALING + RHO 

filename = 'example_data_2D/short_fibres.png'
downsampling_range = 4
figsize = (10,5)    

for k in range(downsampling_range):

    # downsampling and computation
    scale = 2**k
    f = 1/scale # image scale factor
    s = 128//scale # quiver arrow spacing
    sigma = 0.5 # would it make sense to scale this too?
    rho = 50/scale # scaling the integration radius
    image = skimage.io.imread(filename)
    image = np.mean(image[:,:,0:3],axis=2)
    image = skimage.transform.rescale(image,f,multichannel=False)
    image -= np.min(image)
    image /= np.max(image)
    S = st2d.structure_tensor(image,sigma,rho)
    val,vec = st2d.eig_special(S)
        
    # visualization
    fig, ax = plt.subplots(1, 2, figsize=figsize, sharex=True, sharey=True)
    ax[0].imshow(image,cmap=plt.cm.gray)
    st2d.plot_orientations(ax[0], image.shape, vec, s = s)
    ax[0].set_title(f'Downsample = {scale}, arrows')
    intensity_rgba = plt.cm.gray(image)
    orientation_st_rgba = plt.cm.hsv((np.arctan2(vec[1], vec[0])/np.pi).reshape(image.shape))
    ax[1].imshow((0.5+0.5*intensity_rgba)*orientation_st_rgba)       
    ax[1].set_title(f'Downsample = {scale}, color')
 
    #basis_filename = filename.split('/')[-1].split('.')[0]
    #fig.savefig(basis_filename + '_scale_' + str(scale) + '.png')
    plt.show()
    
#%% COMPARING DOMINANT ORIENTATION AND OPTICAL FLOW
image = skimage.io.imread('example_data_2D/drawn_fibres_B.png');

# computing structure tensor, orientation and optical flow
sigma = 0.5
rho = 5
S = st2d.structure_tensor(image,sigma,rho)
val,vec = st2d.eig_special(S) # dominant orientation
fx = st2d.solve_flow(S) # optical flow

# visualization
figsize = (10,10)
fy = np.ones(image.shape)
fig, ax = plt.subplots(2,2,figsize=figsize)

ax[0][0].imshow(image,cmap=plt.cm.gray)
st2d.plot_orientations(ax[0][0], image.shape, vec)
ax[0][0].set_title('Orientation from structure tensor, arrows')
ax[0][1].imshow(image,cmap=plt.cm.gray)
st2d.plot_orientations(ax[0][1], image.shape, np.r_[fx,np.ones((1,image.size))])
ax[0][1].set_title('Orientation from optical flow, arrows')
intensity_rgba = plt.cm.gray(image)
orientation_st_rgba = plt.cm.hsv((np.arctan2(vec[1],vec[0])/np.pi).reshape(image.shape))
orientation_of_rgba = plt.cm.hsv((np.arctan2(1,fx)/np.pi).reshape(image.shape))
ax[1][0].imshow(intensity_rgba*orientation_st_rgba)
ax[1][0].set_title('Dominant orientation from structure tensor, color')
ax[1][1].imshow(intensity_rgba*orientation_of_rgba)
ax[1][1].set_title('Orientation from optical flow, color')
plt.show()