
#include "LinearTetrahedralMesh.h"

#include "itkKdTreeGenerator.h"

#include "vtkCellType.h"
#include "vtkIdList.h"
#include "vtkPoints.h"
#include "vtkSmartPointer.h"
#include "vtkUnstructuredGridReader.h"

#include "vnl/algo/vnl_determinant.h"
#include "vnl/vnl_math.h"

#include "muException.h"

LinearTetrahedralMesh
::LinearTetrahedralMesh()
{
  m_VTKMesh = 0;
  m_KdTree = 0;
}

LinearTetrahedralMesh
::~LinearTetrahedralMesh()
{
  m_VTKMesh = 0;

  m_KdTree = 0;
  m_KdTreeSample = 0;

  m_ElementVolumes.Clear();
  m_ElementWeightMappings.Clear();
}

void
LinearTetrahedralMesh
::ClearMappings()
{
  m_KdTree = 0;
  m_KdTreeSample = 0;

  m_ElementVolumes.Clear();
  m_ElementWeightMappings.Clear();
}

void
LinearTetrahedralMesh
::ReadVTKFile(const char* fn)
{
  vtkSmartPointer<vtkUnstructuredGridReader> ugridReader = 
    vtkSmartPointer<vtkUnstructuredGridReader>::New();
  ugridReader->SetFileName(fn);
  ugridReader->Update();

  vtkSmartPointer<vtkUnstructuredGrid> ug = ugridReader->GetOutput();

  this->SetVTKMesh(ug);
}

void
LinearTetrahedralMesh
::RecomputeMappings()
{
  m_ElementVolumes.Clear();
  m_ElementWeightMappings.Clear();

  if (m_VTKMesh.GetPointer() == 0)
  {
    return;
  }

  unsigned int numElements = m_VTKMesh->GetNumberOfCells();
  unsigned int numPoints = m_VTKMesh->GetNumberOfPoints();

  m_ElementVolumes.Allocate(numElements);
  m_ElementWeightMappings.Allocate(numElements);

/*
  vtkSmartPointer<vtkUnstructuredGrid> newMesh =
    vtkSmartPointer<vtkUnstructuredGrid>::New();
  vtkPoints* newPts = vtkPoints::New();
  for (unsigned int i = 0; i < numPoints; i++)
  {
    double x[3];
    m_VTKMesh->GetPoint(i, x);
    newPts->InsertNextPoint(x); 
  }
  newMesh->SetPoints(newPts);
*/

  for (unsigned int el = 0; el < numElements; el++)
  {
    vtkSmartPointer<vtkIdList> ptIds = vtkSmartPointer<vtkIdList>::New();

    m_VTKMesh->GetCellPoints(el, ptIds);

    if (ptIds->GetNumberOfIds() != 4)
    {
      m_ElementVolumes.Append(1e-10);
      m_ElementWeightMappings.Append(MatrixType(4, 4, 0.0));
      continue;
    }

    double point[3];

    MatrixType H(4, 4);

    for (unsigned int j = 0; j < 4; j++)
    {
      m_VTKMesh->GetPoint(ptIds->GetId(j), point);
      H(0, j) = point[0];
      H(1, j) = point[1];
      H(2, j) = point[2];
    }

    for (unsigned int j = 0; j < 4; j++)
      H(3, j) = 1.0;

    float V6 = vnl_determinant(H);

    float V = V6 / 6.0;

    // Ordering can result in negative volume
    if (V < 0.0)
    {
      // Reorder H and make volume positive
      // Hper = H * P where P is a permutation matrix with det = -1
      MatrixType Hper = H;
      for (unsigned int i = 0; i < 3; i++)
      {
        Hper(i, 0) = H(i, 3);
        Hper(i, 3) = H(i, 0);
      }
      H = Hper;

      V *= -1.0;

      // Rearrange global point ids for this cell
      vtkIdType newIds[4];
      newIds[0] = ptIds->GetId(3);
      newIds[1] = ptIds->GetId(1);
      newIds[2] = ptIds->GetId(2);
      newIds[3] = ptIds->GetId(0);
      //m_VTKMesh->ReplaceLinkedCell(el, 4, newIds); // Proper but does not exist?
      m_VTKMesh->ReplaceCell(el, 4, newIds);
      //newMesh->InsertNextCell(VTK_TETRA, 4, newIds);
    }
/*
    else
    {
      vtkIdType foo[4];
      foo[0] = ptIds->GetId(0);
      foo[1] = ptIds->GetId(1);
      foo[2] = ptIds->GetId(2);
      foo[3] = ptIds->GetId(3);
      newMesh->InsertNextCell(VTK_TETRA, 4, foo);
    }
*/

    if (V < 1e-10)
    {
      // Degenerate case
      m_ElementVolumes.Append(1e-10);

      MatrixType E(4, 4);
      E.set_identity();
      E *= 1e-10;
      //E.fill(0.0);

      m_ElementWeightMappings.Append(E);
    }
    else
    {
      m_ElementVolumes.Append(V);
      m_ElementWeightMappings.Append(MatrixInverseType(H));
    }

  } // for el

  // Rebuild cell links in case of reordering
  m_VTKMesh->BuildLinks();
  // Use separate mesh instead
  //m_VTKMesh = newMesh;

}

LinearTetrahedralMesh::VectorType
LinearTetrahedralMesh
::ComputeShapeFunctions(unsigned int el, const VectorType& globalP)
{
  // Convert to homogeneous coordinates
  VectorType globalP_h(4);
  for (int j = 0; j < 3; j++)
    globalP_h[j] = globalP[j];
  globalP_h[3] = 1.0;

  VectorType N = m_ElementWeightMappings[el] * globalP_h;

  return N;
}

LinearTetrahedralMesh::MatrixType
LinearTetrahedralMesh
::ComputeShapeFunctionDerivatives(unsigned int el)
{
  MatrixType mapM = m_ElementWeightMappings[el];

  // Drop entries in the last column of mapping (constant terms)
  MatrixType delN(4, 3, 0.0);
  for (unsigned int i = 0; i < 4; i++)
    for (unsigned int j = 0; j < 3; j++)
      delN(i, j) = mapM(i, j); 

  return delN;
}

float
LinearTetrahedralMesh
::ComputeElementVolume(unsigned int el)
{
  return m_ElementVolumes[el];
}

LinearTetrahedralMesh::PointIDVectorType
LinearTetrahedralMesh
::FindClosestPoints(PointType& query, unsigned int n)
{
  // Build Kd tree if needed
  if (m_KdTree.IsNull())
  {
    unsigned int numPoints = m_VTKMesh->GetNumberOfPoints();
//std::cout << "Building Kd tree for nn queries... " << numPoints << " points" << std::endl;

    SampleType::Pointer sample = SampleType::New();
    sample->SetMeasurementVectorSize(3);
    
    for (unsigned int i = 0; i < numPoints; i++)
    { 
      double x[3];
      m_VTKMesh->GetPoint(i, x);

      PointType mv;
      for (unsigned int dim = 0; dim < 3; dim++)
        mv[dim] = x[dim];

      sample->PushBack(mv);
    }

    m_KdTreeSample = sample;

    typedef itk::Statistics::KdTreeGenerator<SampleType> TreeGeneratorType;
    TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();
    treeGenerator->SetSample(sample);
    treeGenerator->SetBucketSize(100);
    treeGenerator->Update();

    m_KdTree = treeGenerator->GetOutput();
  }

  // Do nearest neighbor search
  PointIDVectorType neighbors;

  m_KdTree->Search(query, n, neighbors) ;

  return neighbors;
}
