
#include "DTImageInterpolator.h"

#include "TensorMatrixMath.h"
#include "FastTensorMatrixMath.h"

#include "DTImageScalarSource.h"

#include "itkImageRegionIteratorWithIndex.h"

#include "DynArray.h"

DTImageInterpolator
::DTImageInterpolator()
{
  m_InputImage = 0;

  DiffusionTensor::MatrixType M(3, 3);
  M.set_identity();
  M *= 1e-20;

  m_DefaultTensor.FromMatrix(M);

  m_UseLogEuclidean = true;
}

DTImageInterpolator
::~DTImageInterpolator()
{

}

void
DTImageInterpolator
::SetDefaultTensor(const DiffusionTensor& dt)
{
  m_DefaultTensor = dt;
}

void
DTImageInterpolator
::SetInputImage(const DTImageType* dti)
{
  m_InputImage = (DTImageType*)dti;

#if 0
  DTImageScalarSource::Pointer ssource = DTImageScalarSource::New();
  ssource->SetInput(m_InputImage);

  typedef itk::Image<float, 3> FloatImageType;

  FloatImageType::Pointer MDimg = ssource->GetMDImage();

  typedef itk::ImageRegionIteratorWithIndex<FloatImageType> IteratorType;

  IteratorType it(MDimg, MDimg->GetLargestPossibleRegion());
  
  it.GoToBegin();

  float minMD = it.Get();
  DiffusionTensor minT = m_InputImage->GetPixel(it.GetIndex());

  while (!it.IsAtEnd())
  {
    float v = it.Get();
    if (v < minMD)
    {
      minMD = v;
      minT = m_InputImage->GetPixel(it.GetIndex());
    }
    ++it;
  }

/*
  DiffusionTensor::MatrixType M(3, 3);
  M.set_identity();
  M *= minMD / 3.0 + 1e-20;

  m_DefaultTensor.FromMatrix(M);
*/
  m_DefaultTensor = minT;
#endif

}

DiffusionTensor
DTImageInterpolator
::Evaluate(DTImagePointType& point)
{
  if (m_InputImage == 0)
    muExceptionMacro(<< "No input image defined");

  DTImagePointType origin = m_InputImage->GetOrigin();

  DTImageSizeType size = m_InputImage->GetLargestPossibleRegion().GetSize();

  DTImageSpacingType spacing = m_InputImage->GetSpacing();

  // Transform to continuous index
  typedef itk::ContinuousIndex<double, 3> ContinuousIndexType;
  ContinuousIndexType contInd;
  m_InputImage->TransformPhysicalPointToContinuousIndex(point, contInd);

  float x = point[0];
  float y = point[1];
  float z = point[2];

  // Get the 8 grid positions
  int ix1 = (int)contInd[0];
  int iy1 = (int)contInd[1];
  int iz1 = (int)contInd[2];

  int ix2 = ix1 + 1;
  int iy2 = iy1 + 1;
  int iz2 = iz1 + 1;

  float x1 = ix1 * spacing[0] + origin[0];
  float y1 = iy1 * spacing[1] + origin[1];
  float z1 = iz1 * spacing[2] + origin[2];

  float x2 = ix2 * spacing[0] + origin[0];
  float y2 = iy2 * spacing[1] + origin[1];
  float z2 = iz2 * spacing[2] + origin[2];

  float V = spacing[0] * spacing[1] * spacing[2];

  // Get distances to the image grid
  float fx = x - x1;
  float fy = y - y1;
  float fz = z - z1;

  float gx = x2 - x;
  float gy = y2 - y;
  float gz = z2 - z;

  // Add DTI voxels and weights in the neighborhood
  DynArray<DiffusionTensor::MatrixType> tensors;
  DynArray<float> weights;

  tensors.Allocate(8);
  weights.Allocate(8);

#define WEIGHT_MACRO(x, y, z, w) \
  if ((0 <= (x)) && ((x) < (long)size[0]) && \
    (0 <= (y)) && ((y) < (long)size[1]) && \
    (0 <= (z)) && ((z) < (long)size[2])) \
  { \
    DTImageIndexType ind = {{(x), (y), (z)}}; \
    tensors.Append(m_InputImage->GetPixel(ind).GetMatrix()); \
    weights.Append(w); \
  }

  WEIGHT_MACRO(ix1, iy1, iz1, gx*gy*gz / V);
  WEIGHT_MACRO(ix1, iy1, iz2, gx*gy*fz / V);
  WEIGHT_MACRO(ix1, iy2, iz1, gx*fy*gz / V);
  WEIGHT_MACRO(ix1, iy2, iz2, gx*fy*fz / V);
  WEIGHT_MACRO(ix2, iy1, iz1, fx*gy*gz / V);
  WEIGHT_MACRO(ix2, iy1, iz2, fx*gy*fz / V);
  WEIGHT_MACRO(ix2, iy2, iz1, fx*fy*gz / V);
  WEIGHT_MACRO(ix2, iy2, iz2, fx*fy*fz / V);

#undef WEIGHT_MACRO

/*
  // Sum of weights should be one
  float sumw = 0;
  for (unsigned int k = 0; k < weights.GetSize(); k++)
    sumw += weights[k];
  if (sumw == 0) // Not gonna happen?
    return m_DefaultTensor;
  for (unsigned int k = 0; k < weights.GetSize(); k++)
    weights[k] /= sumw;
*/

  DiffusionTensor::MatrixType mu;

  if (m_UseLogEuclidean)
  {
    mu =
      FastTensorMatrixMath::mean(tensors.GetRawArray(), weights.GetRawArray(),
        tensors.GetSize());
      //TensorMatrixMath::mean(tensors.GetRawArray(), weights.GetRawArray(),
      //  tensors.GetSize());
  }
  else
  {
    mu = DiffusionTensor::MatrixType(3, 3);
    mu.fill(0);
    for (unsigned int k = 0; k < weights.GetSize(); k++)
      mu += tensors[k] * weights[k];
  }

  DiffusionTensor tensor;
  tensor.FromMatrix(mu);

  return tensor;
}
