/*=====================================================================

Program: Gradient Rotation for Tensor Transformation
Language: C++
Date: 11/13/2007
Author: Madhura A Ingalhalikar

======================================================================*/



#include "itkDiffusionTensor3DReconstructionImageFilter.h"
#include "itkVectorImage.h"
#include "itkNrrdImageIO.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkMetaDataDictionary.h"
#include "itkImageSeriesReader.h"
#include "itkImageRegionConstIterator.h"
#include "itkImageRegionIterator.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkVectorIndexSelectionCastImageFilter.h"
#include "vnl/vnl_math.h"
#include "vnl/vnl_matrix.h"
#include "itkNeighborhoodAlgorithm.h"
#include "itkConstantBoundaryCondition.h"
#include "itkVectorImageInterpolateFunction.h"
#include "itkVectorImageLinearInterpolateFunction.h"
#include "itkTensorFractionalAnisotropyImageFilter.h"
#include "itkTimeProbe.h"
#include "itkVariableLengthVector.h"
#include "itkImageFileWriter.h"
#include "itkGradientRotateCLP.h"
#include <vnl/algo/vnl_svd.h>
#include <iostream>

int main( int argc, char *argv[] )
{
   
   PARSE_ARGS;
  
  std::cout << "Input DTI nrrd Image: " <<  InputImageFilename << std::endl; 
  std::cout << "Input Deformation Field" << DeformationFilename << std::endl;
  std::cout << "Output Tensor Image: " <<  OutputImageFilename << std::endl;
    

   const unsigned int Dimension = 3;
   
   bool readb0 = false;
   double b0 = 0;
   unsigned int threshold =80;
   
   itk::TimeProbe clock;
   clock.Start();
 
  
  typedef unsigned short                      PixelType;
  typedef itk::VectorImage<unsigned short, 3> ImageType;
  typedef itk::VariableLengthVector<double> realPixelType;
  
  //Some required typedef's
   
  typedef itk::Matrix<float,3,3>  MatrixType; 
  typedef itk::Vector<double,3> VectorType;
  typedef MatrixType::InternalMatrixType vnlMatrixType;
  

 
  //Read the DTI image data
  
  itk::ImageFileReader<ImageType>::Pointer reader = itk::ImageFileReader<ImageType>::New();
  ImageType::Pointer img; 
  reader->SetFileName(InputImageFilename);
 
  try
    {
     std::cout << "Reading DTI....." << std::endl;
    reader->Update();
    img = reader->GetOutput();
   
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    return EXIT_FAILURE;
    }
  
    std::cout << img << std::endl;
  int vectorLength = img->GetVectorLength();  

  //Read deformation field 
 
  typedef   float VectorComponentType;
  typedef   itk::Vector< VectorComponentType, Dimension > VectorPixelType;
  typedef   itk::Image< VectorPixelType,  Dimension >   DeformationFieldType;   
  typedef   itk::ImageFileReader< DeformationFieldType >  FieldReaderType;
   
  FieldReaderType::Pointer fieldReader = FieldReaderType::New();
  fieldReader->SetFileName(  DeformationFilename.c_str());
   
  try 
  {
   std::cout << "Reading Deformation Field........." << std::endl;
   fieldReader->Update();
     
  }
  catch(itk::ExceptionObject & fe)
  {
    std::cout << "Field Exception caught ! " << fe << std::endl;
  }
  
  DeformationFieldType::Pointer deformationField = fieldReader->GetOutput(); 
  
  std::cout << deformationField << std::endl;

 //Define tensorConstructionFilter 
 
  typedef itk::DiffusionTensor3DReconstructionImageFilter< 
  PixelType, PixelType, double > TensorReconstructionImageFilterType; 
  
  itk::MetaDataDictionary imgMetaDictionary = img->GetMetaDataDictionary();    
  std::vector<std::string> imgMetaKeys = imgMetaDictionary.GetKeys();
  std::vector<std::string>::const_iterator itKey; 
  std::string metaString;
  
  TensorReconstructionImageFilterType::GradientDirectionType vect3d;
  TensorReconstructionImageFilterType::GradientDirectionContainerType::Pointer 
  DiffusionVectors = TensorReconstructionImageFilterType::GradientDirectionContainerType::New();
   
  typedef itk::Image< PixelType, Dimension > ReferenceImageType;  
    
  
 //Define Interpolator 
 
typedef double CoordRepType;
typedef itk::VectorImageLinearInterpolateFunction<ImageType,CoordRepType> InterpolatorType;
InterpolatorType::Pointer interp = InterpolatorType::New();
//interp->SetInputImage( img );


 // Define Tensor output image
 
  typedef itk::DiffusionTensor3D<double> TensorPixelType;
  typedef itk::Image<TensorPixelType,Dimension> TensorImageType;
  TensorImageType::Pointer tensor = TensorImageType::New();
  TensorImageType::RegionType region;
  region.SetSize(deformationField->GetRequestedRegion().GetSize());
 
  tensor->SetSpacing( deformationField->GetSpacing()); 
  tensor->SetBufferedRegion(region);  
  tensor->SetRequestedRegion(region);
  tensor->SetOrigin(deformationField->GetOrigin());
  tensor->SetLargestPossibleRegion(region);
  tensor->SetDirection(deformationField->GetDirection());
  tensor->Allocate();
  TensorPixelType edgePadding; 
  edgePadding.Fill(0.0);
  
     
 
 // pointer to filter
  
  TensorReconstructionImageFilterType::Pointer filter = TensorReconstructionImageFilterType::New();
 
  interp->SetInputImage(reader->GetOutput());
  
 
  
  //iterator over tensor image
 
  typedef itk::ImageRegionIteratorWithIndex<TensorImageType> OutputIteratorType;
  OutputIteratorType outputIt( tensor, region );
  
  
   
 // iterator for the deformation field
 
  typedef itk::ImageRegionIterator<DeformationFieldType> FieldIteratorType;
  FieldIteratorType fieldIt(deformationField, deformationField->GetRequestedRegion());
  
 // Define Neighbourhood Iterator for computing the Jacobian on the fly  
   
  typedef itk::ConstantBoundaryCondition< DeformationFieldType > boundaryConditionType;
  typedef itk::ConstNeighborhoodIterator<DeformationFieldType, boundaryConditionType> ConstNeighborhoodIteratorType;
  ConstNeighborhoodIteratorType::RadiusType radius;
  radius[0] =1; radius[1] =1; radius[2] = 1;
  
     
  
  //Define an image of 1 voxel
  
  ImageType::Pointer newImage = ImageType::New();
  ImageType::RegionType smallRegion;
  ImageType::RegionType::IndexType refIndex;
  ImageType::SizeType size; 
  size[0] =1; size[1] =1; size[2] =1;
  smallRegion.SetSize(size);
  
 
 

  typedef itk::Point<double,3> PointType;
  PointType point;
  DeformationFieldType::PixelType displacement;
  TensorImageType::IndexType index; 
  TensorImageType::IndexType newIndex;
  vnlMatrixType J;
  
  
  //define boundary condition
    
    boundaryConditionType cbc;    
    ConstNeighborhoodIteratorType
    bit( radius,deformationField,region);
    
    cbc.SetConstant(static_cast<VectorType>(0.0));
    bit.OverrideBoundaryCondition(&cbc);
    bit.GoToBegin();
   std::cout <<  " Computing Rotation and tensor... " << std::endl;
   
   while ( ! outputIt.IsAtEnd() )
 {
     
   index = outputIt.GetIndex();    
   displacement = fieldIt.Get(); 
    
     // std::cout << "orig_D = " <<displacement<<std::endl;
                  
   //compute Jacobian     
  
    for (int i = 0; i < 3; ++i)
      {
      for (int j = 0; j < 3; ++j)
        {
       
	 J(j,i) =  0.5 * (bit.GetPrevious(i)[j] - bit.GetNext(i)[j]);
        
        }
      }
        
  
  
     //Use SVD for computing rotation matrix    
      
      vnlMatrixType iden;
      iden.set_identity(); 
     
      // Compute Singluar value Decompostion of the Jacobian to find the Rotation Matrix. 
      
      vnl_svd<float> svd(J + iden);      
      MatrixType  rotationMatrix(svd.U() * svd.V().transpose());
     
      
    
      tensor->TransformIndexToPhysicalPoint( index, point );
   
      
      for(unsigned int j = 0; j < 3; j++ )
      {
      point[j] += displacement[j];
      }
      
    
    
    
   if( interp->IsInsideBuffer( point ) )
      {
      typedef  InterpolatorType::OutputType OutputType;
      const OutputType interpolatedValue = interp->Evaluate(point);
      realPixelType value;
      value = static_cast<realPixelType>(interpolatedValue);
       //std::cout << "R= " << rotationMatrix <<std::endl;
      newImage->TransformPhysicalPointToIndex(point,refIndex);
      smallRegion.SetIndex(refIndex);
      newImage->SetBufferedRegion(smallRegion);
      newImage->SetLargestPossibleRegion(smallRegion);
      newImage->SetVectorLength(vectorLength);
      newImage->Allocate();
      newImage->SetPixel(refIndex,value);
      if (value[0] > threshold){
      
      unsigned int numberOfImages = 0;
                        
  for (itKey = imgMetaKeys.begin(); itKey != imgMetaKeys.end(); itKey ++)
    {
    double x,y,z;

    itk::ExposeMetaData<std::string> (imgMetaDictionary, *itKey, metaString);
    if (itKey->find("DWMRI_gradient") != std::string::npos)
      { 
        
      sscanf(metaString.c_str(), "%lf %lf %lf\n", &x, &y, &z);
      vect3d[0] = x; vect3d[1] = y; vect3d[2] = z;
      
      TensorReconstructionImageFilterType::GradientDirectionType newDir;
      newDir[0] = rotationMatrix(0,0)* vect3d[0] + rotationMatrix(0,1)* vect3d[1] + rotationMatrix(0,2)* vect3d[2];
      newDir[1] = rotationMatrix(1,0)* vect3d[0] + rotationMatrix(1,1)* vect3d[1] + rotationMatrix(1,2)* vect3d[2];
      newDir[2] = rotationMatrix(2,0)* vect3d[0] + rotationMatrix(2,1)* vect3d[1] + rotationMatrix(2,2)* vect3d[2];
      
      
    
         
      DiffusionVectors->InsertElement( numberOfImages, newDir);
      ++numberOfImages;
      }
    else if (itKey->find("DWMRI_b-value") != std::string::npos)
      {
        
      readb0 = true;
      b0 = atof(metaString.c_str());
      }
    }
   if(!readb0)
    {
    std::cerr << "BValue not specified in header file" << std::endl;
    return EXIT_FAILURE;
    }
    
  
      filter->SetGradientImage( DiffusionVectors, newImage);
     
      filter->SetBValue(b0);
      filter->SetNumberOfThreads( 1 ); //required
      
      
      filter->SetThreshold(threshold);
        
   try
    {
    filter->UpdateLargestPossibleRegion();
  
    }
  catch (itk::ExceptionObject &e)
    {
    std::cout << e << std::endl;
    return EXIT_FAILURE;
    } 
     
   
      TensorImageType::Pointer tes = filter->GetOutput();
      TensorPixelType tp =  tes->GetPixel(refIndex);
      outputIt.Set( tp );  
      }
      else
      {
     outputIt.Set(edgePadding);
      }
      
     }
     
     
   else
    {
     outputIt.Set(edgePadding);
    }
     ++outputIt; 
     ++fieldIt;
     ++bit;
     
     
  }
  
  clock.Stop();
  double timeTaken = clock.GetMeanTime();
  std::cout << "Time Taken= " << timeTaken <<std::endl;
      
  //// Writer............
  
  typedef itk::ImageFileWriter< TensorImageType> TensorWriterType;
  TensorWriterType::Pointer tensorWriter = TensorWriterType::New();
  tensorWriter->SetFileName( OutputImageFilename );
  tensorWriter->SetInput(tensor);
  std::cout << tensor << std::endl;
  
  tensorWriter->Update();
  
 
  
  return 0;
  
}
  
  
  
  
  
  

  
  
  
  
