
#include "DTImageInterpolator.h"
#include "DTImageWarpFilter.h"

#include "itkBSplineInterpolateImageFunction.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkLinearInterpolateImageFunction.h"
#include "itkVectorIndexSelectionCastImageFilter.h"

#include "InverseDeformationFilter.h"
#include "VectorBlurImageFilter.h"

DTImageWarpFilter
::DTImageWarpFilter()
{
  m_DeformationField = 0;
  m_InverseDeformationField = 0;
}

DTImageWarpFilter
::~DTImageWarpFilter()
{

}

void
DTImageWarpFilter
::SetDeformationField(DeformationFieldType* def)
{
  m_DeformationField = def;

  this->Modified();
}

void
DTImageWarpFilter
::SetInverseDeformationField(DeformationFieldType* def)
{
  m_InverseDeformationField = def;

  this->Modified();
}

void
DTImageWarpFilter
::GenerateData()
{

  if (this->GetInput() == 0)
    return;

  if (m_DeformationField.IsNull())
    itkExceptionMacro(<< "No deformation  field defined");

//TODO: mask for foreground, skip bg warping

  DTImageSizeType size = this->GetInput()->GetLargestPossibleRegion().GetSize();
  DTImageSpacingType spacing = this->GetInput()->GetSpacing();

  // Allocate output
  this->GetOutput()->SetRegions(this->GetInput()->GetLargestPossibleRegion());
  this->GetOutput()->CopyInformation(this->GetInput());
  this->GetOutput()->Allocate();

  // Get inverse deformation if not given
  if (m_InverseDeformationField.IsNull())
  {
    InverseDeformationFilter::Pointer invFilter = InverseDeformationFilter::New();
    invFilter->SetNumberOfIterations(100);
    invFilter->SetTolerance(1e-2);
    invFilter->SetInput(m_DeformationField);
    invFilter->Update();

    m_InverseDeformationField = invFilter->GetOutput();
  }

#if 0
  // Smoothen inverse
  typedef VectorBlurImageFilter<DeformationFieldType, DeformationFieldType>
    DeformationSmootherType;

  DeformationSmootherType::Pointer defsmoother = DeformationSmootherType::New();
  defsmoother->SetKernelWidth(1.0);
  defsmoother->SetInput(m_InverseDeformationField);
  defsmoother->Update();

  m_InverseDeformationField = defsmoother->GetOutput();
#endif

  // Create DTI interpolator
  DTImageInterpolator dtInterpolator;
  dtInterpolator.SetInputImage(this->GetInput());

  // Create deformation field interpolators
  typedef itk::BSplineInterpolateImageFunction<FloatImageType, double>
    DeformationInterpolatorType;
  //typedef itk::LinearInterpolateImageFunction<FloatImageType, double>
  //  DeformationInterpolatorType;
  typedef itk::VectorIndexSelectionCastImageFilter<DeformationFieldType, FloatImageType> IndexSelectType;

//TODO: use interp for fw/bw mappings

  DeformationInterpolatorType::Pointer invDefInterpolators[3];
  for (int i = 0; i < 3; i++)
  {
    IndexSelectType::Pointer select_i = IndexSelectType::New();
    select_i->SetIndex(i);
    select_i->SetInput(m_InverseDeformationField);
    select_i->Update();
  
    FloatImageType::Pointer displacement_i = select_i->GetOutput();
  
    DeformationInterpolatorType::Pointer tmp = DeformationInterpolatorType::New();
    tmp->SetInputImage(displacement_i);
    tmp->SetSplineOrder(1);

    invDefInterpolators[i] = tmp;
  }

  DiffusionTensor::MatrixType id(3, 3);
  id.set_identity();

  DiffusionTensor Teps;
  Teps.FromMatrix(id * 1e-10);

  // Warp individual voxels
  DTImageIndexType ind;
  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        DTImagePointType point;
        this->GetOutput()->TransformIndexToPhysicalPoint(ind, point);

        InverseDeformationFilter::VectorType invd =
          m_DeformationField->GetPixel(ind);

        // Find inverse mapping
        DTImagePointType inverse;

        inverse[0] = point[0] + invd[0];
        inverse[1] = point[1] + invd[1];
        inverse[2] = point[2] + invd[2];

        if (!invDefInterpolators[0]->IsInsideBuffer(inverse))
        {
          this->GetOutput()->SetPixel(ind, Teps);
          continue;
        }

        // Get the source to target mapping derivative at inverse point
        DiffusionTensor::MatrixType A(3, 3, 0.0);

        for (int j = 0; j < 3; j++)
        {
          DeformationInterpolatorType::CovariantVectorType d_j =
            invDefInterpolators[j]->EvaluateDerivative(inverse);
          A(j, 0) = d_j[0];
          A(j, 1) = d_j[1];
          A(j, 2) = d_j[2];
          A(j, j) += 1.0;
        }

        DiffusionTensor::MatrixQRType qr(A);

        // Get the tensor at inverse point
        DiffusionTensor T = dtInterpolator.Evaluate(inverse);
        DiffusionTensor::MatrixType D = T.GetMatrix();

        //If A is close to id just use the original tensor
        DiffusionTensor::MatrixType diff = A - id;
        if (diff.frobenius_norm() < 1e-2)
        {
          this->GetOutput()->SetPixel(ind, T);
          continue;
        }

#if 0
        // Use local affine transform (w/ scaling)
        DiffusionTensor::MatrixType Dprime = A.transpose() * D * A;

        // Remove scaling effect
        DiffusionTensor::MatrixQRType qr(A);
        float detA = qr.determinant();
        if (detA > 1e-20)
          Dprime /= pow(detA, 2.0 / 3.0);
#else
        // Compute rotation component of A through SVD
        DiffusionTensor::MatrixSVDType svd(A);

        // Rotation
        DiffusionTensor::MatrixType W = svd.U() * svd.V().transpose();
        // Strain
        //DiffusionTensor::MatrixType R =
        //  svd.V() * svd.W() *  svd.V().transpose();

        DiffusionTensor::MatrixType Dprime = W.transpose() * D * W;
#endif

        DiffusionTensor warpedT;
        warpedT.FromMatrix(Dprime);

        this->GetOutput()->SetPixel(ind, warpedT);
      }

}
