
#include "DTImageInterpolator.h"
#include "DTImageDestructiveWarpFilter.h"

#include "itkBSplineInterpolateImageFunction.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkIterativeInverseDeformationFieldImageFilter.h"
#include "itkLinearInterpolateImageFunction.h"
#include "itkVectorIndexSelectionCastImageFilter.h"

//#include "TensorMatrixMath.h"
#include "FastTensorMatrixMath.h"

#include "InverseDeformationFilter.h"
#include "VectorBlurImageFilter.h"

// TODO: store jacobian image

DTImageDestructiveWarpFilter
::DTImageDestructiveWarpFilter()
{
  m_DeformationField = 0;
  m_InverseDeformationField = 0;

  m_JVariance = 1.2;

  m_NormalizeOutputTrace = false;
}

DTImageDestructiveWarpFilter
::~DTImageDestructiveWarpFilter()
{

}


void
DTImageDestructiveWarpFilter
::SetDeformationField(DeformationFieldType* def)
{
  m_DeformationField = def;

  this->Modified();
}

void
DTImageDestructiveWarpFilter
::SetInverseDeformationField(DeformationFieldType* def)
{
  m_InverseDeformationField = def;

  this->Modified();
}

void
DTImageDestructiveWarpFilter
::GenerateData()
{

  if (this->GetInput() == 0)
    return;

  if (m_DeformationField.IsNull())
    itkExceptionMacro(<< "No deformation  field defined");

  // TODO: speed up by using mask for foreground to skip bg warping

  DTImageSizeType size = this->GetInput()->GetLargestPossibleRegion().GetSize();
  DTImageSpacingType spacing = this->GetInput()->GetSpacing();

  // Allocate output
  this->GetOutput()->CopyInformation(this->GetInput());
  this->GetOutput()->SetRegions(this->GetInput()->GetLargestPossibleRegion());
  this->GetOutput()->Allocate();

  // Get inverse deformation
  DeformationFieldPointer inverseDeformation;

  if (m_InverseDeformationField.IsNull())
  {
    // Compute inverse if not given explicitly
#if 0
    InverseDeformationFilter::Pointer invFilter = InverseDeformationFilter::New();
    invFilter->SetNumberOfIterations(50);
    invFilter->SetTolerance(1e-2);
    invFilter->SetInput(m_DeformationField);
    invFilter->Update();
#else
    typedef itk::IterativeInverseDeformationFieldImageFilter<
      DeformationFieldType, DeformationFieldType>
      InverterType;
    InverterType::Pointer invFilter = InverterType::New();
    invFilter->SetNumberOfIterations(50);
    invFilter->SetStopValue(1e-2);
    invFilter->SetInput(m_DeformationField);
    invFilter->Update();
#endif
    inverseDeformation = invFilter->GetOutput();

#if 0
    // Smoothen the estimated inverse
    typedef VectorBlurImageFilter<DeformationFieldType, DeformationFieldType>
      DeformationSmootherType;

    DeformationSmootherType::Pointer defsmoother = DeformationSmootherType::New();
    defsmoother->SetKernelWidth(1.0);
    defsmoother->SetInput(invFilter->GetOutput());
    defsmoother->Update();

    inverseDeformation = defsmoother->GetOutput();
#endif

  }
  else
  {
    inverseDeformation = m_InverseDeformationField;
  }

  // Create DTI interpolator
  DTImageInterpolator dtInterpolator;
  dtInterpolator.SetInputImage(this->GetInput());

  // Create deformation field interpolators
  typedef itk::BSplineInterpolateImageFunction<FloatImageType, double>
    DeformationInterpolatorType;
  //typedef itk::LinearInterpolateImageFunction<FloatImageType, double>
  //  DeformationInterpolatorType;
  typedef itk::VectorIndexSelectionCastImageFilter<DeformationFieldType, FloatImageType> IndexSelectType;

  DeformationInterpolatorType::Pointer invDefInterpolators[3];
  for (int i = 0; i < 3; i++)
  {
    IndexSelectType::Pointer select_i = IndexSelectType::New();
    select_i->SetIndex(i);
    select_i->SetInput(inverseDeformation);
    select_i->Update();
  
    FloatImageType::Pointer displacement_i = select_i->GetOutput();
  
    DeformationInterpolatorType::Pointer tmp = DeformationInterpolatorType::New();
    tmp->SetInputImage(displacement_i);
    tmp->SetSplineOrder(2);

    invDefInterpolators[i] = tmp;
  }

  DiffusionTensor::MatrixType id(3, 3);
  id.set_identity();

  DiffusionTensor Teps;
  Teps.FromMatrix(id * 1e-10);

  float Kscale = 1.0 / sqrtf(m_JVariance);

  float maxJ = 0;
  float sumJ = 0;

  unsigned int countJ = 0;

  // Warp individual voxels
  DTImageIteratorType dtIt(
    this->GetOutput(), this->GetOutput()->GetLargestPossibleRegion());

  for (dtIt.GoToBegin(); !dtIt.IsAtEnd(); ++dtIt)
  {
    DTImageIndexType ind = dtIt.GetIndex();

    DTImagePointType point;
    this->GetInput()->TransformIndexToPhysicalPoint(ind, point);

    InverseDeformationFilter::VectorType invd =
      m_DeformationField->GetPixel(ind);

    // Find inverse mapping
    DTImagePointType inverse;

    inverse[0] = point[0] + invd[0];
    inverse[1] = point[1] + invd[1];
    inverse[2] = point[2] + invd[2];

    bool insideAllBuffers = true;
    for (int j = 0; j < 3; j++)
      if (!invDefInterpolators[j]->IsInsideBuffer(inverse))
      {
        insideAllBuffers = false;
        break;
      }

    if (!insideAllBuffers)
    {
      this->GetOutput()->SetPixel(ind, Teps);
      continue;
    }

    // Get the source to target mapping derivative at inverse point
    DiffusionTensor::MatrixType A(3, 3, 0.0);

    for (int j = 0; j < 3; j++)
    {
      DeformationInterpolatorType::CovariantVectorType d_j =
        invDefInterpolators[j]->EvaluateDerivative(inverse);
      A(j, 0) = d_j[0];
      A(j, 1) = d_j[1];
      A(j, 2) = d_j[2];
      A(j, j) += 1.0;
    }

    DiffusionTensor::MatrixQRType qr(A);
    float jacobian = qr.determinant();

    // Get the tensor at inverse point
    DiffusionTensor T = dtInterpolator.Evaluate(inverse);
    DiffusionTensor::MatrixType D = T.GetMatrix();

    //If A is close to id just use the original tensor
    DiffusionTensor::MatrixType diff = A - id;
    if (diff.frobenius_norm() < 1e-2)
    {
      this->GetOutput()->SetPixel(ind, T);
      continue;
    }

#if 0
    // Warp the tensor
    DiffusionTensor::MatrixType Dprime = A.transpose() * D * A;

    // Remove scaling effect
    DiffusionTensor::MatrixQRType qr(A);
    if (jacobian > 1e-10)
      Dprime /= pow(jacobian, 2.0 / 3.0);
#else
    // Compute rotation component of A through SVD
    DiffusionTensor::MatrixSVDType svd(A);

    // Rotation
    DiffusionTensor::MatrixType W = svd.U() * svd.V().transpose();
    // Strain
    //DiffusionTensor::MatrixType R =
    //  svd.V() * svd.W() *  svd.V().transpose();

    DiffusionTensor::MatrixType Dprime = W.transpose() * D * W;
#endif

    // Compute trace of original tensor
    float traceD = 0;
    for (unsigned int i = 0; i < 3; i++)
      traceD += D(i, i);

    // Compute determinant of original tensor
    //DiffusionTensor::MatrixQRType qrD(Dprime);
    //float detD = qrD.determinant();

    //
    // Homogenize tensor based on amount of deformation
    //

// Jacobian is volume multiplier [0, inf)
// J = 1 means no change
// K = max(J, 1)

    // Interpolation weight for homogenization
    //float K = fabs(jacobian);
    float K = (jacobian > 1) ? jacobian : 1;
    float alpha = expf(-0.5 * (K-1) * Kscale);

#if DEBUG
    if (jacobian > 0)
    {
      countJ++;
      sumJ += jacobian;
    }
    if (jacobian > maxJ)
      maxJ = jacobian;
#endif

    // E is isotropic tensor with eigenvalues = 2.0*MeanDiffusivity
    // E is isotropic tensor with eigenvalues = 2.0*GeometricMeanDiff
    DiffusionTensor::MatrixType E(3, 3);
    E.set_identity();
    E *= traceD / 3 * 2.0;
    //E *= pow(2.0 * detD, 1.0/3.0) + 1e-10;

    DiffusionTensor::MatrixType tt[2];
    tt[0] = Dprime;
    tt[1] = E;
    float ww[2];
    ww[0] = alpha;
    ww[1] = 1.0 - alpha;

    if (alpha < 1e-2)
      Dprime = E;
    else if (alpha > 0.99)
      ; // Do nothing, use original tensor
    else
      Dprime = FastTensorMatrixMath::mean(tt, ww, 2);
      //Dprime = TensorMatrixMath::mean(tt, ww, 2);

    DiffusionTensor warpedT;
    warpedT.FromMatrix(Dprime);

    this->GetOutput()->SetPixel(ind, warpedT);
  }

#if DEBUG
  std::cout << "Average J = " << sumJ / countJ << std::endl;
  std::cout << "Max J = " << maxJ << std::endl;
#endif

  if (m_NormalizeOutputTrace)
  {
    float maxTrace = 1e-20;
    for (dtIt.GoToBegin(); !dtIt.IsAtEnd(); ++dtIt)
    {
      DTImageIndexType ind = dtIt.GetIndex();
      DiffusionTensor D = this->GetOutput()->GetPixel(ind);
      DiffusionTensor::MatrixType Dmat = D.GetMatrix();
      float tr = 0.0;
      for (unsigned int i = 0; i < 3; i++)
        tr += Dmat(i, i);
      if (tr > maxTrace)
        maxTrace = tr;
    }
    for (dtIt.GoToBegin(); !dtIt.IsAtEnd(); ++dtIt)
    {
      DTImageIndexType ind = dtIt.GetIndex();
      DiffusionTensor D = this->GetOutput()->GetPixel(ind);
      D /= maxTrace;
      this->GetOutput()->SetPixel(ind, D);
    }
  }

}
