/*=========================================================================

Program:   General Affine Registration
Language:  C++
Date:      2007/07/11 15:24:18 

=========================================================================*/

#include <iostream>
#include <fstream>

#include <itkImage.h>
#include <metaCommand.h>
#include "itkImageRegistrationMethod.h"
#include "itkCenteredAffineTransform.h"
#include "itkMattesMutualInformationImageToImageMetric.h"
#include "itkLinearInterpolateImageFunction.h"
#include "itkRegularStepGradientDescentOptimizer.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkResampleImageFilter.h"
#include "itkCastImageFilter.h"
#include "itkCommand.h"
#include "itkVector.h"
#include "itkTransformFileWriter.h"
#include "itkCenteredTransformInitializer.h"
#include "itkExtractImageFilter.h"
#include "itkVectorImage.h"
#include "itkVectorIndexSelectionCastImageFilter.h"
#include "GeneralAffineRegistrationCLP.h"



//  The following section of Code implements a Command observer
//  that will monitor the evolution of the registration process.

class CommandIterationUpdate : public itk::Command 
{
public:
  typedef  CommandIterationUpdate   Self;
  typedef  itk::Command             Superclass;
  typedef itk::SmartPointer<Self>  Pointer;
  itkNewMacro( Self );
protected:
  CommandIterationUpdate() {};
public:
  typedef itk::RegularStepGradientDescentOptimizer     OptimizerType;
  typedef   const OptimizerType   *    OptimizerPointer;

  void Execute(itk::Object *caller, const itk::EventObject & event)
    {
      Execute( (const itk::Object *)caller, event);
    }

  void Execute(const itk::Object * object, const itk::EventObject & event)
    {
      OptimizerPointer optimizer = 
        dynamic_cast< OptimizerPointer >( object );
      if( ! itk::IterationEvent().CheckEvent( &event ) )
        {
        return;
        }
      std::cout << optimizer->GetCurrentIteration() << "   ";
      std::cout << optimizer->GetValue() << "   ";
      std::cout << optimizer->GetCurrentPosition() << std::endl;
    }
};


/****************************************************************
 Program: AffineRegistration 
 
 Purpose: Co-register two image datasets using an affine 
          transform where the moving image is an DTI -B0

****************************************************************/



