/*=========================================================================

Program:   DTI BSpline Registration
Language:  C++
Date:      2007/07/11 15:24:18 

=========================================================================*/

#include <iostream>
#include <fstream>

#include <itkImage.h>
#include <metaCommand.h>
#include <itkOrientImageFilter.h>
#include <itkImageFileReader.h>
#include <itkImageFileWriter.h>
#include <itkVersorRigid3DTransformOptimizer.h>
#include <itkImageRegistrationMethod.h>
#include <itkMattesMutualInformationImageToImageMetric.h>
#include <itkLinearInterpolateImageFunction.h>
#include <itkBSplineDeformableTransform.h>
#include <itkVersorRigid3DTransform.h>
#include <itkLBFGSBOptimizer.h>
#include <itkTimeProbesCollectorBase.h>
#include <itkTransformFactory.h>
#include <itkTransformFileWriter.h>
#include <itkTransformFileReader.h>
#include <itkVectorImage.h>
#include <itkVectorIndexSelectionCastImageFilter.h>
#include <itkResampleImageFilter.h>
#include "DTIBsplineRegistrationCLP.h"


//  The following section of code implements a Command observer
//  that will monitor the evolution of the registration process.
//
class CommandIterationUpdate : public itk::Command 
{
public:
  typedef  CommandIterationUpdate   Self;
  typedef  itk::Command             Superclass;
  typedef  itk::SmartPointer<Self>  Pointer;
  itkNewMacro( Self );
protected:
  CommandIterationUpdate() {};
public:
  typedef itk::LBFGSBOptimizer   OptimizerType;
  typedef const OptimizerType   *OptimizerPointer;

  void Execute(itk::Object *caller, const itk::EventObject & event)
    {
      Execute( (const itk::Object *)caller, event);
    }

  void Execute(const itk::Object * object, const itk::EventObject & event)
    {
      OptimizerPointer optimizer = 
        dynamic_cast< OptimizerPointer >( object );
      if( ! itk::IterationEvent().CheckEvent( &event ) )
        {
        return;
        }
      std::cout << optimizer->GetCurrentIteration() << "   ";
      std::cout << optimizer->GetValue() << "   ";
      std::cout << /*optimizer->GetCurrentPosition() <<*/ std::endl;
    }
};


/****************************************************************
 Program: mimxCoRegisterBspline
 
 Purpose: Co-register two image datasets using a b-spline 
          transform

****************************************************************/



