/**
\file Registration.h

This file holds the declaration of the class Registration.

https://www.cbica.upenn.edu/sbia/software/ <br>
software@cbica.upenn.edu

Copyright (c) 2016 University of Pennsylvania. All rights reserved. <br>
See COPYING file or https://www.cbica.upenn.edu/sbia/software/license.html
*/
#pragma once

#include "itkAffineTransform.h"
#include "itkRegularStepGradientDescentOptimizer.h"
#include "itkMattesMutualInformationImageToImageMetric.h"
#include "itkMultiResolutionImageRegistrationMethod.h"
#include "itkNormalizedMutualInformationHistogramImageToImageMetric.h"
#include "ApplicationBase.h"


template <typename TRegistration>
class RegistrationInterfaceCommand : public itk::Command
{
public:
  typedef  RegistrationInterfaceCommand   Self;
  typedef  itk::Command                   Superclass;
  typedef  itk::SmartPointer<Self>        Pointer;
  itkNewMacro(Self);
protected:
  RegistrationInterfaceCommand() {};
public:
  typedef   TRegistration                              RegistrationType;
  typedef   RegistrationType *                         RegistrationPointer;
  typedef   itk::RegularStepGradientDescentOptimizer   OptimizerType;
  typedef   OptimizerType *                            OptimizerPointer;

  void Execute(itk::Object * object, const itk::EventObject & event)
  {
    if (!(itk::IterationEvent().CheckEvent(&event)))
    {
      return;
    }
    RegistrationPointer registration = dynamic_cast<RegistrationPointer>(object);

    OptimizerPointer optimizer = dynamic_cast<OptimizerPointer>(registration->GetModifiableOptimizer());
    /*std::cout << "-------------------------------------" << std::endl;
    std::cout << "MultiResolution Level : " << registration->GetCurrentLevel() << std::endl;
    std::cout << "mutual information value" << optimizer->GetValue()<<std::endl;*/
    std::cout << std::endl;
    if (registration->GetCurrentLevel() == 0)
    {
      optimizer->SetMaximumStepLength(0.0525);
      optimizer->SetMinimumStepLength(0.00001);
    }
    else
    {
      optimizer->SetMaximumStepLength(optimizer->GetMaximumStepLength() * 0.25);
      optimizer->SetMinimumStepLength(optimizer->GetMinimumStepLength() * 0.1);
    }
  }
  void Execute(const itk::Object *, const itk::EventObject &)
  {
    return;
  }
};

template<class TFilter>
class CommandIterationUpdate : public itk::Command
{
public:
  typedef CommandIterationUpdate   Self;
  typedef itk::Command             Superclass;
  typedef itk::SmartPointer<Self>  Pointer;
  itkNewMacro(Self);
protected:
  CommandIterationUpdate() {};
public:

  void Execute(itk::Object *caller, const itk::EventObject & event)
  {
    Execute((const itk::Object *) caller, event);
  }

  void Execute(const itk::Object * object, const itk::EventObject & event)
  {
    const TFilter * filter =
      dynamic_cast< const TFilter * >(object);
    if (typeid(event) != typeid(itk::IterationEvent))
    {
      return;
    }

    std::string msg = "Iteration " + std::to_string(filter->GetElapsedIterations()) + " (of " + std::to_string(filter->GetMaximumNumberOfIterations()) + ").  ";
    msg = msg + " Current convergence value = " + std::to_string(filter->GetCurrentConvergenceMeasurement()) + " (threshold = " + std::to_string(filter->GetConvergenceThreshold()) + ")";

    QMessageBox::warning(NULL, "Error", QString::fromStdString(msg), QMessageBox::Ok, NULL);
  }

};

