/*=========================================================================

  Program:   Registration stand-alone
  Module:    $RCSfile: $
  Language:  C++
  Date:      $Date: 2008-07-30 22:35:51 +0900 (水, 30 7 2008) $
  Version:   $Revision: 7363 $

=========================================================================*/
#include <cstdio>
#include <cstring>
#include <cstdlib>

#include <iostream>

#include "itkImage.h"
#include "itkOrientImageFilter.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkTransformFileReader.h"
#include "itkTransformFileWriter.h"

#include "itkRegularStepGradientDescentOptimizer.h"
#include "itkImageFileWriter.h"
#include "itkImageRegistrationMethod.h"
#include "itkLinearInterpolateImageFunction.h"
#include "itkMattesMutualInformationImageToImageMetric.h"
#include "itkAffineTransform.h"
#include "itkResampleImageFilter.h"

#include "itkTimeProbesCollectorBase.h"

#include "itkAffineRegistration.h"

#include "itkCenteredAffineTransform.h"

using namespace std;

#define ORIENT

// #ifdef ORIENT
//   typedef itk::Image<short, 3> ImageType;
// #else
//   typedef itk::Image<short, 3> ImageType;
// #endif // ORIENT
//   typedef ImageType::Pointer ImagePointerType;





//  The following section of code implements a Command observer
//  used to monitor the evolution of the registration process.
//
#include "itkCommand.h"

