/*=========================================================================

 Program:   GTRACT (Guided Tensor Restore Anatomical Connectivity Tractography)
 Module:    $RCSfile: $
 Language:  C++
 Date:      $Date: 2006/03/29 14:53:40 $
 Version:   $Revision: 1.9 $

   Copyright (c) University of Iowa Department of Radiology. All rights reserved.
   See GTRACT-Copyright.txt or http://mri.radiology.uiowa.edu/copyright/GTRACT-Copyright.txt 
   for details.
 
      This software is distributed WITHOUT ANY WARRANTY; without even 
      the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
      PURPOSE.  See the above copyright notices for more information.

=========================================================================*/

#include <iostream>
#include <fstream>

#include <itkImage.h>
#include <itkDiffusionTensor3DReconstructionImageFilter.h>
#include <itkVectorIndexSelectionCastImageFilter.h>
#include <itkMedianImageFilter.h>
#include <itkVector.h>
#include <itkMatrix.h>
#include <itkResampleImageFilter.h>
#include <itkLinearInterpolateImageFunction.h>
#include <itkImageFileWriter.h>
#include <itkImageFileReader.h>
#include "itkVectorResampleImageFilter.h"
#include "itkIdentityTransform.h"
#include "itkMetaDataObject.h"
#include "itkDiffusionTensor3D.h"
#include "itkIOCommon.h"
#include "computeTensorCLP.h"

int main (int argc, char **argv)
{
  
  PARSE_ARGS;
  
  typedef signed short                       PixelType;
  typedef itk::VectorImage<PixelType,3>	      VectorImageType;
  typedef itk::Image<PixelType,3>	      IndexImageType;
  typedef itk::DiffusionTensor3D<float> TensorPixelType;
  typedef itk::Image<TensorPixelType,3> TensorImageType;

  IndexImageType::SizeType MedianFilterSize;
  MedianFilterSize[0] = medianFilterSize[0];
  MedianFilterSize[1] = medianFilterSize[1];
  MedianFilterSize[2] = medianFilterSize[2];
  
  bool debug=true;
 // applyMeasurementFrame = true;
  if (debug) 
    {
    std::cout << "=====================================================" << std::endl; 
    std::cout << "Input Image: " <<  inputVolume << std::endl; 
    std::cout << "Output Image: " <<  outputVolume << std::endl; 
    std::cout << "Resample Isotropic: " << resampleIsotropic <<std::endl;
    std::cout << "Voxel Size: " << voxelSize <<std::endl;
    std::cout << "Median Filter Size: " << MedianFilterSize <<std::endl;
    std::cout << "Threshold: " << backgroundSuppressingThreshold <<std::endl; 
    std::cout << "B0 Index: " << b0Index <<std::endl; 
    std::cout << "Apply Measurement Frame: " << applyMeasurementFrame <<std::endl;  
    std::cout << "=====================================================" << std::endl; 
    }
  
  bool violated=false;
  if (inputVolume.size() == 0) { violated = true; std::cout << "  --inputVolume Required! "  << std::endl; }
  if (outputVolume.size() == 0) { violated = true; std::cout << "  --outputVolume Required! "  << std::endl; }
  if (violated) exit(1);

  typedef itk::Vector<double, 3> VectorType;
  typedef itk::Matrix<double, 3,3> MatrixType;
  bool readb0 = false;
  double b0 = 0;
    
  typedef itk::ImageFileReader<VectorImageType, 
                                itk::DefaultConvertPixelTraits< PixelType > > VectorImageReaderType;
  VectorImageReaderType::Pointer vectorImageReader = VectorImageReaderType::New();
  vectorImageReader->SetFileName( inputVolume );
	
  try 
    {			
    vectorImageReader->Update();	
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    throw;
    }
    
    VectorImageType::Pointer img = vectorImageReader->GetOutput();    
          
  typedef itk::DiffusionTensor3DReconstructionImageFilter<PixelType, PixelType, float> TensorFilterType;
  TensorFilterType::Pointer tensorFilter = TensorFilterType::New();
  
  itk::MetaDataDictionary imgMetaDictionary = img->GetMetaDataDictionary();    
  std::vector<std::string> imgMetaKeys = imgMetaDictionary.GetKeys();
  std::vector<std::string>::const_iterator itKey; 
  std::string metaString;
  
  TensorFilterType::GradientDirectionType vect3d;
  TensorFilterType::GradientDirectionContainerType::Pointer 
  DiffusionVectors = TensorFilterType::GradientDirectionContainerType::New();
  
  unsigned int numberOfImages = 0;
                        
  for (itKey = imgMetaKeys.begin(); itKey != imgMetaKeys.end(); itKey ++)
    {
    double x,y,z;

    itk::ExposeMetaData<std::string> (imgMetaDictionary, *itKey, metaString);
    if (itKey->find("DWMRI_gradient") != std::string::npos)
      { 
        
      sscanf(metaString.c_str(), "%lf %lf %lf\n", &x, &y, &z);
     
      vect3d[0] = x; vect3d[1] = y; vect3d[2] = z;
      std::cout << "Gradient:: " << numberOfImages <<"::"<< vect3d <<std::endl;
      if ( applyMeasurementFrame)
      {
      std::vector<std::vector<double> > msrFrame;
      itk::ExposeMetaData<std::vector<std::vector<double> > >(
                vectorImageReader->GetOutput()->GetMetaDataDictionary(),
                "NRRD_measurement frame",msrFrame); 
 	 MatrixType measurementFrame;
 	 for (int i=0;i<3;i++)
   	{
   	   for (int j=0;j<3;j++)
      	   {
      	   measurementFrame[i][j] = msrFrame[i][j];      
      	   } 
    	}
     
     	std::cout << " Measurement Frame::  " << measurementFrame <<std::endl;       
        vect3d = measurementFrame * vect3d;        
        std::cout << "Gradients after Measurement Frame::  " << vect3d <<std::endl;         
    } 
      
      DiffusionVectors->InsertElement( numberOfImages, vect3d);
      ++numberOfImages;
      }
    else if (itKey->find("DWMRI_b-value") != std::string::npos)
      {
        
      readb0 = true;
      b0 = atof(metaString.c_str());
      std::cout << "B value ::" << b0 << std::endl;
      }
    }
   if(!readb0)
    {
    std::cerr << "BValue not specified in header file" << std::endl;
    return EXIT_FAILURE;
    }
  
  
  tensorFilter->SetThreshold( backgroundSuppressingThreshold );
  tensorFilter->SetGradientImage( DiffusionVectors, img);
  tensorFilter->SetBValue(b0);
  tensorFilter->SetNumberOfThreads( 1 ); //required    
  tensorFilter->Update( );

    
  
  TensorImageType::Pointer tensorImage = tensorFilter->GetOutput();
  
  typedef itk::ImageFileWriter<TensorImageType> WriterType;
  WriterType::Pointer nrrdWriter = WriterType::New();
  nrrdWriter->SetInput( tensorImage );
  nrrdWriter->SetFileName( outputVolume );
  try
    {
    nrrdWriter->Update();
    }
  catch (itk::ExceptionObject e)
    {
    std::cout << e << std::endl;
    }
    

}
