
#include "FastTensorMatrixMath.h"

#include "vnl/vnl_math.h"

#include <iostream>

#include <cmath>

namespace FastTensorMatrixMath
{


MatrixType
expMap(const MatrixType& X)
{
//std::cout << "expMap" << std::endl;

  DiffusionTensor::MatrixEigenType eig(X);

  // Exponent of a diagonal matrix is simply exponent of the diagonal elements
  MatrixType expD(3, 3);
  expD.fill(0);
  for (int i = 0; i < 3; i++)
  {
    float f = expf(eig.D(i, i));
    if (vnl_math_isnan(f))
      f = 0;
    if (vnl_math_isinf(f))
      f = 0;
    expD(i, i) = f;
  }

//std::cout << "expMap RET" << std::endl;

  return eig.V * expD * eig.V.transpose();
}

MatrixType
logMap(const MatrixType& s)
{
//std::cout << "logMap" << std::endl;

  DiffusionTensor::MatrixEigenType eig(s);

  MatrixType logD(3, 3);
  logD.fill(0);
  for (int i = 0; i < 3; i++)
  {
    float di = eig.D(i, i);

    if (di < 1e-10)
      di = 1e-10;

    float f = logf(di);
    if (vnl_math_isnan(f))
      f = 0;
    if (vnl_math_isinf(f))
      f = 0;
    logD(i, i) = f;
  }

//std::cout << "logMap RET" << std::endl;

  return eig.V * logD * eig.V.transpose();
}

float innerProduct(
  const MatrixType& s1, const MatrixType& s2)
{
  MatrixType diff = logMap(s1) - logMap(s2);

  MatrixType A = diff * diff;

  float trace = 0;
  for (unsigned int i = 0; i < 3; i++)
    trace += A(i, i);

  return sqrtf(trace);
}

float norm(const MatrixType& s)
{
  return innerProduct(s, s);
}


MatrixType
mean(MatrixType* tarray, float* weights, unsigned int n)
{
  MatrixType defaultM(3, 3);
  defaultM.set_identity();
  defaultM *= 1e-10;

  if (tarray == 0 || weights == 0)
// Exception?
    return defaultM;

  if (n == 0)
    return defaultM;

  if (n == 1)
    return tarray[0];

/*
  float sumWeights = 0;
  for (unsigned int k = 0; k < n; k++)
    sumWeights += weights[k];

  if (fabs(sumWeights-1.0) > 1e-10)
    std::cerr << "[FastTensorMatrixMath::mean] sum of weights not one: " 
      << sumWeights << std::endl;
*/

  MatrixType logMu(3, 3);
  logMu.fill(0);

  for (unsigned int k = 0; k < n; k++)
  {
    MatrixType logSk = FastTensorMatrixMath::logMap(tarray[k]);

/*
    float traceLogSk = 0;
    for (int j = 0; j < 3; j++)
      traceLogSk += logSk(j, j);

    if (traceLogSk < 1e-10)
      continue;
*/

    logMu += (logSk * weights[k]);
  }

  return expMap(logMu);
}

} // namespace FastTensorMatrixMath