template <class TImage>
int myAffReg<TImage>::DoIt(TImage *outputimage)
{
  //
  // Command line processing
  //
  //PARSE_ARGS;

  typedef itk::ImageFileReader<ImageType> FileReaderType;
  typedef itk::OrientImageFilter<ImageType,ImageType> OrientFilterType;
  typedef itk::MattesMutualInformationImageToImageMetric<ImageType, ImageType>
    MetricType;
  typedef itk::RegularStepGradientDescentOptimizer
    OptimizerType;
  typedef itk::LinearInterpolateImageFunction<ImageType, double>
    InterpolatorType;
  typedef itk::ImageRegistrationMethod<ImageType,ImageType>
    RegistrationType;
  typedef itk::AffineTransform<double> TransformType;
  typedef OptimizerType::ScalesType OptimizerScalesType;
  typedef itk::ResampleImageFilter<ImageType,ImageType> ResampleType;
  typedef itk::LinearInterpolateImageFunction<ImageType, double> ResampleInterpolatorType;
  typedef itk::ImageFileWriter<ImageType> WriterType;
  typedef itk::ContinuousIndex<double, 3> ContinuousIndexType;

  // Add a time probe
  itk::TimeProbesCollectorBase collector;

  typedef itk::TransformFileReader TransformReaderType;
  TransformReaderType::Pointer initialTransform;

  if (m_InitialTransform != "")
    {
    initialTransform= TransformReaderType::New();
    initialTransform->SetFileName( m_InitialTransform );
    try
      {
      initialTransform->Update();
      }
    catch (itk::ExceptionObject &err)
      {
      std::cerr << err << std::endl;
      return  EXIT_FAILURE;
      }
    }


  // Reorient to axials to avoid issues with registration metrics not
  // transforming image gradients with the image orientation in
  // calculating the derivative of metric wrt transformation
  // parameters.
  //
  // Forcing image to be axials avoids this problem. Note, that
  // reorientation only affects the internal mapping from index to
  // physical coordinates.  The reoriented data spans the same
  // physical space as the original data.  Thus, the registration
  // transform calculated on the reoriented data is also the
  // transform forthe original un-reoriented data. 

  typename  OrientFilterType::Pointer orientFixed = OrientFilterType::New();
  orientFixed->UseImageDirectionOn();
  orientFixed->SetDesiredCoordinateOrientationToAxial();
  //orientFixed->SetInput (fixedReader->GetOutput());
  orientFixed->SetInput (m_FixedImage);
  collector.Start( "Orient fixed volume" );
  orientFixed->Update();
  collector.Stop( "Orient fixed volume" );

  typename OrientFilterType::Pointer orientMoving = OrientFilterType::New();
  orientMoving->UseImageDirectionOn();
  orientMoving->SetDesiredCoordinateOrientationToAxial();
//  orientMoving->SetInput (movingReader->GetOutput());
  orientMoving->SetInput (m_MovingImage);
  collector.Start( "Orient moving volume" );
  orientMoving->Update();
  collector.Stop( "Orient moving volume" );


   cout << "before optimizer" << endl;

  // Set up the optimizer
  //
  //
  typename OptimizerType::Pointer      optimizer     = OptimizerType::New();
    optimizer->SetNumberOfIterations ( m_Iterations );
    optimizer->SetMinimumStepLength ( .0005 );
    optimizer->SetMaximumStepLength ( 10.0 );
    optimizer->SetMinimize(true);

  typename TransformType::Pointer transform = TransformType::New();
  OptimizerScalesType scales( transform->GetNumberOfParameters() );
    scales.Fill ( 1.0 );
  for( unsigned j = 9; j < 12; j++ )
    {
    scales[j] = 1.0 / vnl_math_sqr(m_TranslationScale);
    }
    optimizer->SetScales( scales );

  cout << "before transform" << endl;

  // Initialize the transform
  //
  //
  typename TransformType::InputPointType centerFixed;
  typename ImageType::RegionType::SizeType sizeFixed = orientFixed->GetOutput()->GetLargestPossibleRegion().GetSize();
  //ImageType::RegionType::SizeType sizeFixed = m_FixedImage->GetLargestPossibleRegion().GetSize();
  // Find the center
  ContinuousIndexType indexFixed;
  for ( unsigned j = 0; j < 3; j++ )
    {
    indexFixed[j] = (sizeFixed[j]-1) / 2.0;
    }
  orientFixed->GetOutput()->TransformContinuousIndexToPhysicalPoint ( indexFixed, centerFixed );
  //m_FixedImage->TransformContinuousIndexToPhysicalPoint ( indexFixed, centerFixed );

  typename TransformType::InputPointType centerMoving;
  typename ImageType::RegionType::SizeType sizeMoving = orientMoving->GetOutput()->GetLargestPossibleRegion().GetSize();
  //ImageType::RegionType::SizeType sizeMoving = m_MovingImage->GetLargestPossibleRegion().GetSize();
  // Find the center
  ContinuousIndexType indexMoving;
  for ( unsigned j = 0; j < 3; j++ )
    {
    indexMoving[j] = (sizeMoving[j]-1) / 2.0;
    }
  orientMoving->GetOutput()->TransformContinuousIndexToPhysicalPoint ( indexMoving, centerMoving );
   //m_MovingImage->TransformContinuousIndexToPhysicalPoint ( indexMoving, centerMoving );

  transform->SetCenter( centerFixed );
  transform->Translate(centerMoving-centerFixed);
  std::cout << "Centering transform: "; transform->Print( std::cout );

  // If an initial transformation was provided, then use it instead.
  // (Should this be instead of the centering transform or composed
  // with the centering transform.)
  //
  if (m_InitialTransform != ""
      && initialTransform->GetTransformList()->size() != 0)
    {
    typename TransformReaderType::TransformType::Pointer initial
      = *(initialTransform->GetTransformList()->begin());

    // most likely, the transform coming in is a subclass of
    // MatrixOffsetTransformBase 
    typedef itk::MatrixOffsetTransformBase<double,3,3> DoubleMatrixOffsetType;
    typedef itk::MatrixOffsetTransformBase<float,3,3> FloatMatrixOffsetType;

    typename DoubleMatrixOffsetType::Pointer da
      = dynamic_cast<DoubleMatrixOffsetType*>(initial.GetPointer());
    typename FloatMatrixOffsetType::Pointer fa
      = dynamic_cast<FloatMatrixOffsetType*>(initial.GetPointer());

    if (da)
      {
      transform->SetMatrix( da->GetMatrix() );
      transform->SetOffset( da->GetOffset() );
      }
    else if (fa)
      {
      vnl_matrix<double> t(3,3);
      for (int i=0; i < 3; ++i)
        {
        for (int j=0; j <3; ++j)
          {
          t.put(i, j, fa->GetMatrix().GetVnlMatrix().get(i, j));
          }
        }

      transform->SetMatrix( t );
      transform->SetOffset( fa->GetOffset() );
      }
    else
      {
      std::cout << "Initial transform is an unsupported type." << std::endl;
      }

    std::cout << "Initial transform: "; transform->Print ( std::cout );
    }


   cout << "before metric" << endl;

  // Set up the metric
  //
  typename  MetricType::Pointer  metric        = MetricType::New();
    metric->SetNumberOfHistogramBins ( m_HistogramBins );
    metric->SetNumberOfSpatialSamples( m_SpatialSamples );

   cout << "before interpolator" << endl;


  // Create the interpolator
  //
  typename InterpolatorType::Pointer interpolator = InterpolatorType::New();

   cout << "before registration" << endl;

  // Set up the registration
  //
  typename RegistrationType::Pointer registration = RegistrationType::New();
    registration->SetTransform ( transform );
    registration->SetInitialTransformParameters ( transform->GetParameters() );
    registration->SetMetric ( metric );
    registration->SetOptimizer ( optimizer );
    registration->SetInterpolator ( interpolator );
    registration->SetFixedImage ( orientFixed->GetOutput() );
    registration->SetMovingImage ( orientMoving->GetOutput() );
     //registration->SetFixedImage ( m_FixedImage);
     //registration->SetMovingImage(m_MovingImage);

   cout << " before registration update" << endl;


  try
    {
//    collector.Start( "Register" );
    registration->Update();
//    collector.Stop( "Register" );
    }
  catch( itk::ExceptionObject & err )
    {
    std::cout << err << std::endl;
    std::cerr << err << std::endl;
    return  EXIT_FAILURE ;
    }
  catch ( ... )
    {
    return  EXIT_FAILURE ;
    }


  cout << " before transform" << endl;

  transform->SetParameters ( registration->GetLastTransformParameters() );

  if (m_OutputTransform != "")
    {
    typedef itk::TransformFileWriter TransformWriterType;
    typename TransformWriterType::Pointer outputTransformWriter;

    outputTransformWriter= TransformWriterType::New();
    outputTransformWriter->SetFileName( m_OutputTransform );
    outputTransformWriter->SetInput( transform );
    try
      {
      outputTransformWriter->Update();
      }
    catch (itk::ExceptionObject &err)
      {
      std::cerr << err << std::endl;
      return  EXIT_FAILURE ;
      }
    }

  cout << " before resample" << endl;

  // Resample to the original coordinate frame (not the reoriented
  // axial coordinate frame) of the fixed image
  //
  if (m_ResampledImageFileName != "")
    {
     typename ResampleType::Pointer resample = ResampleType::New();
     typename ResampleInterpolatorType::Pointer Interpolator = ResampleInterpolatorType::New();
    //resample->SetInput ( movingReader->GetOutput() );
    resample->SetInput (m_MovingImage);
    resample->SetTransform ( transform );
    resample->SetInterpolator ( Interpolator );
   // resample->SetOutputParametersFromImage ( fixedReader->GetOutput() );
    resample->SetOutputParametersFromImage(m_FixedImage);

    collector.Start( "Resample" );
    resample->Update();
    collector.Stop( "Resample" );

    typename WriterType::Pointer resampledWriter = WriterType::New();
    resampledWriter->SetFileName ( m_ResampledImageFileName.c_str() );
    resampledWriter->SetInput ( resample->GetOutput() );

    try
      {
      collector.Start( "Write volume" );
      resampledWriter->Write();
      collector.Stop( "Write volume" );
      }
    catch( itk::ExceptionObject & err )
      {
      std::cerr << err << std::endl;
      std::cerr << err << std::endl;
      return EXIT_FAILURE;
      }

    //outputimage = resample->GetOutput();   // pass the registered image
    //outputimage->Update();

    m_OutputImage = ImageType::New();

    typename ImageType::IndexType start;
    start[0] = 0; start[1] = 0; start[2] = 0;

    typename ImageType::SizeType size;
    size = m_FixedImage->GetBufferedRegion().GetSize();

    typename ImageType::RegionType region;
    region.SetSize(size);
    region.SetIndex(start);

    m_OutputImage->SetRegions(region);
    m_OutputImage->Allocate();

    }

  // Report the time taken by the registration
  collector.Report();


  return EXIT_SUCCESS;
}