int main (int argc, char **argv)
{
  PARSE_ARGS;
  
  std::cout << "Moving Image: " <<  MovingImageFilename << std::endl; 
  std::cout << "Fixed Image: " <<  FixedImageFilename << std::endl; 
  std::cout << "Output Transform: " <<  OutputFilename << std::endl; 
  std::cout << "Input Transform: " <<  InputTransform << std::endl; 
  std::cout << "Resample Image: " << OutputResampledImageFilename << std::endl;
  std::cout << "Deformation Image: " << DeformationFilename << std::endl; 
  std::cout << "Grid X Size: " << GridXSize <<std::endl;
  std::cout << "Grid Y Size: " << GridYSize <<std::endl;
  std::cout << "Grid Z Size: " << GridZSize <<std::endl;
  std::cout << "Border X Size: " << BorderXSize <<std::endl;
  std::cout << "Border Y Size: " << BorderYSize <<std::endl;
  std::cout << "Border Z Size: " << BorderZSize <<std::endl;
  std::cout << "Corrections: " << NumberOfCorrections <<std::endl;
  std::cout << "Evaluations: " << NumberOfEvaluations <<std::endl;
  std::cout << "Histogram: " << HistogramBins <<std::endl;
  std::cout << "Scale: " << SpatialScale <<std::endl;  
  std::cout << "Convergence: " << Convergence <<std::endl;  
  std::cout << "Tolerance: " << Tolerance <<std::endl;  
  std::cout << "Iterations: " << Iterations <<std::endl;
  std::cout << "Index: " << MovingImageIndex <<std::endl;
  std::cout << "Bound X: " << BoundXDeformations <<std::endl;
  std::cout << "Bound Y: " << BoundYDeformations <<std::endl;
  std::cout << "Bound Z: " << BoundZDeformations <<std::endl;
  std::cout << "X Bound: " << xLowerBound << " " << xUpperBound <<std::endl;
  std::cout << "Y Bound: " << yLowerBound << " " << yUpperBound <<std::endl;
  std::cout << "Z Bound: " << zLowerBound << " " << zUpperBound <<std::endl;
  std::cout << "Reorient Anatomical Image: " << OrientAnatomicalImage <<std::endl;   
  //exit(1);
  
  
  
  
  typedef itk::Image<signed short, 3> ImageType;
  typedef itk::VectorImage<signed short,3>               MovingImageType; 
  typedef itk::ImageFileReader<MovingImageType> MovingImageReaderType; 
  typedef itk::ImageFileReader<ImageType> ImageReaderType; 
  
  /* Read the Moving Image */
  MovingImageReaderType::Pointer movingImageReader =  MovingImageReaderType::New();
  movingImageReader->SetFileName( MovingImageFilename.c_str() );
  try 
    {     
    movingImageReader->Update();  
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    throw;
    }
    
   MovingImageType::Pointer img = movingImageReader->GetOutput();
    
 
  typedef itk::VectorIndexSelectionCastImageFilter<MovingImageType, ImageType> VectorSelectFilterType;
  typedef VectorSelectFilterType::Pointer 	VectorSelectFilterPointer;  
  VectorSelectFilterPointer SelectIndexImageFilter = VectorSelectFilterType::New();
  SelectIndexImageFilter->SetIndex( MovingImageIndex );
  SelectIndexImageFilter->SetInput( img );
  try
    {
    SelectIndexImageFilter->Update();
    }
  catch (itk::ExceptionObject e)
    {
    std::cout << e << std::endl;
    }  
  

  typedef itk::OrientImageFilter< ImageType, ImageType>  OrientFilterType;
  ImageType::Pointer movingImage;
  if ( OrientAnatomicalImage )
    {
    OrientFilterType::Pointer orientImageFilter = OrientFilterType::New();
    orientImageFilter->SetInput( SelectIndexImageFilter->GetOutput() );
    orientImageFilter->SetDesiredCoordinateOrientation(itk::SpatialOrientation::ITK_COORDINATE_ORIENTATION_RIP);
    orientImageFilter->UseImageDirectionOn();
    orientImageFilter->Update();
    movingImage = orientImageFilter->GetOutput();
    ImageType::PointType  fixedOrigin  = movingImage->GetOrigin();
    fixedOrigin.Fill(0);
    movingImage->SetOrigin( fixedOrigin );
    }
  else
    {
    movingImage = SelectIndexImageFilter->GetOutput();
    }
    
    


  /* Read the Fixed Image */
  ImageReaderType::Pointer fixedImageReader =  ImageReaderType::New();
  fixedImageReader->SetFileName( FixedImageFilename.c_str() );
  try 
    {     
    fixedImageReader->Update();  
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    throw;
    }

  ImageType::Pointer fixedImage;
  if ( OrientAnatomicalImage )
    {
    OrientFilterType::Pointer orientImageFilter = OrientFilterType::New();
    orientImageFilter->SetInput( fixedImageReader->GetOutput() );
    orientImageFilter->SetDesiredCoordinateOrientation(itk::SpatialOrientation::ITK_COORDINATE_ORIENTATION_RIP);
    orientImageFilter->UseImageDirectionOn();
    orientImageFilter->Update();
    fixedImage = orientImageFilter->GetOutput();
    ImageType::PointType  fixedOrigin  = fixedImage->GetOrigin();
    fixedOrigin.Fill(0);
    fixedImage->SetOrigin( fixedOrigin );
    }
  else
    {
    fixedImage = fixedImageReader->GetOutput();
    }

  /* Now Setup the Registration */
  ImageType::SizeType fixedImageSize = fixedImage->GetBufferedRegion().GetSize();
  const unsigned int numberOfSamples = fixedImage->GetBufferedRegion().GetNumberOfPixels() / 
                                       SpatialScale;


  typedef itk::ImageRegistrationMethod< ImageType, ImageType > RegistrationType;
  RegistrationType::Pointer   registration  = RegistrationType::New();

  typedef itk::LBFGSBOptimizer                OptimizerType;
  OptimizerType::Pointer      optimizer     = OptimizerType::New();
  registration->SetOptimizer(     optimizer     );

  typedef itk::MattesMutualInformationImageToImageMetric<
                ImageType,
                ImageType >                 MetricType;
  MetricType::Pointer         metric      = MetricType::New();
  metric->SetNumberOfHistogramBins( HistogramBins );
  metric->SetNumberOfSpatialSamples( numberOfSamples );
  metric->ReinitializeSeed( 76926294 );
  registration->SetMetric(        metric        );

  typedef itk:: LinearInterpolateImageFunction<
        ImageType,
        double          >                    InterpolatorType;
  InterpolatorType::Pointer  interpolator  = InterpolatorType::New();
  registration->SetInterpolator(  interpolator  );

  static const unsigned int SpaceDimension = 3;
  static const unsigned int SplineOrder = 3;
  typedef double CoordinateRepType;
  typedef itk::BSplineDeformableTransform<
               CoordinateRepType,
               SpaceDimension,
               SplineOrder >              TransformType;
  TransformType::Pointer finalTransform = TransformType::New( );
  registration->SetTransform( finalTransform );



  registration->SetFixedImage(  fixedImage   );
  registration->SetMovingImage(   movingImage   );
  registration->SetFixedImageRegion( fixedImage->GetBufferedRegion() );

  /*** Setup the B-SPline Parameters ***/
  TransformType::RegionType    bsplineRegion;
  TransformType::SizeType      gridSizeOnImage;
  TransformType::SizeType      gridBorderSize;
  TransformType::SizeType      totalGridSize;

  gridSizeOnImage[0] = GridXSize;
  gridSizeOnImage[1] = GridYSize;
  gridSizeOnImage[2] = GridZSize;
  gridBorderSize[0]  = BorderXSize;    // Border for spline order = 3 ( 1 lower, 2 upper )
  gridBorderSize[1]  = BorderYSize;    // Border for spline order = 3 ( 1 lower, 2 upper )
  gridBorderSize[2]  = BorderZSize;    // Border for spline order = 3 ( 1 lower, 2 upper )

  totalGridSize = gridSizeOnImage + gridBorderSize;
  bsplineRegion.SetSize( totalGridSize );

  TransformType::SpacingType spacing = fixedImage->GetSpacing();
  TransformType::OriginType origin = fixedImage->GetOrigin();

  for (unsigned int r=0; r<3; r++)
    {
    spacing[r] *= floor( static_cast<double>(fixedImageSize[r] - 1)  /
                         static_cast<double>(gridSizeOnImage[r] - 1) );
    origin[r]  -=  spacing[r];
    }

  finalTransform->SetGridSpacing( spacing );
  finalTransform->SetGridOrigin( origin );
  finalTransform->SetGridRegion( bsplineRegion );

  /* Fix this */
  if ( InputTransform.length() > 0 )
    {
    std::cout << "Using Bulk Transform: " << InputTransform << std::endl;
    itk::TransformFileReader::Pointer transformReader =  itk::TransformFileReader::New();
    transformReader->SetFileName( InputTransform.c_str() );
    try 
      {
      transformReader->Update( );  
      }
    catch (itk::ExceptionObject &ex)
      {
      std::cout << ex << std::endl;
      throw;
      }

    typedef itk::VersorRigid3DTransform< double >     BulkTransformType;
    BulkTransformType::Pointer bulkTransform = BulkTransformType::New();

    std::string readTransformType = (transformReader->GetTransformList()->back())->GetTransformTypeAsString();
    if ( strcmp(readTransformType.c_str(),"VersorRigid3DTransform_double_3_3") == 0)
      {
      bulkTransform->SetIdentity();
      bulkTransform->SetParameters(
                (*transformReader->GetTransformList()->begin())->GetParameters() );
      bulkTransform->SetFixedParameters(
                (*transformReader->GetTransformList()->begin())->GetFixedParameters() );
      }
    else
      {
      std::cout << "Error: Invalid Bulk Transform Type! " << std::endl;
      std::cout << "Only the VersorRigid3DTransform_double_3_3 transform is currently supported." << std::endl;
      }
    finalTransform->SetBulkTransform(   bulkTransform   );
    }

  const unsigned int numberOfParameters = finalTransform->GetNumberOfParameters();
  TransformType::ParametersType parameters( numberOfParameters );
  parameters.Fill( 0.0 );

  finalTransform->SetParameters( parameters );
  registration->SetInitialTransformParameters( finalTransform->GetParameters() );

  OptimizerType::BoundSelectionType boundSelect( finalTransform->GetNumberOfParameters() );
  OptimizerType::BoundValueType     upperBound( finalTransform->GetNumberOfParameters() );
  OptimizerType::BoundValueType     lowerBound( finalTransform->GetNumberOfParameters() );


  /* User Specifies the Deformation Bounds in X,Y,Z */
  /*    Default is unbounded                        */
  for ( int i=0; i<boundSelect.size(); i+=3)
    {
    boundSelect[i+0] = BoundXDeformations;
    boundSelect[i+1] = BoundYDeformations;
    boundSelect[i+2] = BoundZDeformations;
    lowerBound[i+0]  = xLowerBound;
    lowerBound[i+1]  = yLowerBound;
    lowerBound[i+2]  = zLowerBound;
    upperBound[i+0]  = xUpperBound;
    upperBound[i+1]  = yUpperBound;
    upperBound[i+2]  = zUpperBound;
    }


  optimizer->SetBoundSelection( boundSelect );
  optimizer->SetUpperBound( upperBound );
  optimizer->SetLowerBound( lowerBound );

  optimizer->SetCostFunctionConvergenceFactor( Convergence );
  optimizer->SetProjectedGradientTolerance( Tolerance );
  optimizer->SetMaximumNumberOfIterations( Iterations );
  optimizer->SetMaximumNumberOfEvaluations( NumberOfEvaluations );
  optimizer->SetMaximumNumberOfCorrections( NumberOfCorrections );

  // Add a time probe
  itk::TimeProbesCollectorBase collector;
  std::cout << std::endl << "Starting Registration" << std::endl;

  // Create the Command observer and register it with the registration filter.
  //
  CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
  optimizer->AddObserver( itk::IterationEvent(), observer );

  try
    {
    collector.Start( "Registration" );
    registration->StartRegistration();
    collector.Stop( "Registration" );
    }
  catch( itk::ExceptionObject & err )
    {
    std::cerr << "ExceptionObject caught !" << std::endl;
    std::cerr << err << std::endl;
    throw;
    }

  OptimizerType::ParametersType finalParameters =
            registration->GetLastTransformParameters();

  collector.Report();

  /* This call is required to copy the parameters */
  finalTransform->SetParametersByValue( finalParameters );

  std::cout << finalTransform << std::endl;

  itk::TransformFileWriter::Pointer transformWriter =  itk::TransformFileWriter::New();
  transformWriter->SetFileName( OutputFilename.c_str() );
  if ( InputTransform.length() > 0 )
    {
    transformWriter->AddTransform( finalTransform->GetBulkTransform( ) );
    }
  transformWriter->AddTransform( finalTransform );
  try 
    {
    transformWriter->Update( );  
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    throw;
    }
    
  // Resample image
    
  typedef itk::ResampleImageFilter< ImageType, ImageType >    ResampleFilterType;

  ResampleFilterType::Pointer resampler = ResampleFilterType::New();

  resampler->SetTransform( finalTransform );
  resampler->SetInput( movingImage );

  resampler->SetSize(    fixedImage->GetLargestPossibleRegion().GetSize() );
  resampler->SetOutputOrigin(  fixedImage->GetOrigin() );
  resampler->SetOutputSpacing( fixedImage->GetSpacing() );
  resampler->SetOutputDirection(fixedImage->GetDirection());
  resampler->SetDefaultPixelValue(0);
  
   try
    {
    resampler->Update();
    }
  catch( itk::ExceptionObject & errr ) 
    { 
    std::cerr << "ExceptionObject caught !" << std::endl; 
    std::cerr << errr << std::endl; 
    throw;
    } 

  //Write resampled Image
  typedef  signed short OutputPixelType;

  typedef itk::Image< OutputPixelType, 3 > OutputImageType;
                  
  typedef itk::ImageFileWriter< OutputImageType >  WriterType;

  WriterType::Pointer  writer =  WriterType::New();
  writer->SetFileName( OutputResampledImageFilename );
  writer->SetInput( resampler->GetOutput()   );
  writer->Update();
  
  // Generate Deformation Field
  
  // Generate the explicit deformation field resulting from 
  // the registration.
  
    typedef itk::Vector< float, 3>  VectorType;
    typedef itk::Image< VectorType, 3 >  DeformationFieldType;

    DeformationFieldType::Pointer field = DeformationFieldType::New();
    DeformationFieldType::RegionType region;
    region.SetSize(fixedImage->GetLargestPossibleRegion().GetSize());
    field->SetRegions(region );
    field->SetOrigin( fixedImage->GetOrigin() );
    field->SetSpacing( fixedImage->GetSpacing() );
    field->Allocate();

    typedef itk::ImageRegionIterator< DeformationFieldType > FieldIterator;
    FieldIterator fi( field, region );

    fi.GoToBegin();

    TransformType::InputPointType  fixedPoint;
    TransformType::OutputPointType movingPoint;
    DeformationFieldType::IndexType index;

    VectorType displacement;

    while( ! fi.IsAtEnd() )
      {
      index = fi.GetIndex();
      field->TransformIndexToPhysicalPoint( index, fixedPoint );
      movingPoint = finalTransform->TransformPoint( fixedPoint );
     displacement[0] = movingPoint[0] - fixedPoint[0];
     displacement[1] = movingPoint[1] - fixedPoint[1];
     displacement[2] = movingPoint[2] - fixedPoint[2];
      fi.Set( displacement );
      ++fi;
      }

    typedef itk::ImageFileWriter< DeformationFieldType >  FieldWriterType;
    FieldWriterType::Pointer fieldWriter = FieldWriterType::New();

    fieldWriter->SetInput( field );

    fieldWriter->SetFileName( DeformationFilename );
    try
      {
      fieldWriter->Update();
      }
    catch( itk::ExceptionObject & excp )
      {
      std::cerr << "Exception thrown " << std::endl;
      std::cerr << excp << std::endl;
      return EXIT_FAILURE;
      }
   


  exit(0);

}    
