/*=========================================================================

Program:   General Rigid Registration
Language:  C++
Author:    Madhura A. Ingalhalikar
Date:      2007/07/11 15:24:18 

=========================================================================*/

#include <iostream>
#include <fstream>

#include <itkImage.h>
#include <metaCommand.h>
#include "itkImageRegistrationMethod.h"
#include "itkLinearInterpolateImageFunction.h"
#include "itkImage.h"
#include "itkVersorRigid3DTransform.h"
#include "itkCenteredTransformInitializer.h"
#include "itkVersorRigid3DTransformOptimizer.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkResampleImageFilter.h"
#include "itkCastImageFilter.h"
#include "itkTransformFileWriter.h"
#include "itkMattesMutualInformationImageToImageMetric.h"
#include "itkCommand.h"
#include "itkExtractImageFilter.h"
#include "itkVectorImage.h"
#include "itkVectorIndexSelectionCastImageFilter.h"
#include "itkOrientImageFilter.h"
#include "GeneralRigidRegistrationCLP.h"

//  The following section of code implements a Command observer
//  that will monitor the evolution of the registration process.
//

class CommandIterationUpdate : public itk::Command 
{
public:
  typedef  CommandIterationUpdate   Self;
  typedef  itk::Command             Superclass;
  typedef itk::SmartPointer<Self>  Pointer;
  itkNewMacro( Self );
protected:
  CommandIterationUpdate() {};
public:
  typedef itk::VersorRigid3DTransformOptimizer     OptimizerType;
  typedef   const OptimizerType   *    OptimizerPointer;

  void Execute(itk::Object *caller, const itk::EventObject & event)
    {
      Execute( (const itk::Object *)caller, event);
    }

  void Execute(const itk::Object * object, const itk::EventObject & event)
    {
      OptimizerPointer optimizer = 
        dynamic_cast< OptimizerPointer >( object );
      if( ! itk::IterationEvent().CheckEvent( &event ) )
        {
        return;
        }
      std::cout << optimizer->GetCurrentIteration() << "   ";
      std::cout << optimizer->GetValue() << "   ";
      std::cout << optimizer->GetCurrentPosition() << std::endl;
    }
};


