
#include "TensorMatrixMath.h"

#include "vnl/vnl_math.h"

#include <iostream>

#include <cmath>

namespace TensorMatrixMath
{

// Returns X, such that A == X*X'
static MatrixType _decompose(const MatrixType& A)
{

#if 0
  // Cholesky decomposition
  DiffusionTensor::MatrixCholeskyType chol(A);

  //return chol.lower_triangle();
  return chol.L_badly_named_method();
#else
  // Matrix square root
  DiffusionTensor::MatrixEigenType eig(A);

  // Just in case that A is not non-negative definite (should not happen)
  for (int i = 0; i < 3; i++)
    if (eig.D(i, i) < 1e-10)
      eig.D(i, i) = 1e-10;

  // returns V * sqrt(D) * V.transpose()
  return eig.square_root();
#endif

}

MatrixType
expMap(const MatrixType& p, const MatrixType& X)
{
//std::cout << "expMap" << std::endl;
//std::cout << "X = \n" << X << std::endl;

  // X nearly zero, don't bother calculating
  if (X.frobenius_norm() < 1e-8)
    return p;

  MatrixType g = _decompose(p);
  MatrixType invg = DiffusionTensor::MatrixInverseType(g);
//std::cout << "invg = \n" << invg << std::endl;

  MatrixType Y = invg * X * invg.transpose();
//std::cout << "Y = \n" << Y << std::endl;

  DiffusionTensor::MatrixEigenType eig(Y);

  // Exponent of a diagonal matrix is simply exponent of the diagonal elements
  MatrixType expD(3, 3);
  expD.fill(0);
  for (int i = 0; i < 3; i++)
  {
    float f = expf(eig.D(i, i));
    if (vnl_math_isnan(f))
      f = 0;
    if (vnl_math_isinf(f))
      f = 0;
    expD(i, i) = f;
  }

  MatrixType gv = g * eig.V;

//std::cout << "expMap RET" << std::endl;

  return gv * expD * gv.transpose();
}

MatrixType
logMap(const MatrixType& p, const MatrixType& X)
{
//std::cout << "logMap" << std::endl;
  MatrixType g = _decompose(p);
//std::cout << "g = \n" << g << std::endl;
  MatrixType invg = DiffusionTensor::MatrixInverseType(g);
//std::cout << "invg = \n" << invg << std::endl;

  MatrixType Y = invg * X * invg.transpose();

  DiffusionTensor::MatrixEigenType eig(Y);

  MatrixType logD(3, 3);
  logD.fill(0);
  for (int i = 0; i < 3; i++)
  {
    float di = eig.D(i, i);
    if (di < 1e-10)
      di = 1e-10;
    float f = logf(di);
    if (vnl_math_isnan(f))
      f = 0;
    if (vnl_math_isinf(f))
      f = 0;
    logD(i, i) = f;
  }

  MatrixType gv = g * eig.V;
//std::cout << "logMap RET" << std::endl;

  return gv * logD * gv.transpose();
}

float innerProduct(
  const MatrixType& p, const MatrixType& X, const MatrixType& Y)
{
  MatrixType g = _decompose(p);
  MatrixType invg = DiffusionTensor::MatrixInverseType(g);

  MatrixType A =
    invg * X * DiffusionTensor::MatrixInverseType(p) * Y * invg.transpose();

  float trace = 0;
  for (unsigned int i = 0; i < 3; i++)
    trace += A(i, i);

  return trace;
}

float norm(const MatrixType& p, const MatrixType& X)
{
  return sqrtf(innerProduct(p, X, X));
}


MatrixType
mean(MatrixType* tarray, float* weights, unsigned int n)
{
  MatrixType defaultM(3, 3);
  defaultM.set_identity();
  defaultM *= 1e-10;

  if (tarray == 0 || weights == 0)
// Exception?
    return defaultM;

  if (n == 0)
    return defaultM;

  if (n == 1)
    return tarray[0];

/*
  float sumWeights = 0;
  for (unsigned int k = 0; k < n; k++)
    sumWeights += weights[k];

  if (fabs(sumWeights-1.0) > 1e-10)
    std::cerr << "[TensorMatrixMath::mean] sum of weights not one: " 
      << sumWeights << std::endl;
*/

  MatrixType mu(3, 3);
  mu.set_identity();

  MatrixType X(3, 3);

  float eps = 1e-4;

  unsigned int iter = 0;

  do
  {
    iter++;

    if (iter > 10)
      break;

    if (mu.frobenius_norm() < 1e-2)
      break;

    X.fill(0);
    for (unsigned int k = 0; k < n; k++)
    {
      MatrixType dX = TensorMatrixMath::logMap(mu, tarray[k]) * weights[k];

      for (int i = 0; i < 3; i++)
        for (int j = 0; j < 3; j++)
        {
          float f = dX(i, j);
          if (vnl_math_isnan(f))
            dX(i, j) = 0;
          if (vnl_math_isinf(f))
            dX(i, j) = 0;
        }

      X += dX;
    }

    mu = TensorMatrixMath::expMap(mu, X);
  }
  while (TensorMatrixMath::norm(mu, X) > eps);

  return mu;
}


} // namespace TensorMatrixMath