template <class TImage>
int myAffReg<TImage>::myregistration(ImageType *outputimage)  // main
{

   DoIt(outputimage);

  return EXIT_SUCCESS;

}

template <class TImage>
void myAffReg<TImage>::SetHistogramBin(int histbin)
{

  m_HistogramBins = histbin;
   
}

template <class TImage>
void myAffReg<TImage>::SetSpatialSamples(int spatialsample)
{
  m_SpatialSamples = spatialsample;
}


template <class TImage>
void myAffReg<TImage>::SetIteration(int iter)
{
  m_Iterations = iter;
}

template <class TImage>
void myAffReg<TImage>::SetTransScale(double transscale)
{
  m_TranslationScale = transscale;
}

template <class TImage>
void myAffReg<TImage>::SetInitTransform(std::string& inittrans)
{
  m_InitialTransform = inittrans;
}

template <class TImage>
void myAffReg<TImage>::SetOutputTransform(std::string& outtrans)
{
  m_OutputTransform = outtrans;
}

template <class TImage>
void myAffReg<TImage>::SetFixedImage(ImageType *fixedimage)
{
  m_FixedImage = fixedimage;
}

template <class TImage>
void myAffReg<TImage>::SetMovingImage(ImageType *movingimage)
{
  m_MovingImage = movingimage;
}

