/*=====================================================================

Program: Gradient Rotation for Tensor Transformation
Language: C++
Date: 11/13/2007
Author: Madhura A Ingalhalikar

======================================================================*/



#include "itkDiffusionTensor3DReconstructionImageFilter.h"
#include "itkVectorImage.h"
#include "itkNrrdImageIO.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkMetaDataDictionary.h"
#include "itkImageSeriesReader.h"
#include "itkImageRegionConstIterator.h"
#include "itkImageRegionIterator.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkVectorIndexSelectionCastImageFilter.h"
#include "vnl/vnl_math.h"
#include "vnl/vnl_matrix.h"
#include "itkNeighborhoodAlgorithm.h"
#include "itkConstantBoundaryCondition.h"
#include "itkVectorImageInterpolateFunction.h"
#include "itkVectorImageLinearInterpolateFunction.h"
#include "itkTensorFractionalAnisotropyImageFilter.h"
#include "itkTimeProbe.h"
#include "itkVariableLengthVector.h"
#include "itkImageFileWriter.h"
#include "itkGradientRotateCLP.h"
#include <vnl/algo/vnl_svd.h>
#include <iostream>
#include <vnl/vnl_quaternion.h>
#include <vnl/vnl_vector.h>
#include <vnl/vnl_cross.h>