int main( int argc, char *argv[] )
{
  PARSE_ARGS;
  
  std::cout << "Moving Image: " <<  MovingImageFilename << std::endl; 
  std::cout << "Fixed Image: " <<  FixedImageFilename << std::endl; 
  std::cout << "Output Transform: " <<  OutputXformFilename << std::endl; 
  std::cout << "Resample Image: " << OutputResampledImageFilename << std::endl; 
  std::cout << "Iterations: " << Iterations <<std::endl;
  std::cout << "Max Step Length: " << MaxStepLength <<std::endl; 
  std::cout << "Min Step Length: " << MinStepLength <<std::endl; 
  

  
  const    unsigned int    Dimension = 3;
  typedef  signed short          PixelType;

  typedef itk::Image< PixelType, Dimension >  FixedImageType;
  typedef itk::Image< PixelType, 3 >  MovingImageType;  
  typedef itk::VersorRigid3DTransform< double > RigidTransformType;
  typedef itk::VersorRigid3DTransformOptimizer       OptimizerType;
  typedef itk::MattesMutualInformationImageToImageMetric< FixedImageType, MovingImageType >    MetricType;
  typedef itk::LinearInterpolateImageFunction< MovingImageType,double >    InterpolatorType;
  typedef itk::ImageRegistrationMethod< FixedImageType, MovingImageType >    RegistrationType;

  
  MetricType::Pointer         metric        = MetricType::New();
  OptimizerType::Pointer      optimizer     = OptimizerType::New();
  InterpolatorType::Pointer   interpolator  = InterpolatorType::New();
  RigidTransformType::Pointer  transform	= RigidTransformType::New();
  RegistrationType::Pointer   registration  = RegistrationType::New();

   // Read Input Images

  typedef itk::ImageFileReader< FixedImageType  > FixedImageReaderType;
  typedef itk::ImageFileReader< MovingImageType > MovingImageReaderType;

  FixedImageReaderType::Pointer  fixedImageReader  = FixedImageReaderType::New();
  MovingImageReaderType::Pointer movingImageReader = MovingImageReaderType::New();

  fixedImageReader->SetFileName( FixedImageFilename.c_str());
  movingImageReader->SetFileName( MovingImageFilename.c_str() );
  MovingImageType::Pointer movingImage;
  FixedImageType::Pointer fixedImage;

  try 
    { 
    movingImageReader->Update();
    fixedImageReader->Update();
    
    } 
  catch( itk::ExceptionObject & ee ) 
    { 
    std::cerr << "Exception Object caught !" << std::endl; 
    std::cerr << ee << std::endl; 
    return -1;
    } 
    
    
    
 
    movingImage = movingImageReader->GetOutput();
  
    fixedImage = fixedImageReader->GetOutput();
   
    
     

  registration->SetMetric(        metric        );
  registration->SetOptimizer(     optimizer     );
  registration->SetInterpolator(  interpolator  );  
  registration->SetTransform( transform ); 
  

  registration->SetFixedImage(    fixedImage    );
  registration->SetMovingImage(   movingImage   );
    
  registration->SetFixedImageRegion( fixedImage->GetBufferedRegion() ); 

  
  // Use centre of Mass
  
  
  typedef itk::CenteredTransformInitializer< RigidTransformType, FixedImageType, MovingImageType >  TransformInitializerType;
  TransformInitializerType::Pointer initializer = TransformInitializerType::New();
  
  initializer->SetTransform(   transform );
  initializer->SetFixedImage( fixedImage);
  initializer->SetMovingImage( movingImage);
 
  initializer->MomentsOn();  
  initializer->InitializeTransform();


  
  typedef RigidTransformType::VersorType  VersorType;
  typedef VersorType::VectorType     VectorType;

  VersorType     rotation;
  VectorType     axis;
  
  axis[0] = 0.0;
  axis[1] = 0.0;
  axis[2] = 1.0;

  const double angle = 0;

  rotation.Set(  axis, angle  );

  transform->SetRotation( rotation );
 
  registration->SetInitialTransformParameters( transform->GetParameters() );
  
  //SET OPTIMIZER

  typedef OptimizerType::ScalesType       OptimizerScalesType;
  OptimizerScalesType optimizerScales( transform->GetNumberOfParameters() );
  const double translationScale = 1.0 / 1000.0;
  optimizerScales[0] = 1.0;
  optimizerScales[1] = 1.0;
  optimizerScales[2] = 1.0;
  optimizerScales[3] = translationScale;
  optimizerScales[4] = translationScale;
  optimizerScales[5] = translationScale;
  optimizer->SetScales( optimizerScales );
  optimizer->SetMaximumStepLength( MaxStepLength ); 
  optimizer->SetMinimumStepLength( MinStepLength );
  optimizer->SetNumberOfIterations( Iterations );


  // Create the Command observer and register it with the optimizer.  
  CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
  optimizer->AddObserver( itk::IterationEvent(), observer );
  
  try 
    { 
    registration->StartRegistration(); 
    } 
  catch( itk::ExceptionObject & err ) 
    { 
    std::cerr << "ExceptionObject caught !" << std::endl; 
    std::cerr << err << std::endl; 
    return -1;
    } 
  
  OptimizerType::ParametersType finalParameters = 
                    registration->GetLastTransformParameters();


  const double versorX              = finalParameters[0];
  const double versorY              = finalParameters[1];
  const double versorZ              = finalParameters[2];
  const double finalTranslationX    = finalParameters[3];
  const double finalTranslationY    = finalParameters[4];
  const double finalTranslationZ    = finalParameters[5];

  const unsigned int numberOfIterations = optimizer->GetCurrentIteration();

  const double bestValue = optimizer->GetValue();

  // Print out results
 
  std::cout << std::endl << std::endl;
  std::cout << "Result = " << std::endl;
  std::cout << " versor X      = " << versorX  << std::endl;
  std::cout << " versor Y      = " << versorY  << std::endl;
  std::cout << " versor Z      = " << versorZ  << std::endl;
  std::cout << " Translation X = " << finalTranslationX  << std::endl;
  std::cout << " Translation Y = " << finalTranslationY  << std::endl;
  std::cout << " Translation Z = " << finalTranslationZ  << std::endl;
  std::cout << " Iterations    = " << numberOfIterations << std::endl;
  std::cout << " Metric value  = " << bestValue          << std::endl;    

  transform->SetParameters( finalParameters );

  RigidTransformType::MatrixType matrix = transform->GetRotationMatrix();
 // RigidTransformType::OffsetType offset = transform->GetOffset();

  std::cout << "Matrix = " << std::endl << matrix << std::endl;
  //std::cout << "Offset = " << std::endl << offset << std::endl; 

//Write Transform
  typedef itk::TransformFileWriter TransformWriterType;
   TransformWriterType::Pointer transformWriter = TransformWriterType::New();
   transformWriter->SetFileName( OutputXformFilename );
   transformWriter->SetInput( transform );
  
  try
   {
   transformWriter->Update();
   }
  catch( itk::ExceptionObject & excp )
   {
   std::cerr << "Error while saving the transforms" << std::endl;
   std::cerr << excp << std::endl;
   return 0;
   }
  
//Resample Image
  typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;

  RigidTransformType::Pointer finalTransform = RigidTransformType::New();

  finalTransform->SetCenter( transform->GetCenter() );

  finalTransform->SetParameters( finalParameters );

  ResampleFilterType::Pointer resampler = ResampleFilterType::New();

  resampler->SetTransform( finalTransform );
  resampler->SetInput( movingImage);

  resampler->SetSize(    fixedImage->GetLargestPossibleRegion().GetSize() );
  resampler->SetOutputOrigin(  fixedImage->GetOrigin() );
  resampler->SetOutputSpacing( fixedImage->GetSpacing() );
  resampler->SetDefaultPixelValue( 100 );

  try
    {
    resampler->Update();
    }
  catch( itk::ExceptionObject & errr ) 
    { 
    std::cerr << "ExceptionObject caught !" << std::endl; 
    std::cerr << errr << std::endl; 
    throw;
    } 
  
  //Write resampled Image
 
  typedef itk::Image< PixelType, Dimension > OutputImageType;
 
                   
  typedef itk::ImageFileWriter< OutputImageType >  WriterType;

  WriterType::Pointer      writer =  WriterType::New();

  writer->SetFileName( OutputResampledImageFilename );

  writer->SetInput( resampler->GetOutput()   );
  writer->Update();

  


  return 0;
}

