/*=========================================================================

 Program:   GTRACT (Guided Tensor Restore Anatomical Connectivity Tractography)
 Module:    $RCSfile: $
 Language:  C++
 Date:      $Date: 2006/03/29 14:53:40 $
 Version:   $Revision: 1.9 $

   Copyright (c) University of Iowa Department of Radiology. All rights reserved.
   See GTRACT-Copyright.txt or http://mri.radiology.uiowa.edu/copyright/GTRACT-Copyright.txt 
   for details.
 
      This software is distributed WITHOUT ANY WARRANTY; without even 
      the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
      PURPOSE.  See the above copyright notices for more information.

=========================================================================*/

#include <iostream>
#include <fstream>

#include <itkImage.h>
#include <itkDiffusionTensor3DReconstructionImageFilter.h>
#include <itkVectorIndexSelectionCastImageFilter.h>
#include <itkMedianImageFilter.h>
#include <itkResampleImageFilter.h>
#include <itkLinearInterpolateImageFunction.h>
#include <itkImageFileWriter.h>
#include <itkImageFileReader.h>
#include "itkVectorResampleImageFilter.h"
#include "itkIdentityTransform.h"
#include "itkMetaDataObject.h"
#include "itkIOCommon.h"
#include "gtractTensorCLP.h"

