
#include "VonMisesFisherDistribution.h"

#include "vnl/vnl_bessel.h"
#include "vnl/vnl_math.h"

#include <cmath>
#include <cstdlib>
#include <stdexcept>

#define USE_MERSENNE 1

#if USE_MERSENNE
#include "MersenneTwisterRNG.h"
#endif


static inline float _uniformVar()
{
#if USE_MERSENNE
  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();
  return rng->GenerateUniformRealClosedInterval();
#else
  return rand() / (float)RAND_MAX;
#endif
}

static inline float _stdExpVar()
{
  float U = _uniformVar();
  return -logf(1.0 - U);
}

static inline float _stdNormalVar()
{
#if USE_MERSENNE
  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();
  return rng->GenerateNormal(0.0, 1.0);
#else
  while (true)
  {
    float U = 6.0*_uniformVar();
    float V = 2.0*_uniformVar() - 1.0;

    float X = sqrtf(2.0)*V / powf(V, 4.0/5.0);
    float Xsq = X*X;

    if (U > Xsq)
      return X;

    if (U >= (Xsq*(1.0-Xsq/8.0)))
      if ((1.0-U/6.0)*Xsq*Xsq <= 8.0*(expf(-Xsq/2.0-1+Xsq/2.0)))
        return X;
  }

  return 0;
#endif
}

static inline float _stdGammaVar(float shape)
{
  if (shape == 1.0)
    return _stdExpVar();

  float X, Y, U, V;

  if (shape < 1.0)
  {
    while (true)
    {
      float U = _uniformVar();
      float V = _stdExpVar();
      if (U <= (1.0 - shape))
      {
        X = powf(U, 1.0/shape);
        if (X <= V)
          return X;
      }
      else
      {
        Y = -logf((1.0-U)/shape);
        X = powf (1.0 - shape + shape*Y, 1.0/shape);
        if (X <= (V + Y))
          return X;
      }
    } // loop
  } // if shape < 1.0
  else
  {
    float b = shape - 1.0/3.0;
    float c = 1.0 / sqrtf(9.0 * b);
    while (true)
    {
      do
      {
        X = _stdNormalVar();
        V = 1.0 + c*X;
      } while (V <= 0.0);

      V = V*V*V;
      U = _uniformVar();

      float Xsq = X*X;
      if (U < (1.0 - 0.0331*Xsq*Xsq))
        return b*V;
      if (logf(U) < (0.5*Xsq+b*(1.0-V+logf(V))))
        return b*V;
    } // loop
  } // else shape < 1.0
}

static inline float _betaVar(float a, float b)
{
  if ((a <= 1.0) && (b <= 1.0))
  {
    // Jhonk's algorithm
    float U, V, X, Y;

    while (true)
    {
      U = _uniformVar();
      V = _uniformVar();

      X = powf(U, 1.0/a);
      Y = powf(V, 1.0/b);

      if ((X + Y) <= 1.0)
        return X / (X + Y);
    }
  }
  else
  {
    float Ga = _stdGammaVar(a);
    float Gb = _stdGammaVar(b);

    return Ga / (Ga + Gb);
  }
}

VonMisesFisherDistribution
::VonMisesFisherDistribution()
{
  VectorType mu(3, 0.0);
  mu[2] = 1.0;

  this->Initialize(mu, 1.0);
}

VonMisesFisherDistribution
::VonMisesFisherDistribution(const VectorType& mu, float k)
{
  this->Initialize(mu, k);
}

VonMisesFisherDistribution
::~VonMisesFisherDistribution()
{

}

void
VonMisesFisherDistribution
::Initialize(const VectorType& mu, float k)
{
  m_Mu = mu;
  m_K = k;

  float muMag = mu.magnitude();
  if (muMag < 1e-20)
    throw std::runtime_error("von Mises-Fisher: undefined mean direction");
  m_Mu /= muMag;

  m_Dim = mu.size();

  if (m_Dim < 2)
    throw std::runtime_error("von Mises-Fisher: dimension needs to be >= 2");

  float d = m_Dim;

  float t1 = sqrtf(4.0*k*k + (d-1)*(d-1));

  m_B = (-2.0*k+t1)/(d-1);
  m_X0 = (1-m_B) / (1+m_B);
  m_M = (d-1)/2.0;
  m_C = k*m_X0 + (d-1)*logf(1-m_X0*m_X0);

  // Reference point in unit hypersphere
  VectorType ref(m_Dim, 0.0);
  ref[m_Dim-1] = 1;

  // Generate rotation from ref to mu
  VectorType t = ref - m_Mu;
  if (t.magnitude() < 1e-10)
  {
    m_RotM = MatrixType(m_Dim, m_Dim);
    m_RotM.set_identity();
  }
  else
  {
    //m_RotM = _getReflectMat(m_Mu, -ref) * _getReflectMat(ref, -ref);

    VectorType ta = m_Mu + ref;
    ta.normalize();
    // rotA = I - 2*ta*ta.transpose();
    MatrixType rotA(m_Dim, m_Dim, 0.0);
    for (unsigned int i = 0; i < m_Dim; i++)
      for (unsigned int j = 0; j < m_Dim; j++)
      {
        if (i == j)
          rotA(i, j) += 1.0;
        rotA(i, j) -= 2.0*ta[i]*ta[j];
      }

    VectorType tb = ref + ref;
    tb.normalize();
    // rotB = I - 2*tb*tb.transpose();
    MatrixType rotB(m_Dim, m_Dim, 0.0);
    for (unsigned int i = 0; i < m_Dim; i++)
      for (unsigned int j = 0; j < m_Dim; j++)
      {
        if (i == j)
          rotB(i, j) += 1.0;
        rotB(i, j) -= 2.0*tb[i]*tb[j];
      }

    m_RotM = rotA*rotB;
  }
}

VonMisesFisherDistribution::VectorType
VonMisesFisherDistribution
::Generate()
{
  float t = -1.0;
  float u = 1.0;

  float w = 0;

  while (t < logf(u))
  {
    float z = _betaVar(m_M, m_M);
    u = _uniformVar();

    w = (1.0 - (1.0+m_B)*z)  / (1.0 - (1.0-m_B)*z);
    t = m_K*w + (m_Dim-1)*logf(1.0-m_X0*w) - m_C;
  }


  // Generate random vector on unit hypersphere
  VectorType v(m_Dim-1, 0.0);
  for (unsigned int i = 0; i < (m_Dim-1); i++)
    v[i] = _stdNormalVar();
  v.normalize();

  VectorType s(m_Dim, 0.0);
  float a = sqrtf(1.0-w*w);
  for (unsigned int i = 0; i < (m_Dim-1); i++)
    s[i] = a*v[i];
  s[m_Dim-1] = w;

  return m_RotM*s;
}

float
VonMisesFisherDistribution
::EvaluateDensity(const VectorType& x)
{
  float e = m_K * dot_product(m_Mu, x);
  float j = vnl_bessel(m_Dim/2 - 1, m_K);
// TODO make sure we use modified bessel func ???
  float bessel_1 = fabs(j);
  float c =
    powf(m_K, m_Dim/2.0-1.0) / (powf(2*vnl_math::pi, m_Dim/2.0) * bessel_1);
  return c * expf(e);
}
