
#include "itkBSplineInterpolateImageFunction.h"
#include "itkNumericTraits.h"
#include "itkVectorIndexSelectionCastImageFilter.h"

#include "vnl/vnl_math.h"
#include "vnl/vnl_matrix.h"
#include "vnl/algo/vnl_svd.h"

#include "InverseDeformationFilter.h"
#include "DynArray.h"

#include "muException.h"

#include <iostream>

InverseDeformationFilter
::InverseDeformationFilter()
{
  m_NumberOfIterations = 100;
  m_Tolerance = 1e-2;

  m_MaximumDisplacementNorm = 100.0;
}

void
InverseDeformationFilter
::GenerateData()
{

  // Separate deformation field into 3 images (x, y, z) and
  // create B-spline displacement interpolators for each dimension
  typedef itk::BSplineInterpolateImageFunction<FloatImageType, double>
    InterpolatorType;
  typedef itk::VectorIndexSelectionCastImageFilter<
    DeformationType, FloatImageType> IndexSelectType;

//std::cout << "Building interpolators..." << std::endl;
  DynArray<InterpolatorType::Pointer> u;
  for (int i = 0; i < 3; i++)
  {
    IndexSelectType::Pointer select_i = IndexSelectType::New();
    select_i->SetIndex(i);
    select_i->SetInput(this->GetInput());
    select_i->Update();

    FloatImageType::Pointer displacement_i = select_i->GetOutput();

    InterpolatorType::Pointer tmp = InterpolatorType::New();
    tmp->SetInputImage(displacement_i);
    tmp->SetSplineOrder(3);

    u.Append(tmp);
  }

  // Image info
  DeformationRegionType region = this->GetInput()->GetLargestPossibleRegion();
  DeformationSpacingType spacing = this->GetInput()->GetSpacing();
  DeformationSizeType size = region.GetSize();

//std::cout << "Check mask image" << std::endl;
  if (!m_MaskImage.IsNull())
  {
    DeformationSizeType size_m = m_MaskImage->GetLargestPossibleRegion().GetSize();
    if (size != size_m)
      muExceptionMacro(<< "Input def and mask size mismatch");
  }

  // Allocate output deformation
  DeformationType* output = this->GetOutput();
  this->GetOutput()->CopyInformation(this->GetInput());
  this->GetOutput()->SetRegions(region);
  this->GetOutput()->Allocate();

  VectorType zerov;
  zerov.Fill(0.0);
  this->GetOutput()->FillBuffer(zerov);

  // Book-keeping
  typedef vnl_matrix<double> MatrixType;

  MatrixType du(3, 3, 0.0);
  MatrixType deltaI(3, 1, 0.0);
  MatrixType deltaP(3, 1, 0.0);

  InterpolatorType::PointType point;
  InterpolatorType::PointType inverse;
  InterpolatorType::PointType lastInverse;

  double tolSquared = m_Tolerance*m_Tolerance;

  //
  // Process each grid point in the output image
  //

  DeformationIndexType ind;

  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        if (!m_MaskImage.IsNull() && (m_MaskImage->GetPixel(ind) == 0))
          continue;

//std::cout << "Inverse for " << ind << std::endl;

        output->TransformIndexToPhysicalPoint(ind, point);

        if (!u[0]->IsInsideBuffer(point))
          continue;

        // First guess is inv(x) = x - u(x)
        inverse[0] = point[0] - u[0]->Evaluate(point);
        inverse[1] = point[1] - u[1]->Evaluate(point);
        inverse[2] = point[2] - u[2]->Evaluate(point);

        // Check if it's inside valid image region 
        if (!u[0]->IsInsideBuffer(inverse))
          inverse = point;

        // Keep track of possible inverse mappings
        lastInverse = inverse;

        // Temporary variables
        double functionValue = 0;
        double functionDerivative = 0;
        double lastFunctionValue = itk::NumericTraits<double>::max();

        double f = 1.0;
        double a;

        double errorSquared = 0;

        unsigned int iter;
        for (iter = 0; iter < m_NumberOfIterations; iter++)
        {
          if (!u[0]->IsInsideBuffer(inverse))
          {
            // Mapped to outside of buffer, stop
            inverse = lastInverse;
            break;
          }

          // Displacement derivatives
          for (int j = 0; j < 3; j++)
          {
            InterpolatorType::CovariantVectorType du_j =
              u[j]->EvaluateDerivative(inverse);

            du(j, 0) = du_j[0];
            du(j, 1) = du_j[1];
            du(j, 2) = du_j[2];
            du(j, j) += 1.0;
          }

          // Compute 2-norm of difference between:
          // inv(x) + u(inv(x)) vs x
          // should be zero
          functionValue = 0;
          for (int j = 0; j < 3; j++)
          {
            double d = inverse[j] + u[j]->Evaluate(inverse) - point[j];

            deltaP(j, 0) = d;
            functionValue += d*d;
          }

/*
std::cout << "Iter " << iter << std::endl;
std::cout << "  point = " << point << std::endl;
std::cout << "  inverse = " << inverse << std::endl;
std::cout << "  match = " << functionValue << std::endl;
std::cout << "  deltaP = \n    " << deltaP << std::endl;
std::cout << "  du = \n    " << du << std::endl;
*/

          // Do Newton step if function value decreases
          if (functionValue < lastFunctionValue || f < 1.0)
          {
            vnl_svd<double> qr(du);
            deltaI = qr.solve(deltaP);

            errorSquared =
              deltaI(0, 0)*deltaI(0, 0) +
              deltaI(1, 0)*deltaI(1, 0) +
              deltaI(2, 0)*deltaI(2, 0);

            if (errorSquared < tolSquared && functionValue < tolSquared)
              break;

            lastInverse = inverse;

            lastFunctionValue = functionValue;

            functionDerivative =
              (deltaP(0, 0) * du(0, 0) * deltaI(0, 0) +
               deltaP(1, 0) * du(1, 1) * deltaI(1, 0) +
               deltaP(2, 0) * du(2, 2) * deltaI(2, 0)) * 2;

            inverse[0] -= deltaI(0, 0);
            inverse[1] -= deltaI(1, 0);
            inverse[2] -= deltaI(2, 0);

            f = 1.0;

            continue;
          }

          // Partial step
          a =  -functionDerivative /
            2 * (functionValue - lastFunctionValue - functionDerivative);

          if (a < 0.1)
            a = 0.1;
          if (a > 0.5)
            a = 0.5;
          f *= a;

          inverse[0] = lastInverse[0] - f * deltaI(0, 0);
          inverse[1] = lastInverse[1] - f * deltaI(1, 0);
          inverse[2] = lastInverse[2] - f * deltaI(2, 0);

        } // for iter

        if (iter >= m_NumberOfIterations)
        {
          // No convergence
          inverse = lastInverse;
        }

        VectorType v;
        v[0] = inverse[0] - point[0];
        v[1] = inverse[1] - point[1];
        v[2] = inverse[2] - point[2];

        for (int d = 0; d < 3; d++)
        {
          if (vnl_math_isinf(v[d]))
            v[d] = 0.0;
          if (vnl_math_isnan(v[d]))
            v[d] = 0.0;
        }

        double vnorm = v.GetNorm();
        if (vnorm > m_MaximumDisplacementNorm)
          v *= (m_MaximumDisplacementNorm / vnorm);

        output->SetPixel(ind, v);

      } // for ind

}