int main( int argc, char *argv[] )
{

  PARSE_ARGS;
  
  
  std::cout << "Moving Image: " <<  MovingImageFilename << std::endl; 
  std::cout << "Fixed Image: " <<  FixedImageFilename << std::endl; 
  std::cout << "Output Transform: " <<  OutputXformFilename << std::endl; 
  std::cout << "Output Inverse Transform: " <<  OutputInverseXformFilename << std::endl;
  std::cout << "Resample Image: " << OutputResampledImageFilename << std::endl; 
  std::cout << "Iterations: " << Iterations <<std::endl;
  std::cout << "Max Step Length: " << MaxStepLength <<std::endl; 
  std::cout << "Min Step Length: " << MinStepLength <<std::endl; 
  std::cout << "No Of Histogram Bins: " << Bins <<std::endl; 
  std::cout << "No of Spatial Samples: " << SpatialSamples <<std::endl; 
  std::cout << "Gradient Tolerance for the optimizer: " << GradientTolerance <<std::endl; 


  //Instantiate required classes
  
  const    unsigned int    dimension = 3;
  typedef unsigned short PixelType;  
  typedef itk::Image<PixelType,dimension>             FixedImageType; 
  typedef itk::Image<PixelType,3>               MovingImageType;
  typedef itk::Vector<double,3> VectorType;  
  typedef itk::CenteredAffineTransform< double,dimension >      AffineTransformType;
  typedef itk::RegularStepGradientDescentOptimizer                 OptimizerType; 
  typedef itk::MattesMutualInformationImageToImageMetric< FixedImageType, MovingImageType >    MetricType;
  typedef itk::LinearInterpolateImageFunction< MovingImageType, double >   InterpolatorType;
  typedef itk::ImageRegistrationMethod<  FixedImageType,   MovingImageType >    RegistrationType;
 

  MetricType::Pointer               metric        = MetricType::New();
  AffineTransformType::Pointer      transform     = AffineTransformType::New();
  AffineTransformType::Pointer      invTransform  = AffineTransformType::New();
  OptimizerType::Pointer            optimizer     = OptimizerType::New();
  InterpolatorType::Pointer         interpolator  = InterpolatorType::New();
  RegistrationType::Pointer         registration  = RegistrationType::New();
  FixedImageType::Pointer     	    fixedImage    = FixedImageType::New();  
  
  metric->SetNumberOfHistogramBins( Bins );
  metric->SetNumberOfSpatialSamples( SpatialSamples );
  
  
  // Read Input Images
  
  typedef itk::ImageFileReader< FixedImageType  > FixedImageReaderType;
  typedef itk::ImageFileReader< MovingImageType > MovingImageReaderType;

  FixedImageReaderType::Pointer  fixedImageReader  = FixedImageReaderType::New();
  MovingImageReaderType::Pointer movingImageReader = MovingImageReaderType::New();

  fixedImageReader->SetFileName(  FixedImageFilename.c_str() );
  movingImageReader->SetFileName( MovingImageFilename.c_str() );
  MovingImageType::Pointer img;
  
   try 
    { 
    movingImageReader->Update();
    fixedImageReader->Update();
    
    } 
  catch( itk::ExceptionObject & ee ) 
    { 
    std::cerr << "Reader Exception Object caught !" << std::endl; 
    std::cerr << ee << std::endl; 
    return -1;
    } 
    
    img = movingImageReader->GetOutput();  
    FixedImageType::Pointer  im =  fixedImageReader->GetOutput();     
    std::cout << img << std::endl;
    std::cout << im << std::endl;
      
  registration->SetMetric(        metric        );
  registration->SetOptimizer(     optimizer     );
  registration->SetInterpolator(  interpolator  );  
  registration->SetTransform( transform ); 
  registration->SetFixedImage(    im    ); 
  registration->SetMovingImage(   img  );
 
  
  
  registration->SetFixedImageRegion( fixedImageReader->GetOutput()->GetBufferedRegion() );
  
  // Use centre of Mass 
  
  typedef itk::CenteredTransformInitializer< AffineTransformType, FixedImageType, MovingImageType >  TransformInitializerType;
  TransformInitializerType::Pointer initializer = TransformInitializerType::New();
  
  initializer->SetTransform(   transform );
  initializer->SetFixedImage(  im );
  initializer->SetMovingImage(img);
  initializer->MomentsOn();  
  initializer->InitializeTransform();
   
  registration->SetInitialTransformParameters( transform->GetParameters() ); 
  
  
    
  //SET OPTIMIZER 
  
  typedef OptimizerType::ScalesType ScalesType;
  ScalesType parametersScales( transform->GetNumberOfParameters() );

  // set initial parameters to 1.0

  parametersScales.Fill( 1.0 );

  // Set translation parameters to 0.001

  for (int j = 12; j < 15; j++ )
    {
    parametersScales[j] = 0.0001;
    }

  optimizer->SetScales( parametersScales ); 
  optimizer->SetNumberOfIterations( Iterations );
  optimizer->SetMaximumStepLength( MaxStepLength ); 
  optimizer->SetMinimumStepLength( MinStepLength );
  optimizer->SetGradientMagnitudeTolerance( GradientTolerance );
  optimizer->MinimizeOn();
  
  
  // Create the Command observer and register it with the optimizer. 
   
  CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
  optimizer->AddObserver( itk::IterationEvent(), observer );  
   
  
  
  try 
    { 
    registration->StartRegistration(); 
    } 
  catch( itk::ExceptionObject & err ) 
    { 
    std::cerr << "Registration Exception Object caught !" << std::endl; 
    std::cerr << err << std::endl; 
    return -1;
    } 
    
   //GET RESULTS
  
  RegistrationType::ParametersType finalParameters = registration->GetLastTransformParameters(); 
  
  transform->SetParameters( finalParameters );
  VectorType  trans = transform->GetTranslation();
  std::cout << trans << std::endl;
   
   //Write Transform

   typedef itk::TransformFileWriter TransformWriterType;
   TransformWriterType::Pointer transformWriter = TransformWriterType::New();
   transformWriter->SetFileName( OutputXformFilename );
   transformWriter->SetInput( transform );
  
  try
   {
   transformWriter->Update();
   }
  catch( itk::ExceptionObject & excp )
   {
   std::cerr << "Error while saving the transforms" << std::endl;
   std::cerr << excp << std::endl;
   return 0;
   }
   
   //Write Inverse Transform
   
    if ( OutputInverseXformFilename.length() > 0 )
  {
    transform->GetInverse( invTransform );
	TransformWriterType::Pointer inverseTransformWriter = TransformWriterType::New();
    inverseTransformWriter->SetFileName( OutputInverseXformFilename );
    inverseTransformWriter->SetInput( invTransform );
    try
      {
      inverseTransformWriter->Update();
      }
    catch( itk::ExceptionObject & errr ) 
      { 
      std::cerr << "ExceptionObject caught !" << std::endl; 
      std::cerr << errr << std::endl; 
    throw;
      } 
  }  
   
  
       
  //Resample DTI Image
  typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;

  AffineTransformType::Pointer finalTransform = AffineTransformType::New();

  finalTransform->SetCenter( transform->GetCenter() );

  finalTransform->SetParameters( finalParameters );

  ResampleFilterType::Pointer resampler = ResampleFilterType::New();

  resampler->SetTransform(finalTransform);
  resampler->SetInput( img );

  fixedImage = fixedImageReader->GetOutput();

  resampler->SetSize(    fixedImage->GetLargestPossibleRegion().GetSize() );
  resampler->SetOutputOrigin(  fixedImage->GetOrigin() );
  resampler->SetOutputSpacing( fixedImage->GetSpacing() );
  resampler->SetOutputDirection(fixedImage->GetDirection());
  resampler->SetDefaultPixelValue(0);
  
   try
    {
    resampler->Update();
    }
  catch( itk::ExceptionObject & errr ) 
    { 
    std::cerr << "ExceptionObject caught !" << std::endl; 
    std::cerr << errr << std::endl; 
    throw;
    } 

  //Write resampled Image
  typedef  unsigned short OutputPixelType;

  typedef itk::Image< OutputPixelType, 3 > OutputImageType;
 
                    
  typedef itk::ImageFileWriter< OutputImageType >  WriterType;

  WriterType::Pointer      writer =  WriterType::New();
 
  writer->SetFileName( OutputResampledImageFilename );

//  caster->SetInput( resampler->GetOutput() );
  writer->SetInput( resampler->GetOutput() );
  writer->Update();

return 0;
}