class CommandIterationUpdateRegistration : public itk::Command
{
public:
  typedef  CommandIterationUpdateRegistration   Self;
  typedef  itk::Command             Superclass;
  typedef  itk::SmartPointer<Self>  Pointer;
  itkNewMacro(Self);
protected:
  CommandIterationUpdateRegistration() {};
public:
  typedef   itk::RegularStepGradientDescentOptimizer  OptimizerType;
  typedef   const OptimizerType *                     OptimizerPointer;
  void Execute(itk::Object *caller, const itk::EventObject & event)
  {
    Execute((const itk::Object *)caller, event);
  }
  void Execute(const itk::Object * object, const itk::EventObject & event)
  {
    //OptimizerPointer optimizer = dynamic_cast<OptimizerPointer>(object);
    if (!(itk::IterationEvent().CheckEvent(&event)))
    {
      return;
    }
    //std::cout << optimizer->GetCurrentIteration() << "   ";
    //std::cout << optimizer->GetValue() << "   ";
    //std::cout << optimizer->GetCurrentPosition() << std::endl;
  }
};

/**
\class Registration

\brief Class that handles affine registration between a fixed and moving image

This class uses standard ITK-based filters/optimizers/etc. for performing this multi-resolution registration.

Optimizer:    RegularStepGradientDescentOptimizer
Interpolator: LinearInterpolateImageFunction
Metric:       MattesMutualInformationImageToImageMetric

*/
class Registration : public ApplicationBase
{

public:
  Registration();
  ~Registration();

  template<class ImageType, class InternalImageType /*do we need this second template?*/>
  typename itk::MultiResolutionImageRegistrationMethod<ImageType, ImageType>::Pointer Run(typename ImageType::Pointer fixedImagePointer,
    typename ImageType::Pointer movingImagePointer);

  template<class ImageType>
  typename ImageType::Pointer ResampleTransform(typename itk::MultiResolutionImageRegistrationMethod<ImageType, ImageType>::Pointer registrationPointer, typename ImageType::Pointer fixedImagePointer, typename ImageType::Pointer movingImagePointer);

private:
  inline void SetLongRunning(bool longRunning);

};