int main( int argc, char *argv[] )
{
   
   PARSE_ARGS;
  
  std::cout << "Input DTI nrrd Image: " <<  InputImageFilename << std::endl; 
  std::cout << "Input Deformation Field" << DeformationFilename << std::endl;
  std::cout << "Output Tensor Image: " <<  OutputImageFilename << std::endl;
    

   const unsigned int Dimension = 3;
   
   bool readb0 = false;
   double b0 = 0;
   unsigned int threshold =100;
   
   itk::TimeProbe clock;
   clock.Start();
 
  
  typedef unsigned short                      PixelType;
  typedef itk::VectorImage<unsigned short, 3> ImageType;
  typedef itk::VariableLengthVector<double> realPixelType;
  
  //Some required typedef's
   
  typedef itk::Matrix<float,3,3>  MatrixType; 
  typedef itk::Vector<double,3> VectorType;
  typedef MatrixType::InternalMatrixType vnlMatrixType;
  typedef MatrixType::ComponentType CompType;
  typedef vnl_vector_fixed<CompType, 3> vnlVectorType;
  typedef vnl_quaternion<CompType> RotationType;
  

 
  //Read the DTI image data
  
  itk::ImageFileReader<ImageType>::Pointer reader = itk::ImageFileReader<ImageType>::New();
  ImageType::Pointer img; 
  reader->SetFileName(InputImageFilename.c_str());
 
  try
    {
     std::cout << "Reading DTI....." << std::endl;
    reader->Update();
    img = reader->GetOutput();
   
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    return EXIT_FAILURE;
    }
  
    
  int vectorLength = img->GetVectorLength();  

  //Read deformation field 
 
  typedef   float VectorComponentType;
  typedef   itk::Vector< VectorComponentType, Dimension > VectorPixelType;
  typedef   itk::Image< VectorPixelType,  Dimension >   DeformationFieldType;   
  typedef   itk::ImageFileReader< DeformationFieldType >  FieldReaderType;
   
  FieldReaderType::Pointer fieldReader = FieldReaderType::New();
  fieldReader->SetFileName(  DeformationFilename.c_str());
   
  try 
  {
   std::cout << "Reading Deformation Field........." << std::endl;
   fieldReader->Update();
     
  }
  catch(itk::ExceptionObject & fe)
  {
    std::cout << "Field Exception caught ! " << fe << std::endl;
  }
  
  DeformationFieldType::Pointer deformationField = fieldReader->GetOutput(); 
  


 //Define tensorConstructionFilter 
 
  typedef itk::DiffusionTensor3DReconstructionImageFilter< 
  PixelType, PixelType, double > TensorReconstructionImageFilterType; 
  
  itk::MetaDataDictionary imgMetaDictionary = img->GetMetaDataDictionary();    
  std::vector<std::string> imgMetaKeys = imgMetaDictionary.GetKeys();
  std::vector<std::string>::const_iterator itKey; 
  std::string metaString;
  
  TensorReconstructionImageFilterType::GradientDirectionType vect3d;
  TensorReconstructionImageFilterType::GradientDirectionContainerType::Pointer 
  DiffusionVectors = TensorReconstructionImageFilterType::GradientDirectionContainerType::New();
  TensorReconstructionImageFilterType::GradientDirectionContainerType::Pointer 
  OrigDiffusionVectors = TensorReconstructionImageFilterType::GradientDirectionContainerType::New();
  
  
  
  unsigned int numberOfImages = 0;
  
  for (itKey = imgMetaKeys.begin(); itKey != imgMetaKeys.end(); itKey ++)
    {
    double x,y,z;

    itk::ExposeMetaData<std::string> (imgMetaDictionary, *itKey, metaString);
    if (itKey->find("DWMRI_gradient") != std::string::npos)
      { 
        
      sscanf(metaString.c_str(), "%lf %lf %lf\n", &x, &y, &z);
      vect3d[0] = x; vect3d[1] = y; vect3d[2] = z;    
       
      OrigDiffusionVectors->InsertElement( numberOfImages, vect3d);
      ++numberOfImages;
      }
    else if (itKey->find("DWMRI_b-value") != std::string::npos)
      {
        
      readb0 = true;
      b0 = atof(metaString.c_str());
      }
    }
   if(!readb0)
    {
    std::cerr << "BValue not specified in header file" << std::endl;
    return EXIT_FAILURE;
    }
    
   
  typedef itk::Image< PixelType, Dimension > ReferenceImageType;  
    
  
 //Define Interpolator 
 
typedef double CoordRepType;
typedef itk::VectorImageLinearInterpolateFunction<ImageType,CoordRepType> InterpolatorType;
InterpolatorType::Pointer interp = InterpolatorType::New();
//interp->SetInputImage( img );


 // Define Tensor output image
 
  typedef itk::DiffusionTensor3D<double> TensorPixelType;
  typedef itk::Image<TensorPixelType,Dimension> TensorImageType;
  TensorImageType::Pointer tensor = TensorImageType::New();
  TensorImageType::RegionType region;
  region.SetSize(deformationField->GetRequestedRegion().GetSize());
 
  tensor->SetSpacing( deformationField->GetSpacing()); 
  tensor->SetBufferedRegion(region);  
  tensor->SetOrigin(deformationField->GetOrigin());
  tensor->SetLargestPossibleRegion(region);
  tensor->SetDirection(deformationField->GetDirection());
  tensor->Allocate();
  TensorPixelType edgePadding; 
  edgePadding.Fill(0.0);
  
  typedef  TensorPixelType::EigenVectorsMatrixType EigenVectorsType;
  typedef  TensorPixelType::EigenValuesArrayType EigenValuesType;
 
 
 // pointer to filter
  
  TensorReconstructionImageFilterType::Pointer filter = TensorReconstructionImageFilterType::New();
 
  interp->SetInputImage(reader->GetOutput());
  
 
  
  //iterator over tensor image
 
  typedef itk::ImageRegionIteratorWithIndex<TensorImageType> OutputIteratorType;
  OutputIteratorType outputIt( tensor, region );
  
  
   
 // iterator for the deformation field
 
  typedef itk::ImageRegionIterator<DeformationFieldType> FieldIteratorType;
  FieldIteratorType fieldIt(deformationField, deformationField->GetRequestedRegion());
  
 // Define Neighbourhood Iterator for computing the Jacobian on the fly  
   
  typedef itk::ConstantBoundaryCondition< DeformationFieldType > boundaryConditionType;
  typedef itk::ConstNeighborhoodIterator<DeformationFieldType, boundaryConditionType> ConstNeighborhoodIteratorType;
  ConstNeighborhoodIteratorType::RadiusType radius;
  radius[0] =1; radius[1] =1; radius[2] = 1;
  
     
  
  //Define an image of 1 voxel
  
  ImageType::Pointer newImage = ImageType::New();
  ImageType::RegionType smallRegion;
  ImageType::RegionType::IndexType refIndex;
  ImageType::SizeType size; 
  size[0] =1; size[1] =1; size[2] =1;
  smallRegion.SetSize(size);
  
 
 

  typedef itk::Point<double,3> PointType;
  PointType point;
  DeformationFieldType::PixelType displacement;
  TensorImageType::IndexType index; 
  TensorImageType::IndexType newIndex;
  vnlMatrixType J;
  
  
  //define boundary condition
    
    boundaryConditionType cbc;    
    ConstNeighborhoodIteratorType
    bit( radius,deformationField,region);
    
    cbc.SetConstant(static_cast<VectorType>(0.0));
    bit.OverrideBoundaryCondition(&cbc);
    bit.GoToBegin();
   std::cout <<  " Computing Rotation and tensor... " << std::endl;
   
   while ( ! bit.IsAtEnd() )
 {
     
   index = outputIt.GetIndex();    
   displacement = fieldIt.Get();               
   //compute Jacobian     
  
    for (int i = 0; i < 3; ++i)
      {
      for (int j = 0; j < 3; ++j)
        {
       
	 J(j,i) =  0.5 * (bit.GetPrevious(i)[j] - bit.GetNext(i)[j]);
        
        }
      }
        
  
     //Use SVD for computing rotation matrix    
      
      vnlMatrixType iden;
      iden.set_identity(); 
     
      // Compute Singluar value Decompostion of the Jacobian to find the Rotation Matrix. 
      
      vnlMatrixType F = (J + iden); 
       
      tensor->TransformIndexToPhysicalPoint( index, point );
      for(unsigned int j = 0; j < 3; j++ )
      {
      point[j] += displacement[j];
      }
      
 
   if( interp->IsInsideBuffer( point ) )
      {
      typedef  InterpolatorType::OutputType OutputType;
      const OutputType interpolatedValue = interp->Evaluate(point);
      realPixelType value;
      value = static_cast<realPixelType>(interpolatedValue);
        
     
      newImage->TransformPhysicalPointToIndex(point,refIndex);
      smallRegion.SetIndex(refIndex);
      newImage->SetBufferedRegion(smallRegion);
      newImage->SetLargestPossibleRegion(smallRegion);
      newImage->SetVectorLength(vectorLength);
      newImage->Allocate();
      newImage->SetPixel(refIndex,value);
      if (value[0] > threshold)
      {     
        	 
        realPixelType D(vectorLength);
	
	TensorReconstructionImageFilterType::GradientDirectionType  grad;
	VectorType d;
	MatrixType sig;
	MatrixType covar;
	covar.Fill(0.0);
	
	
	
	 for( unsigned int l = 0; l < numberOfImages; l++)
        {        
	  
	         
         D[l] = -(1/b0)*log(value[l]/value[0]);
	  
	  grad = (OrigDiffusionVectors->ElementAt(l));
	   //std::cout <<  " Compute grad " << grad <<std::endl; 
       
	  d[0] = D[l]*grad[0]; d[1] = D[l]*grad[1]; d[2] = D[l]*grad[2];
	  // std::cout <<  " Compute d " << d <<std::endl; 
      
	  sig(0,0) = d[0]* d[0];  sig(0,1) = d[0]* d[1];  sig(0,2) = d[0]* d[2]; 
	  sig(1,0) = d[1]* d[0];  sig(1,1) = d[1]* d[1];  sig(1,2) = d[1]* d[2]; 
	  sig(2,0) = d[2]* d[0];  sig(2,1) = d[2]* d[1];  sig(2,2) = d[2]* d[2]; 
	  // std::cout <<  " Compute sig " << sig<<std::endl; 
 	  covar += sig; 
	 //  std::cout <<  " Compute covar " << covar <<std::endl; 
      
	 }
	
	covar = covar/numberOfImages;
	
	EigenVectorsType mat;
        EigenValuesType e;
	TensorPixelType tens; 
	tens[0] = covar(0,0); tens[1] = covar(0,1); tens[2] = covar(0,2);
	tens[3] = covar(1,1); tens[4] = covar(1,2); tens[5] = covar(2,2);
	tens.ComputeEigenAnalysis(e,mat);
	
	vnlVectorType ev1, ev2, ev3;
        vnlVectorType n1, n2, pn2;
    // find largest eigenvector
    ev1[0] = mat(2,0); ev1[1] = mat(2,1); ev1[2] = mat(2,2);
    ev2[0] = mat(1,0); ev2[1] = mat(1,1); ev2[2] = mat(1,2);
    ev3[0] = mat(0,0); ev3[1] = mat(0,1); ev3[2] = mat(0,2);
    
      
    n1 = (F * ev1).normalize();
    RotationType R1(vnl_cross_3d(n1,ev1).normalize(),angle(ev1,n1));
    vnlMatrixType r1 =  R1.rotation_matrix_transpose();  

    n2 = (F * ev2).normalize();
    pn2 = n2 - dot_product(n1,n2)*n1;
    RotationType R2(r1*ev1,dot_product(r1*ev2,pn2.normalize()));
    vnlMatrixType r2 =  R2.rotation_matrix_transpose(); 

    vnlMatrixType R = (r2 * r1);
    
    unsigned int no =0;
     TensorReconstructionImageFilterType::GradientDirectionType oldDir;
     TensorReconstructionImageFilterType::GradientDirectionType newDir;    
    for(unsigned int i = 0; i < numberOfImages; i++)
        { 
	oldDir = (OrigDiffusionVectors->ElementAt(i));
	newDir[0] = R(0,0)* oldDir[0] + R(0,1)* oldDir[1] + R(0,2)* oldDir[2];
        newDir[1] = R(1,0)* oldDir[0] + R(1,1)* oldDir[1] + R(1,2)* oldDir[2];
        newDir[2] = R(2,0)* oldDir[0] + R(2,1)* oldDir[1] + R(2,2)* oldDir[2];
	DiffusionVectors->InsertElement( no, newDir);
	no++;
        }
	
  
      filter->SetGradientImage( DiffusionVectors, newImage);
      filter->SetBValue(b0);
      filter->SetNumberOfThreads( 1 ); //required
           
      filter->SetThreshold(threshold);
        
   try
    {
    filter->UpdateLargestPossibleRegion();
  
    }
  catch (itk::ExceptionObject &e)
    {
    std::cout << e << std::endl;
    return EXIT_FAILURE;
    } 
     
   
      TensorImageType::Pointer tes = filter->GetOutput();
      TensorPixelType tp =  tes->GetPixel(refIndex);
      outputIt.Set( tp );  
    }
      else
      {
     outputIt.Set(edgePadding);
      }
      
    }
     
     
   else
    {
     outputIt.Set(edgePadding);
    }
     ++outputIt; 
     ++fieldIt;
     ++bit;
     
     
  }
  
  clock.Stop();
  double timeTaken = clock.GetMeanTime();
  std::cout << "Time Taken= " << timeTaken <<std::endl;
      
  //// Writer............
  
  typedef itk::ImageFileWriter< TensorImageType> TensorWriterType;
  TensorWriterType::Pointer tensorWriter = TensorWriterType::New();
  tensorWriter->SetFileName( OutputImageFilename );
  tensorWriter->SetInput(tensor);
  
  tensorWriter->Update();
  
 
  
  return 0;
  
}
  
  
  
  
  
  

  
  
  
  