template <class TImage>
void myAffReg<TImage>::SetOutputImageFilename(std::string& outimagename)
{
  m_ResampledImageFileName = outimagename;
}

template <class TImage>
TImage *myAffReg<TImage>::GetRegisteredImage()
{

  return m_OutputImage;

}

template <class TImage>
int myAffReg<TImage>::mytransformation(std::string& fixedimagename, std::string& movingimagename, std::string& outimagename, std::string& transformname)  // main
{
  // from DoIt
  cout << "myAffReg<TImage>::mytransformation(ImageType *outputimage)" << endl;

  const     unsigned int   Dimension = 3;
  typedef   short  InputPixelType;
  typedef   short  OutputPixelType;

  typedef itk::Image< InputPixelType,  Dimension >   InputImageType;
  typedef itk::Image< OutputPixelType, Dimension >   OutputImageType;

  typedef itk::Image< double,  Dimension > DoubleImageType;

  typedef itk::ImageFileReader< InputImageType  >  ReaderType;
  typedef itk::ImageFileWriter< OutputImageType >  WriterType;


  ReaderType::Pointer fixedreader = ReaderType::New();
  ReaderType::Pointer movingreader = ReaderType::New();

 WriterType::Pointer writer = WriterType::New();

  fixedreader->SetFileName( fixedimagename );
  movingreader->SetFileName ( movingimagename );

  writer->SetFileName( outimagename );

 // const double angleInDegrees = atof( argv[3] );

 typedef itk::CenteredAffineTransform< double, Dimension >  TransformType;
//   typedef itk::FixedCenterOfRotationAffineTransform< double, Dimension >  TransformType;
  TransformType::Pointer transform = TransformType::New();

  TransformType::OutputVectorType translation1;
  TransformType::OutputVectorType offset;
  TransformType::OutputVectorType rotation;


  typedef itk::ResampleImageFilter<
                  InputImageType, OutputImageType >  FilterType;


  FilterType::Pointer filter = FilterType::New();

  typedef itk::LinearInterpolateImageFunction<
                       InputImageType, double >  InterpolatorType;
  InterpolatorType::Pointer interpolator = InterpolatorType::New();

  filter->SetInterpolator( interpolator );

//  filter->SetDefaultPixelValue( 100 );

  
  fixedreader->Update();
  movingreader->Update();
  
  const InputImageType::SpacingType& spacing = fixedreader->GetOutput()->GetSpacing();
  const InputImageType::PointType& origin  = fixedreader->GetOutput()->GetOrigin();
  InputImageType::SizeType size = fixedreader->GetOutput()->GetLargestPossibleRegion().GetSize();
 

  filter->SetOutputDirection(fixedreader->GetOutput()->GetDirection());

  filter->SetOutputOrigin( origin );
  filter->SetOutputSpacing( spacing );
  filter->SetSize( size );
  filter->SetTransform(transform);

  filter->SetInput( movingreader->GetOutput() );
  writer->SetInput( filter->GetOutput() );


  cout << "size: " << size[0] << ", " << size[1] << ", " << size[2] << endl;
  cout << "spacing: " << spacing[0] << ", " << spacing[1] << ", " << spacing[2] << endl;
  cout << "origin: " << origin[0] << ", " << origin[1] << ", " << origin[2] << endl;

  
 double imageCenterX = origin[0] + spacing[0] * size[0] / 2.0;
 double imageCenterY = origin[1] + spacing[1] * size[1] / 2.0;
 double imageCenterZ = origin[2] + spacing[2] * size[2] / 2.0;
                                                                                   
  translation1[0] =   -imageCenterZ;
  translation1[1] =   -imageCenterY;
  translation1[2] =   -imageCenterX;


  // added
  TransformType::InputPointType center;
  center[0] = translation1[0];
  center[1] = translation1[1];
  center[2] = translation1[2];


  transform->SetCenter (center);
  
  // read transform
  typedef itk::TransformFileReader TransformReaderType;
  TransformReaderType::Pointer initialTransform;

  initialTransform= TransformReaderType::New();
  initialTransform->SetFileName(transformname);
  initialTransform->Update();

  TransformReaderType::TransformType::Pointer initial
      = *(initialTransform->GetTransformList()->begin());

    // most likely, the transform coming in is a subclass of
    // MatrixOffsetTransformBase 
    typedef itk::MatrixOffsetTransformBase<double,3,3> DoubleMatrixOffsetType;
    typedef itk::MatrixOffsetTransformBase<float,3,3> FloatMatrixOffsetType;


    DoubleMatrixOffsetType::Pointer da
      = dynamic_cast<DoubleMatrixOffsetType*>(initial.GetPointer());
    FloatMatrixOffsetType::Pointer fa
      = dynamic_cast<FloatMatrixOffsetType*>(initial.GetPointer());

    if (da)
      {
      transform->SetMatrix( da->GetMatrix() );
      transform->SetOffset( da->GetOffset() );
      }
    else if (fa)
      {
      vnl_matrix<double> t(3,3);
      for (int i=0; i < 3; ++i)
        {
        for (int j=0; j <3; ++j)
          {
          t.put(i, j, fa->GetMatrix().GetVnlMatrix().get(i, j));
          }
        }

      transform->SetMatrix( t );
      transform->SetOffset( fa->GetOffset() );
     }


  try
    {
    writer->Update();
    }
  catch( itk::ExceptionObject & excep )
    {
    std::cerr << "Exception caught !" << std::endl;
    std::cerr << excep << std::endl;
    }


  return EXIT_SUCCESS;
}


