
#include "DTImageInterpolator.h"
#include "DTImageResampleFilter.h"

#include "vnl/vnl_math.h"

#include <math.h>

// Helper function, build 3D rotation matrix given unit vector (axis) and angle
inline static
DiffusionTensor::MatrixType
buildRotation(double x, double y, double z, double angle)
{
  double c = cos(angle);
  double s = sin(angle);

  double ic = 1.0 - c;

  DiffusionTensor::MatrixType R(3, 3);

  R(0, 0) = x*x * ic + c;
  R(0, 1) = x*y * ic - z*s;
  R(0, 2) = x*z * ic + y*s;
  
  R(1, 0) = x*y * ic + z*s;
  R(1, 1) = y*y * ic + c;
  R(1, 2) = y*z * ic - x*s;

  R(2, 0) = x*z * ic - y*s;
  R(2, 1) = y*z * ic + x*s;
  R(2, 2) = z*z * ic + c;

  return R;
}

DTImageResampleFilter::
DTImageResampleFilter()
{
  m_OutputSpacing.Fill(1.0);
  m_OutputOrigin.Fill(0.0);
  m_OutputSize.Fill(0);
  m_OutputDirection.SetIdentity();
}

DTImageResampleFilter::
~DTImageResampleFilter()
{

}

void
DTImageResampleFilter::
GenerateData()
{

  if (this->GetInput() == 0)
    return;

  // Process transform
  DiffusionTensor::MatrixType M(3, 3);
  for (unsigned int i = 0; i < 3; i++)
    for (unsigned int j = 0; j < 3; j++)
    {
      M(i, j) = (m_T2STrafo->GetMatrix())[i][j];
    }

std::cout << "M = \n" << M << std::endl;

  DiffusionTensor::MatrixType Minverse = DiffusionTensor::MatrixInverseType(M);

std::cout << "Minverse = \n" << Minverse << std::endl;

  // Global tensor rotation
  DiffusionTensor::MatrixType globalR;
  {
    DiffusionTensor::MatrixSVDType svd(Minverse);
    globalR = svd.U() * svd.V().transpose();
    
std::cout << "globalR = \n" << globalR << std::endl;
  }   

  // Allocate space for output
  this->GetOutput()->SetRegions(m_OutputSize);
  this->GetOutput()->Allocate();
  this->GetOutput()->SetOrigin(m_OutputOrigin);
  this->GetOutput()->SetSpacing(m_OutputSpacing);
  this->GetOutput()->SetDirection(m_OutputDirection);

  // Set up interpolator
  DTImageInterpolator dtInterpolator;
  dtInterpolator.SetInputImage(this->GetInput());

  DTImageSizeType inputSize =
    this->GetInput()->GetLargestPossibleRegion().GetSize();
  DTImageSpacingType inputSpacing = this->GetInput()->GetSpacing();

  DTImageIndexType ind;

  for (ind[2] = 0; ind[2] < (long)m_OutputSize[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)m_OutputSize[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)m_OutputSize[0]; ind[0]++)
      {
        PointType p;
        this->GetOutput()->TransformIndexToPhysicalPoint(ind, p);

        PointType mappedp = m_T2STrafo->TransformPoint(p);

        DiffusionTensor T = dtInterpolator.Evaluate(mappedp);
        DiffusionTensor::MatrixType D = T.GetMatrix();

//std::cout << "D = \n" << D << std::endl;

#if 0
        DiffusionTensor::MatrixEigenType eig(D);

        VNLVectorType e1 = eig.get_eigenvector(2);
        VNLVectorType e2 = eig.get_eigenvector(1);

        if (e1.two_norm() == 0 || e2.two_norm() == 0)
        {
          continue;
        }

//std::cout << "e1 = " << e1 << std::endl;
//std::cout << "e2 = " << e2 << std::endl;

        e1.normalize();
        e2.normalize();

        // Adjust tensor with spatial transformation Minverse

// Think of Frenet frames, adjust two of them get all three new vectors
// First get new T (rotated e1), then N should be rotated to get close
// to the deformed e2

        VNLVectorType n1 = e1;
        n1.pre_multiply(Minverse);

        VNLVectorType n2 = e2;
        n2.pre_multiply(Minverse);

        if (n1.two_norm() == 0 || n2.two_norm() == 0)
        {
          continue;
        }

        n1.normalize();
        n2.normalize();

        // Compute R1, n1=R1*e1
        // cross3d is a function provided by vnl_vector
        VNLVectorType axisR1 = cross_3d(e1, n1);

        if (axisR1.two_norm() == 0)
          continue;

        axisR1.normalize();

        // angle is a function provided by vnl_vector
        double angleR1 = angle(e1, n1);

        DiffusionTensor::MatrixType R1 =
          buildRotation(axisR1[0], axisR1[1], axisR1[2], angleR1);

        // Project n2 to plane perp to n1=R1*e1
        VNLVectorType n2_proj = n2 - (n1 * dot_product(n1, n2));
        n2_proj.normalize();

        // Compute R2, rotates R1*e2 to proj of n2
        VNLVectorType axisR2 = e1;
        axisR2.pre_multiply(R1);

        if (axisR2.two_norm() == 0)
          continue;

        axisR2.normalize();

        VNLVectorType e2rot = e2;
        e2rot.pre_multiply(R1);
        double angleR2 = angle(e2rot, n2_proj);

        DiffusionTensor::MatrixType R2 =
          buildRotation(axisR2[0], axisR2[1], axisR2[2], angleR2);

        // Compose the rotations
        DiffusionTensor::MatrixType R = R2*R1;

        // Reorient tensor
        DiffusionTensor::MatrixType Dalign = R * D * R.transpose();

#else
        DiffusionTensor::MatrixType Dalign = globalR * D * globalR.transpose();
#endif

        DiffusionTensor rotT;
        rotT.FromMatrix(Dalign);

        this->GetOutput()->SetPixel(ind, rotT);

      }

}
