Skip to content
Snippets Groups Projects
st3d.py 9.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • vand's avatar
    vand committed
    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Thu Aug 29 11:30:17 2019
    
    @author: vand@dtu.dk
    """
    
    import numpy as np
    import scipy.io
    import scipy.ndimage
    import matplotlib.pyplot as plt
    
    #% STRUCTURE TENSOR 3D
    
    def structure_tensor(volume, sigma, rho):
        """ Structure tensor for 3D image data
        Arguments:
            volume: a 3D array of size N = slices(z)*rows(y)*columns(x)
            sigma: a noise scale, structures smaller than sigma will be 
                removed by smoothing
            rho: an integration scale giving the size over the neighborhood in 
                which the orientation is to be analysed
        Returns:
            an array with shape (6,N) containing elements of structure tensor 
                s_xx, s_yy, s_zz, s_xy, s_xz, s_yz ordered acording to 
                volume.ravel(). 
        Author: vand@dtu.dk, 2019
        """    # computing derivatives (scipy implementation truncates filter at 4 sigma)
        volume = volume.astype(np.float);
        
        Vx = scipy.ndimage.gaussian_filter(volume, sigma, order=[0,0,1], mode='nearest')
        Vy = scipy.ndimage.gaussian_filter(volume, sigma, order=[0,1,0], mode='nearest')
        Vz = scipy.ndimage.gaussian_filter(volume, sigma, order=[1,0,0], mode='nearest')
      
        # integrating elements of structure tensor (scipy uses sequence of 1D)
        Jxx = scipy.ndimage.gaussian_filter(Vx**2, rho, mode='nearest')
        Jyy = scipy.ndimage.gaussian_filter(Vy**2, rho, mode='nearest')
        Jzz = scipy.ndimage.gaussian_filter(Vz**2, rho, mode='nearest')
        Jxy = scipy.ndimage.gaussian_filter(Vx*Vy, rho, mode='nearest')
        Jxz = scipy.ndimage.gaussian_filter(Vx*Vz, rho, mode='nearest')
        Jyz = scipy.ndimage.gaussian_filter(Vy*Vz, rho, mode='nearest')
        S = np.vstack((Jxx.ravel(), Jyy.ravel(), Jzz.ravel(), Jxy.ravel(),\
                       Jxz.ravel(), Jyz.ravel()));
        return S
    
    def eig_special(S, full=False):
        """ Eigensolution for symmetric real 3-by-3 matrices
        Arguments:
            S: an array with shape (6,N) containing structure tensor
            full: a flag indicating that all three eigenvalues should be returned
        Returns:
            val: an array with shape (3,N) containing sorted eigenvalues
            vec: an array with shape (3,N) containing eigenvector corresponding to 
                the smallest eigenvalue. If full, vec has shape (6,N) and contains 
                all three eigenvectors 
        More:        
            An analytic solution of eigenvalue problem for real symmetric matrix,
            using an affine transformation and a trigonometric solution of third
            order polynomial. See https://en.wikipedia.org/wiki/Eigenvalue_algorithm
            which refers to Smith's algorithm https://dl.acm.org/citation.cfm?id=366316
        Author: vand@dtu.dk, 2019
        """    
        # TODO -- deal with special cases, decide treatment of full (i.e. maybe return 2 for full)
        # computing eigenvalues
        s = S[3]**2 + S[4]**2 + S[5]**2 # off-diagonal elements
        q = (1/3)*(S[0]+S[1]+S[2]) # mean of on-diagonal elements
        p = np.sqrt((1/6)*(np.sum((S[0:3] - q)**2, axis=0) + 2*s)) # case p==0 treated below 
        p_inv = np.zeros(p.shape)
        p_inv[p!=0] = 1/p[p!=0] # to avoid division by 0
        B = p_inv * (S - np.outer(np.array([1,1,1,0,0,0]),q))  # B represents a 3-by-3 matrix, A = pB+2I   
        d = B[0]*B[1]*B[2] + 2*B[3]*B[4]*B[5] - B[3]**2*B[2]\
                - B[4]**2*B[1] - B[5]**2*B[0] # determinant of B
        phi = np.arccos(np.minimum(np.maximum(d/2,-1),1))/3 # min-max to ensure -1 <= d/2 <= 1 
        val = q + 2*p*np.cos(phi.reshape((1,-1))+np.array([[2*np.pi/3],[4*np.pi/3],[0]])) # ordered eigenvalues
    
        # computing eigenvectors -- either only one or all three
        if full:
            l = val
        else:
            l=val[0]
                
        u = S[4]*S[5]-(S[2]-l)*S[3]
        v = S[3]*S[5]-(S[1]-l)*S[4]
        w = S[3]*S[4]-(S[0]-l)*S[5]
        vec = np.vstack((u*v, u*w, v*w)) # contains one or three vectors
       
        # normalizing -- depends on number of vectors
        if full: # vec is [x1 x2 x3 y1 y2 y3 z1 z2 z3]
            vec = vec[[0,3,6,1,4,7,2,5,8]] # vec is [v1, v2, v3]
            l = np.sqrt(np.vstack((np.sum(vec[0:3]**2,axis=0), np.sum(vec[3:6]**2,\
                    axis=0), np.sum(vec[6:]**2, axis=0))))
            vec = vec/l[[0,0,0,1,1,1,2,2,2]] # division by 0 should not occur
        else: # vec is [x1 y1 z1] = v1
            vec = vec/np.sqrt(np.sum(vec**2, axis=0));
        return val,vec
    
    def solve_flow(S): 
        """ Solving 2D optic flow, returns LLS optimal x and y for flow along z axis
            (A solution of a 2x2 linear system.)
        Arguments:
            S: an array with shape (6,N) containing 3D structure tensor
        Returns:
            xy: an array with shape (2,N) containing x and y components of the flow
        Author: vand@dtu.dk, 2019
        """
        d = S[0]*S[1]-S[3]**2 # denominator
        aligned = d==0 # 0 or inf solutions
        n = np.vstack((S[3]*S[5]-S[1]*S[4],S[3]*S[4]-S[0]*S[5]))
        xy = np.zeros((2,S.shape[1]))
        xy[:,~aligned] = n[:,~aligned]/d[~aligned]
        return xy 
    
    def tensor_vector_distance(S, u):
        """ Caclulating pairwise distance between tensors and vectors
        Arguments:
            S: an array with shape (6,N) containing tensor
            v: an array with shape (M,3) containing vectors
        Returns:
            v: an array with shape (N,M) containing pairwise distances
        Author: vand@dtu.dk, 2019
        """
        dist = np.dot(S[0:3].T,u**2) + 2*np.dot(S[3:].T,u[[0,0,1]]*u[[1,2,2]])
        return dist
        
    #% INTERACTIVE VISUALIZATION FUNCTIONS - DOES NOT WORK WITH INLINE FIGURES
    
    def arrow_navigation(event,z,Z):
        if event.key == "up":
            z = min(z+1,Z-1)
        elif event.key == 'down':
            z = max(z-1,0)
        elif event.key == 'right':
            z = min(z+10,Z-1)
        elif event.key == 'left':
            z = max(z-10,0)
        elif event.key == 'pagedown':
            z = min(z+50,Z+1)
        elif event.key == 'pageup':
            z = max(z-50,0)
        return z
    
    def show_vol(V,cmap='gray'): 
        """
        Shows volumetric data and colored orientation for interactive inspection.
        @author: vand at dtu dot dk
        """
        def update_drawing():
            ax.images[0].set_array(V[z])
            ax.set_title(f'slice z={z}')
            fig.canvas.draw()
     
        def key_press(event):
            nonlocal z
            z = arrow_navigation(event,z,Z)
            update_drawing()
            
        Z = V.shape[0]
        z = (Z-1)//2
        fig, ax = plt.subplots()
        vmin = np.min(V)
        vmax = np.max(V)
        ax.imshow(V[z], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.set_title(f'slice z={z}')
        fig.canvas.mpl_connect('key_press_event', key_press)
    
    def show_vol_flow(V, fxy, s=5, double_arrow = False): 
        """
        Shows volumetric data and xy optical flow for interactive inspection.
        Arguments:
             V: volume
             fxy: flow in x and y direction
             s: spacing of quiver arrows
        @author: vand at dtu dot dk
        """
        def update_drawing():
            ax.images[0].set_array(V[z])
            ax.collections[0].U = fxy[0,z,s//2::s,s//2::s].ravel()
            ax.collections[0].V = fxy[1,z,s//2::s,s//2::s].ravel()
            if double_arrow:
                ax.collections[1].U = -fxy[0,z,s//2::s,s//2::s].ravel()
                ax.collections[1].V = -fxy[1,z,s//2::s,s//2::s].ravel()
            ax.set_title(f'slice z={z}')
            fig.canvas.draw()
     
        def key_press(event):
            nonlocal z
            z = arrow_navigation(event,z,Z)
            update_drawing()
            
        Z = V.shape[2]
        z = (Z-1)//2
        xmesh, ymesh = np.meshgrid(np.arange(V.shape[1]), np.arange(V.shape[2]), indexing='ij')
                # TODO: figure out exactly why this ij later needs 'xy' 
        fig, ax = plt.subplots()
        ax.imshow(V[z],cmap='gray')
        ax.quiver(ymesh[s//2::s,s//2::s], xmesh[s//2::s,s//2::s],
                  fxy[0,z,s//2::s,s//2::s], fxy[1,z,s//2::s,s//2::s],
                  color='r', angles='xy')
        if double_arrow:
            ax.quiver(ymesh[s//2::s,s//2::s], xmesh[s//2::s,s//2::s],
              -fxy[0,z,s//2::s,s//2::s], -fxy[1,z,s//2::s,s//2::s],
              color='r', angles='xy')
        ax.set_title(f'slice z={z}')
        fig.canvas.mpl_connect('key_press_event', key_press)
        
    def fan_coloring(vec):
        """
        Fan-based colors for orientations in xy plane
        Arguments:
            vec: an array with shape (3,N) containing orientations
        Returns:
            rgba: an array with shape (4,N) containing rgba colors
         @author:vand@dtu.dk
        """
        h = (vec[2]**2).reshape((vec.shape[1],1)) # no discontinuity and less gray
        s = np.mod(np.arctan(vec[0]/vec[1]),np.pi) # hue angle
        hue = plt.cm.hsv(s/np.pi)
        rgba = hue*(1-h) + 0.5*h
        rgba[:,3] = 1 # fixing alpha value
        return rgba
        
    def show_vol_orientation(V, vec, 
                             coloring = lambda v : np.c_[abs(v).T,np.ones((v.shape[1],1))], 
                             blending = lambda g,c : 0.5*(g+c)): 
        """
        Shows volumetric data and colored orientation for interactive inspection.
        @author: vand at dtu dot dk
        """
        rgba = coloring(vec).reshape(V.shape+(4,))
        
        def update_drawing():
            ax.images[0].set_array(blending(plt.cm.gray(V[z]), rgba[z]))
            ax.set_title(f'slice z={z}')
            fig.canvas.draw()
     
        def key_press(event):
            nonlocal z
            z = arrow_navigation(event,z,Z)
            update_drawing()
            
        Z = V.shape[0]
        z = (Z-1)//2
        fig, ax = plt.subplots()
        ax.imshow(blending(plt.cm.gray(V[z]), rgba[z]))
        ax.set_title(f'slice z={z}')
        fig.canvas.mpl_connect('key_press_event', key_press)