template<class ImageType, class InternalImageType /*do we need this second template?*/>
typename itk::MultiResolutionImageRegistrationMethod<ImageType, ImageType>::Pointer Registration::Run(typename ImageType::Pointer fixedImagePointer,
  typename ImageType::Pointer movingImagePointer)
{
  progressUpdate(0);
  typedef itk::AffineTransform <double, 3> TransformType;
  typedef itk::RegularStepGradientDescentOptimizer OptimizerType;
  typedef itk::LinearInterpolateImageFunction<ImageType, double> InterpolatorType;
  typedef itk::MattesMutualInformationImageToImageMetric<ImageType, ImageType> MetricType;
  // typedef itk::NormalizedMutualInformationHistogramImageToImageMetric<ImageType, ImageType> MetricType2;
  typedef itk::MultiResolutionImageRegistrationMethod<ImageType, ImageType> RegistrationType;

  typedef itk::MultiResolutionPyramidImageFilter<ImageType, ImageType> FixedImagePyramidType;
  typedef itk::MultiResolutionPyramidImageFilter<ImageType, ImageType> MovingImagePyramidType;

  TransformType::Pointer transform = TransformType::New();
  OptimizerType::Pointer optimizer = OptimizerType::New();
  typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
  typename RegistrationType::Pointer   registrar = RegistrationType::New();
  typename MetricType::Pointer         metric = MetricType::New();
  typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New();
  typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New();

  registrar->SetOptimizer(optimizer);
  registrar->SetInterpolator(interpolator);
  registrar->SetMetric(metric);
  registrar->SetTransform(transform);
  registrar->SetFixedImagePyramid(fixedImagePyramid);
  registrar->SetMovingImagePyramid(movingImagePyramid);

  // typedef itk::Image<ImageType> FixedImageType;
  // typedef itk::Image<ImageType> MovingImageType;
  //typedef itk::CastImageFilter<FixedImageType, InternalImageType> FixedCastFilterType;
  ////typedef itk::CastImageFilter<MovingImageType, InternalImageType> MovingCastFilterType;
  //FixedCastFilterType::Pointer fixedCaster = FixedCastFilterType::New();
  //MovingCastFilterType::Pointer movingCaster = MovingCastFilterType::New();

  //fixedCaster->SetInput(fixedImagePointer);
  //movingCaster->SetInput(movingImagePointer);

  registrar->SetFixedImage(fixedImagePointer);
  registrar->SetMovingImage(movingImagePointer);

  //	fixedCaster->Update();
  registrar->SetFixedImageRegion(fixedImagePointer->GetBufferedRegion());
  typedef typename RegistrationType::ParametersType ParametersType;

  ParametersType initialParameters(transform->GetNumberOfParameters());
  initialParameters[0] = 1;
  initialParameters[1] = 0.0;
  initialParameters[2] = 0.0;
  initialParameters[3] = 0.0;
  initialParameters[4] = 1;
  initialParameters[5] = 0.0;
  initialParameters[6] = 0.0;
  initialParameters[7] = 0.0;
  initialParameters[8] = 1;

  initialParameters[9] = 0.0;
  initialParameters[10] = 0.0;
  initialParameters[11] = 0.0;

  registrar->SetInitialTransformParameters(initialParameters);

  //parameters for metric1

  metric->SetNumberOfHistogramBins(128);
  metric->SetNumberOfSpatialSamples(30000);
  metric->ReinitializeSeed(76926294);
  metric->SetUseExplicitPDFDerivatives(false);

  optimizer->SetNumberOfIterations(200);
  optimizer->SetRelaxationFactor(0.9);

  CommandIterationUpdateRegistration::Pointer observer = CommandIterationUpdateRegistration::New();
  optimizer->AddObserver(itk::IterationEvent(), observer);

  typedef RegistrationInterfaceCommand<RegistrationType> CommandType;
  typename CommandType::Pointer command = CommandType::New();
  registrar->AddObserver(itk::IterationEvent(), command);
  registrar->SetNumberOfLevels(1);
  try
  {
    registrar->Update();
    //std::cout << "stop condition" << registrar->GetOptimizer()->GetStopConditionDescription() << std::endl;
  }
  catch (itk::ExceptionObject &ex)
  {
    std::cout << ex << std::endl;
  }

  //---------------------------obtaining the registartion parameters---------------------
  ParametersType finalParameters = registrar->GetLastTransformParameters();
  //transformation[0][0] =  finalParameters[0];
  //transformation[0][1] = finalParameters[1];
  //transformation[0][2] = finalParameters[2];
  //transformation[1][0] = finalParameters[3];
  //transformation[1][1] = finalParameters[4];
  //transformation[1][2] = finalParameters[5];	
  //transformation[2][0] = finalParameters[6];
  //transformation[2][1] = finalParameters[7];
  //transformation[2][2] = finalParameters[8];
  //
  //shifting[0] = finalParameters[9];
  //shifting[1] = finalParameters[10];
  //shifting[2] = finalParameters[11];

  // unsigned int numberOfIterations = optimizer->GetCurrentIteration();
  // double bestValue = optimizer->GetValue();
  // TransformType::Pointer finalTransform = TransformType::New();

  progressUpdate(100);
  return registrar;
}


template<class ImageType>
typename ImageType::Pointer Registration::ResampleTransform(typename itk::MultiResolutionImageRegistrationMethod<ImageType, ImageType>::Pointer registrationPointer, typename ImageType::Pointer fixedImagePointer, typename ImageType::Pointer movingImagePointer)
{
  typedef itk::ResampleImageFilter<ImageType, ImageType> ResampleFilterType;
  typename ResampleFilterType::Pointer resample = ResampleFilterType::New();

  resample->SetTransform(registrationPointer->GetOutput()->Get());
  resample->SetInput(movingImagePointer);
  resample->SetSize(fixedImagePointer->GetLargestPossibleRegion().GetSize());
  resample->SetOutputOrigin(fixedImagePointer->GetOrigin());
  resample->SetOutputSpacing(fixedImagePointer->GetSpacing());
  resample->SetOutputDirection(fixedImagePointer->GetDirection());
  resample->SetDefaultPixelValue(100);
  resample->Update();

  return resample->GetOutput();
}