int main (int argc, char **argv)
{
  
  PARSE_ARGS;
  
  typedef signed short                      PixelType;
  typedef double                             TensorPixelType;
  typedef itk::VectorImage<PixelType,3>	      VectorImageType;
  typedef itk::Image<PixelType,3>	      IndexImageType;

  IndexImageType::SizeType MedianFilterSize;
  MedianFilterSize[0] = medianFilterSize[0];
  MedianFilterSize[1] = medianFilterSize[1];
  MedianFilterSize[2] = medianFilterSize[2];
  
  bool debug=true;
  applyMeasurementFrame = true;
  if (debug) 
    {
    std::cout << "=====================================================" << std::endl; 
    std::cout << "Input Image: " <<  inputVolume << std::endl; 
    std::cout << "Output Image: " <<  outputVolume << std::endl; 
    std::cout << "Resample Isotropic: " << resampleIsotropic <<std::endl;
    std::cout << "Voxel Size: " << voxelSize <<std::endl;
    std::cout << "Median Filter Size: " << MedianFilterSize <<std::endl;
    std::cout << "Threshold: " << backgroundSuppressingThreshold <<std::endl; 
    std::cout << "B0 Index: " << b0Index <<std::endl; 
    std::cout << "Apply Measurement Frame: " << applyMeasurementFrame <<std::endl;  
    std::cout << "=====================================================" << std::endl; 
    }
  
  bool violated=false;
  if (inputVolume.size() == 0) { violated = true; std::cout << "  --inputVolume Required! "  << std::endl; }
  if (outputVolume.size() == 0) { violated = true; std::cout << "  --outputVolume Required! "  << std::endl; }
  if (violated) exit(1);

  typedef itk::Vector<double, 3> VectorType;
  typedef itk::Matrix<double, 3,3> MatrixType;
    
  typedef itk::ImageFileReader<VectorImageType, 
                                itk::DefaultConvertPixelTraits< PixelType > > VectorImageReaderType;
  VectorImageReaderType::Pointer vectorImageReader = VectorImageReaderType::New();
  vectorImageReader->SetFileName( inputVolume );
	
  try 
    {			
    vectorImageReader->Update();	
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    throw;
    }
    
    
  /* Extract B0 Image */
  typedef itk::VectorIndexSelectionCastImageFilter<VectorImageType, IndexImageType> VectorSelectFilterType;
  typedef VectorSelectFilterType::Pointer 	VectorSelectFilterPointer;
  
  VectorSelectFilterPointer selectIndexImageFilter = VectorSelectFilterType::New();
  selectIndexImageFilter->SetIndex( b0Index );
  selectIndexImageFilter->SetInput( vectorImageReader->GetOutput() );
  try
    {
    selectIndexImageFilter->Update();
    }
  catch (itk::ExceptionObject e)
    {
    std::cout << e << std::endl;
    }
  
  /* Median Filter */
  IndexImageType::Pointer baseImage;
  if ( MedianFilterSize[0] > 0  ||  MedianFilterSize[1] > 0  ||  MedianFilterSize[2] > 0 )
    {
    typedef itk::MedianImageFilter< IndexImageType, IndexImageType > MedianFilterType;
    MedianFilterType::Pointer filter = MedianFilterType::New();
    filter->SetInput( selectIndexImageFilter->GetOutput() );
    filter->SetRadius( MedianFilterSize );
    filter->Update();
    baseImage = filter->GetOutput();
    }
  else
    {
    baseImage = selectIndexImageFilter->GetOutput();
    }
  
  
  /* Resample To Isotropic Images */
  IndexImageType::Pointer b0Image;
  if ( resampleIsotropic )
    {
    typedef itk::ResampleImageFilter<IndexImageType, IndexImageType>  ResampleFilterType;
    ResampleFilterType::Pointer resampler = ResampleFilterType::New();
    resampler->SetInput( baseImage );

    typedef itk::LinearInterpolateImageFunction<IndexImageType, double>  InterpolatorType;
    InterpolatorType::Pointer interpolator = InterpolatorType::New();
    resampler->SetInterpolator( interpolator );
    resampler->SetDefaultPixelValue( 0 ); 

    IndexImageType::SpacingType spacing;	
    spacing[0] = voxelSize;
    spacing[1] = voxelSize;
    spacing[2] = voxelSize;	
    resampler->SetOutputSpacing( spacing );

    // Use the same origin
    resampler->SetOutputOrigin( selectIndexImageFilter->GetOutput()->GetOrigin() );


    IndexImageType::SizeType 	inputSize	= baseImage->GetLargestPossibleRegion().GetSize();
    IndexImageType::SpacingType	inputSpacing	= baseImage->GetSpacing();
    typedef IndexImageType::SizeType::SizeValueType SizeValueType;
    IndexImageType::SizeType		size;	
    size[0] = static_cast<SizeValueType>(inputSize[0] * inputSpacing[0] / voxelSize);
    size[1] = static_cast<SizeValueType>(inputSize[1] * inputSpacing[1] / voxelSize);
    size[2] = static_cast<SizeValueType>(inputSize[2] * inputSpacing[2] / voxelSize);	
    resampler->SetSize( size );

    typedef itk::IdentityTransform< double, 3 >  TransformType;
    TransformType::Pointer transform = TransformType::New();	
    transform->SetIdentity();	
    resampler->SetTransform( transform );
    resampler->Update();
    b0Image = resampler->GetOutput();
    }
  else
    {
    b0Image = baseImage;
    }
  //b0Image->DisconnectPipeline();
          
  typedef itk::DiffusionTensor3DReconstructionImageFilter<PixelType, PixelType, TensorPixelType> TensorFilterType;
  TensorFilterType::Pointer tensorFilter = TensorFilterType::New();
  tensorFilter->SetReferenceImage( b0Image );
  tensorFilter->SetThreshold( backgroundSuppressingThreshold );
      
  std::string BValue_str;
  std::string BValue_keyStr("DWMRI_b-value");    
  itk::ExposeMetaData<std::string> (vectorImageReader->GetOutput()->GetMetaDataDictionary(), BValue_keyStr.c_str(), BValue_str);
  double BValue = atof(BValue_str.c_str());
  std::cout << "The BValue was found to be " << BValue_str << std::endl;
  tensorFilter->SetBValue(BValue); /* Required */
  tensorFilter->SetNumberOfThreads(1); /* Required */
  
  std::vector<std::vector<double> > msrFrame; 
  itk::ExposeMetaData<std::vector<std::vector<double> > >(
                vectorImageReader->GetOutput()->GetMetaDataDictionary(),
                "NRRD_measurement frame",msrFrame); 
 MatrixType measurementFrame;
  for (int i=0;i<3;i++)
    {
    for (int j=0;j<3;j++)
      {
      measurementFrame[i][j] = msrFrame[i][j];
      }
    }
  
  int vectorLength = vectorImageReader->GetOutput()->GetVectorLength();
  for (int i=0;i<vectorLength;i++)
    {
    if (i != b0Index)
      {
      // Don't omit the redeclaration of selectIndexImageFilter:
      VectorSelectFilterPointer selectIndexImageFilter = VectorSelectFilterType::New();
      selectIndexImageFilter->SetIndex( i ); // was SetIndex( b0Index );
      selectIndexImageFilter->SetInput( vectorImageReader->GetOutput() );
      selectIndexImageFilter->Update();
      
      /* Median Filter */
      if ( MedianFilterSize[0] > 0  ||  MedianFilterSize[1] > 0  ||  MedianFilterSize[2] > 0 )
        {
        typedef itk::MedianImageFilter< IndexImageType, IndexImageType > MedianFilterType;
	MedianFilterType::Pointer filter = MedianFilterType::New();
	filter->SetInput( selectIndexImageFilter->GetOutput() );
	filter->SetRadius( MedianFilterSize );
	filter->Update();
	baseImage = filter->GetOutput();
        }
      else
        {
        baseImage = selectIndexImageFilter->GetOutput();
        }
      //baseImage->DisconnectPipeline();  
      
      IndexImageType::Pointer gradientImage;
      if ( resampleIsotropic )
        {
        typedef itk::ResampleImageFilter<IndexImageType, IndexImageType>  ResampleFilterType;
	ResampleFilterType::Pointer resampler = ResampleFilterType::New();
	resampler->SetInput( baseImage );

	typedef itk::LinearInterpolateImageFunction<IndexImageType, double>  InterpolatorType;
	InterpolatorType::Pointer interpolator = InterpolatorType::New();
	resampler->SetInterpolator( interpolator );
	resampler->SetDefaultPixelValue( 0 ); 

	IndexImageType::SpacingType spacing;	
	spacing[0] = voxelSize;
	spacing[1] = voxelSize;
	spacing[2] = voxelSize;	
	resampler->SetOutputSpacing( spacing );

	// Use the same origin
	resampler->SetOutputOrigin( selectIndexImageFilter->GetOutput()->GetOrigin() );


	IndexImageType::SizeType 	inputSize	= baseImage->GetLargestPossibleRegion().GetSize();
	IndexImageType::SpacingType	inputSpacing	= baseImage->GetSpacing();
	typedef IndexImageType::SizeType::SizeValueType SizeValueType;
	IndexImageType::SizeType		size;	
	size[0] = static_cast<SizeValueType>(inputSize[0] * inputSpacing[0] / voxelSize);
	size[1] = static_cast<SizeValueType>(inputSize[1] * inputSpacing[1] / voxelSize);
	size[2] = static_cast<SizeValueType>(inputSize[2] * inputSpacing[2] / voxelSize);	
	resampler->SetSize( size );

	typedef itk::IdentityTransform< double, 3 >  TransformType;
        TransformType::Pointer transform = TransformType::New();	
	transform->SetIdentity();	
	resampler->SetTransform( transform );
        resampler->Update();
        gradientImage = resampler->GetOutput();
        }
      else
        {
        gradientImage = baseImage;
        }      
      //gradientImage->DisconnectPipeline();
//std::cout << "Gradient Image Object:  " << std::endl << gradientImage << std::endl;
      
      TensorFilterType::GradientDirectionType gradientDir;
      char tmpStr[64];
      std::string NrrdValue;
      sprintf(tmpStr,"DWMRI_gradient_%04d", i);    
      itk::ExposeMetaData<std::string> (vectorImageReader->GetOutput()->GetMetaDataDictionary(), tmpStr, NrrdValue);
      char tokTmStr[64];
      strcpy( tokTmStr, NrrdValue.c_str());
      VectorType tmpDir;
      tmpDir[0] = atof( strtok(tokTmStr, " ") );
      tmpDir[1] = atof( strtok(NULL, " ") );
      tmpDir[2] = atof( strtok(NULL, " ") );
      if ( applyMeasurementFrame )
        {
        std::cout << "Apply Measurement Frame: " << tmpDir << std::endl;
        tmpDir = measurementFrame * tmpDir;
        std::cout << "Applied Measurement Frame: " << tmpDir << std::endl;
        }
      gradientDir[0] = tmpDir[0]; gradientDir[1] = tmpDir[1]; gradientDir[2] = tmpDir[2];
      std::cout << "Gradient Directions:  " << gradientDir[0] << ",  " << gradientDir[1] << ",  "<< gradientDir[2] << std::endl;
      tensorFilter->AddGradientImage( gradientDir, gradientImage );
      }
    }
		
  tensorFilter->Update(  );

    
  /* Update the Meta data Header */
  itk::MetaDataDictionary newMeta = tensorFilter->GetOutput()->GetMetaDataDictionary();
  itk::MetaDataDictionary origMeta = vectorImageReader->GetOutput()->GetMetaDataDictionary();
  std::string NrrdValue;
  
  itk::ExposeMetaData<std::string> (origMeta, "DWMRI_b-value", NrrdValue);
  itk::EncapsulateMetaData<std::string> (newMeta, "DWMRI_b-value", NrrdValue);
  
  NrrdValue = "DWMRI";
  itk::EncapsulateMetaData<std::string>(newMeta,"modality",NrrdValue);
  
  for (int i=0;i<4;i++)
    {
    char tmpStr[64];
    sprintf(tmpStr, "NRRD_centerings[%d]", i);
    itk::ExposeMetaData<std::string>(origMeta, tmpStr, NrrdValue);  
    itk::EncapsulateMetaData<std::string>(newMeta, tmpStr, NrrdValue);
    sprintf(tmpStr, "NRRD_kinds[%d]", i);
    itk::ExposeMetaData<std::string>(origMeta, tmpStr, NrrdValue);  
    itk::EncapsulateMetaData<std::string>(newMeta, tmpStr, NrrdValue); 
    sprintf(tmpStr, "NRRD_space units[%d]", i);
    itk::ExposeMetaData<std::string>(origMeta, tmpStr, NrrdValue);  
    itk::EncapsulateMetaData<std::string>(newMeta, tmpStr, NrrdValue); 
    //sprintf(tmpStr, "NRRD_thicknesses[%d]", i);
    //itk::ExposeMetaData<std::string>(origMeta, tmpStr, NrrdValue);  
    //itk::EncapsulateMetaData<std::string>(newMeta, tmpStr, NrrdValue);
    } 
  //std::vector<std::vector<double> > msrFrame(3); 
  //itk::ExposeMetaData<std::vector<std::vector<double> > >(origMeta,"NRRD_measurement frame",msrFrame); 
  //itk::ExposeMetaData<std::vector<std::vector<double> > >(newMeta,"NRRD_measurement frame",msrFrame); 
  
  /*  
  for (int i=0;i<vectorLength;i++)
    {
    char tmpStr[64];
    sprintf(tmpStr,"DWMRI_gradient_%04d", i);
    itk::ExposeMetaData<std::string>(origMeta, tmpStr, NrrdValue);
    itk::EncapsulateMetaData<std::string>(newMeta, tmpStr, NrrdValue);
    }
  */  

  tensorFilter->GetOutput()->SetMetaDataDictionary(newMeta);

  
  typedef TensorFilterType::TensorImageType TensorImageType;
  TensorImageType::Pointer tensorImage = tensorFilter->GetOutput();
  
  typedef itk::ImageFileWriter<TensorImageType> WriterType;
  WriterType::Pointer nrrdWriter = WriterType::New();
  nrrdWriter->SetInput( tensorImage );
  std::cout << tensorImage << std::endl;
  nrrdWriter->SetFileName( outputVolume );
  try
    {
    nrrdWriter->Update();
    }
  catch (itk::ExceptionObject e)
    {
    std::cout << e << std::endl;
    }
    

}
