
#include "BrainMeshGenerator.h"
#include "MersenneTwisterRNG.h"

#include "itkDiscreteGaussianImageFilter.h"
#include "itkVTKImageExport.h"
#include "itkSignedDanielssonDistanceMapImageFilter.h"

#include "vtkCellArray.h"
#include "vtkCellType.h"
#include "vtkContourFilter.h"
#include "vtkDelaunay3D.h"
#include "vtkIdList.h"
#include "vtkImageData.h"
#include "vtkImageImport.h"
#include "vtkPoints.h"
#include "vtkPointLocator.h"
#include "vtkSmartPointer.h"

#include "itkRescaleIntensityImageFilter.h"
#include "itkResampleImageFilter.h"
#include "itkImageFileWriter.h"

#include "createMesh3D.h"

#include "DynArray.h"

#include "LinearTetrahedralMesh.h"

#define MIN_DISTANCEMAG 1e-10

#define SECOND_ORDER_DIST 0
#define SMOOTH_DIST 1

BrainMeshGenerator
::BrainMeshGenerator()
{
  m_InitialSpacing = 4.0;

  m_FScale = 1.1;

  m_TimeStep = 0.01;

  m_MaximumIterations = 500;

  m_MaxLabel = 0;

  m_SplineInterpolationOrder = 1;
}

void
BrainMeshGenerator
::SetInputProbabilities(const DynArray<FloatImagePointer>& probs)
{

  if (probs.GetSize() == 0)
    muExceptionMacro(<< "No probabilities");

  FloatImageSizeType size = probs[0]->GetLargestPossibleRegion().GetSize();

  for (unsigned int k = 0; k < probs.GetSize(); k++)
  {
    FloatImageSizeType size_k = probs[k]->GetLargestPossibleRegion().GetSize();
    if (size_k != size)
      muExceptionMacro(<< "Image size mismatch");
  }

  m_InputProbabilities = probs;

  // Generate label image
std::cout << "Build label image" << std::endl;
  m_LabelImage = ByteImageType::New();
  m_LabelImage->CopyInformation(m_InputProbabilities[0]);
  m_LabelImage->SetRegions(m_InputProbabilities[0]->GetLargestPossibleRegion());
  m_LabelImage->Allocate();
  m_LabelImage->FillBuffer(0);

  FloatImageIndexType ind;

  for (ind[2] = 0; ind[2] < size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < size[0]; ind[0]++)
      {
        unsigned char maxL = 0;
        float maxP = 0;

        for (unsigned int i = 0; i < m_InputProbabilities.GetSize(); i++)
        {
          float p = m_InputProbabilities[i]->GetPixel(ind);

          if (p > maxP)
	  {
	    maxP = p;
	    maxL = i;
	  }
        }

        if (maxP > 0)
	  m_LabelImage->SetPixel(ind, (unsigned char)(maxL+1));
      }

  this->ComputeInternalImages();

}

void
BrainMeshGenerator
::SetLabelImage(const ByteImagePointer& labelImg)
{
  m_LabelImage = labelImg;

  this->ComputeInternalImages();
}

vtkSmartPointer<vtkUnstructuredGrid>
BrainMeshGenerator
::GenerateMesh()
{
std::cout << "Generating mesh with initial spacing = " << m_InitialSpacing << std::endl;
  return this->GenerateMeshAllAtOnce();
  //return this->GenerateMeshSequentially();
}

vtkSmartPointer<vtkUnstructuredGrid>
BrainMeshGenerator
::GenerateMeshAllAtOnce()
{
  FloatImageSizeType size =
    m_LabelImage->GetLargestPossibleRegion().GetSize();

  FloatImageSpacingType spacing = m_LabelImage->GetSpacing();

  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();

  typedef itk::ImageRegionIteratorWithIndex<ByteImageType> LabelIteratorType;
  LabelIteratorType labelIt(
    m_LabelImage, m_LabelImage->GetLargestPossibleRegion());

  // Generate list of points
  // Spaced uniformly with distance to brain boundary <= cube diagonal
std::cout << "Initial set of points" << std::endl;
  float insideTol = m_InitialSpacing * 0.1;

  vtkSmartPointer<vtkPoints> meshPoints = vtkSmartPointer<vtkPoints>::New();
  meshPoints->SetDataTypeToFloat();

  DynArray<unsigned char> fixedPointMarkers;

#if 1
  // Use contour filter for boundary points
std::cout << "Contour filter..." << std::endl;
  FloatImagePointer contourMask = FloatImageType::New();
  contourMask->CopyInformation(m_LabelImage);
  contourMask->SetRegions(m_LabelImage->GetLargestPossibleRegion());
  contourMask->Allocate();
  contourMask->FillBuffer(0);

  for (labelIt.GoToBegin(); !labelIt.IsAtEnd(); ++labelIt)
  {
    if (labelIt.Get() == 0)
      continue;

    contourMask->SetPixel(labelIt.GetIndex(), 1.0);
  }

/*
// TODO downsample mask?/
  typedef itk::ResampleImageFilter<FloatImageType, FloatImageType> ResamplerType;
*/

  typedef itk::VTKImageExport<FloatImageType> ITKExportType;
  ITKExportType::Pointer itkexport = ITKExportType::New();
  itkexport->SetInput(contourMask);
  itkexport->Update();

  // See InsightApplications/Auxialiary/vtk/itkImageToVTKImageFilter
  vtkSmartPointer<vtkImageImport> vtkimport =
    vtkSmartPointer<vtkImageImport>::New();
  vtkimport->SetUpdateInformationCallback(
    itkexport->GetUpdateInformationCallback());
  vtkimport->SetPipelineModifiedCallback(
    itkexport->GetPipelineModifiedCallback());
  vtkimport->SetWholeExtentCallback(itkexport->GetWholeExtentCallback());
  vtkimport->SetSpacingCallback(itkexport->GetSpacingCallback());
  vtkimport->SetOriginCallback(itkexport->GetOriginCallback());
  vtkimport->SetScalarTypeCallback(itkexport->GetScalarTypeCallback());
  vtkimport->SetNumberOfComponentsCallback(itkexport->GetNumberOfComponentsCallback());
  vtkimport->SetPropagateUpdateExtentCallback(itkexport->GetPropagateUpdateExtentCallback());
  vtkimport->SetUpdateDataCallback(itkexport->GetUpdateDataCallback());
  vtkimport->SetDataExtentCallback(itkexport->GetDataExtentCallback());
  vtkimport->SetBufferPointerCallback(itkexport->GetBufferPointerCallback());
  vtkimport->SetCallbackUserData(itkexport->GetCallbackUserData());

  vtkSmartPointer<vtkContourFilter> contourf =
    vtkSmartPointer<vtkContourFilter>::New();
  contourf->SetInput(vtkimport->GetOutput());
  contourf->SetNumberOfContours(1);
  contourf->SetValue(0, 1.0);
  contourf->ComputeNormalsOff();
  contourf->ComputeGradientsOff();

  contourf->Update();

  vtkSmartPointer<vtkPolyData> contourPD = contourf->GetOutput();

std::cout << "Contour filter: " << contourPD->GetNumberOfPoints() << " boundary points" << std::endl;

  for (vtkIdType k = 0; k < contourPD->GetNumberOfPoints(); k++)
  {
    double x[3];
    contourPD->GetPoint(k, x);

    meshPoints->InsertNextPoint(x);
    fixedPointMarkers.Append(1);
  }

  // Clean up contour filtering vars
  contourMask = 0;
#endif

  // Insert equally spaced points inside regions, not at boundary
  float lenx = (size[0]-1) * spacing[0];
  float leny = (size[1]-1) * spacing[1];
  float lenz = (size[2]-1) * spacing[2];

  float randomDisp = 0.1*m_InitialSpacing;

  for (float pz = 0; pz <= lenz; pz += m_InitialSpacing)
    for (float py = 0; py <= leny; py += m_InitialSpacing)
      for (float px = 0; px <= lenx; px += m_InitialSpacing)
      {
        FloatImagePointType p;
        p[0] = px + 0.5*spacing[0];
        p[1] = py + 0.5*spacing[1];
        p[2] = pz + 0.5*spacing[2];

/*
        for (int dim = 0; dim < 3; dim++)
        {
          float r = 2.0*rng->GenerateUniformRealOpenInterval() - 1.0;
          p[dim] += r * randomDisp;
        }
*/

        if (!m_BrainDistanceInterpolator->IsInsideBuffer(p))
          continue;

        float phi = m_BrainDistanceInterpolator->Evaluate(p);

        // Skip outside points
        if (phi > insideTol)
          continue;

        meshPoints->InsertNextPoint(px, py, pz);
        fixedPointMarkers.Append(0);
      }

  // Bounds for Delaunay triangulation
  double bounds[6];
  for (int dim = 0; dim < 3; dim++)
  {
    bounds[2*dim] = 0.0;
    bounds[2*dim + 1] = size[dim]*spacing[dim];
  }

  vtkSmartPointer<vtkUnstructuredGrid> mesh;

  // Begin loop
  unsigned int numCells = 0;
  unsigned int numPoints = meshPoints->GetNumberOfPoints();
std::cout << "Initially " << numPoints << " points" << std::endl;

  float maxChange = m_InitialSpacing;

  unsigned int iter = 0;

  double x[3];
  double y[3];

  while (true)
  {
    ++iter;

std::cout << "---------------------------" << std::endl;
std::cout << "Loop iteration " << iter << std::endl;
    // Recompute Delaunay triangulation if change is higher than threshold
    //if ((maxChange/m_InitialSpacing) > 0.1)
    //if ((iter % 5) == 1)
    {

std::cout << "Recompute Delaunay tesselation" << std::endl;
std::cout << "Current count: " << meshPoints->GetNumberOfPoints() << " points" << std::endl;
      vtkSmartPointer<vtkPoints> filteredPoints =
        vtkSmartPointer<vtkPoints>::New();
      filteredPoints->SetDataTypeToFloat();
      filteredPoints->Allocate(numPoints);

      vtkSmartPointer<vtkPointLocator> pLoc =
        vtkSmartPointer<vtkPointLocator>::New();
      pLoc->SetTolerance(m_InitialSpacing*0.25);
      pLoc->InitPointInsertion(filteredPoints, bounds, numPoints);

      DynArray<unsigned char> filteredFixedMarkers;
      filteredFixedMarkers.Allocate(numPoints);

      // Drop duplicate points
std::cout << "Dropping duplicate points" << std::endl;
      for (int i = 0; i < numPoints; i++)
      {
        meshPoints->GetPoint(i, x);

        int id = pLoc->IsInsertedPoint(x);

        if (id < 0)
        {
          pLoc->InsertNextPoint(x);
          if (fixedPointMarkers[i] != 0)
            filteredFixedMarkers.Append(1);
          else
            filteredFixedMarkers.Append(0);
        }
      
      } // for i

std::cout << "Filter: " << filteredPoints->GetNumberOfPoints() << " remaining" << std::endl;

      meshPoints = pLoc->GetPoints();
      fixedPointMarkers = filteredFixedMarkers;

      mesh = createMesh3D(meshPoints, m_BrainDistanceImage);

      meshPoints = mesh->GetPoints();

      numCells = mesh->GetNumberOfCells();
      numPoints = mesh->GetNumberOfPoints();
std::cout << "Done delaunay " << numCells << " cells, " << numPoints << " points" << std::endl;

      // Stop
      if (iter >= m_MaximumIterations)
        break;

    } // if max change > tol

std::cout << "Bar SS" << std::endl;
    // Compute SS (sum of squares) of the bar lengths
    float lengthSS = 1e-10;
    float sum_length = 0;
    for (int cell = 0; cell < numCells; cell++)
    {
      vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();

      mesh->GetCellPoints(cell, ids);

      unsigned int n = ids->GetNumberOfIds();

      for (int i = 0; i < n; i++)
      {
        meshPoints->GetPoint(ids->GetId(i), x);
        for (int j = i+1; j < n; j++)
        {
          meshPoints->GetPoint(ids->GetId(j), y);

          float barlength = 0;
          for (int k = 0; k < 3; k++)
          {
            float t = (x[k] - y[k]);
            barlength += t*t;
          }
          barlength = sqrtf(barlength);

          lengthSS += barlength*barlength;

          sum_length += barlength;
        }
      }

    }

    float root_lengthSS = sqrtf(lengthSS);
std::cout << "root_lengthSS = " << root_lengthSS << std::endl;

    VectorType mid(3, 0);

std::cout << "h func SS" << std::endl;
    // Compute SS of the size function at edge midpoints
    float hmidSS = 1e-10;
    float sum_hmid = 1e-10;
    for (int cell = 0; cell < numCells; cell++)
    {
      vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();

      mesh->GetCellPoints(cell, ids);

      unsigned int n = ids->GetNumberOfIds();

      for (int i = 0; i < n; i++)
      {
        meshPoints->GetPoint(ids->GetId(i), x);
        for (int j = i+1; j < n; j++)
        {
          meshPoints->GetPoint(ids->GetId(j), y);

          for (int k = 0; k < 3; k++)
            mid[k] = (x[k] + y[k]) / 2.0;

          float hmid = this->EvaluateHFunction(mid[0], mid[1], mid[2]);

          sum_hmid += hmid;

          hmidSS += hmid*hmid;
        }
      }

    }

    float root_hmidSS = sqrtf(hmidSS);
std::cout << "root_hmidSS = " << root_hmidSS << std::endl;

    float scalef = sqrtf(lengthSS / hmidSS);

    VectorType dir(3, 0);

    MatrixType dispMatrix(numPoints, 3, 0.0);

    //
    // Compute displacements due to spring forces at each point
    //
    // F = (L - L0) * unit(x-y)
    //
    // L0 is the root of SS of bar lengths, divided among edges
    // using weight = Fscale * h_mid / root_SS(h_mid)
    //

std::cout << "Spring disps" << std::endl;
    for (int cell = 0; cell < numCells; cell++)
    {
      vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();

      mesh->GetCellPoints(cell, ids);

      unsigned int n = ids->GetNumberOfIds();

      for (int i = 0; i < n; i++)
      {
        unsigned int index_x = ids->GetId(i);

        meshPoints->GetPoint(index_x, x);

        for (int j = i+1; j < n; j++)
        {
          unsigned int index_y = ids->GetId(j);
 

          meshPoints->GetPoint(index_y, y);

          for (int k = 0; k < 3; k++)
          {
            mid[k] = (x[k] + y[k]) / 2.0;
            dir[k] = (x[k] - y[k]);
          }

          float hmid = this->EvaluateHFunction(mid[0], mid[1], mid[2]);

          float length = dir.magnitude();

          if (length < 1e-10)
            continue;

          float length0 = hmid * scalef *  m_FScale;
          //float length0 = hmid / sum_hmid * sum_length;

          dir /= length;

          //forces(x) = max(L0-L, 0) * dir
          //forces(y) = -forces(x)

          float forcemag = length0 - length;

//TODO
// No contraction?
          //if (forcemag < 0.0)
            //continue;

          // Compute displacements
          // dx =  fmag*dir*dt
          // dy = -fmag*dir*dt

          for (int k = 0; k < 3; k++)
          {
            float disp_k = forcemag * dir[k] * m_TimeStep;

            if (fixedPointMarkers[index_x] == 0)
              dispMatrix(index_x, k) += disp_k;
            if (fixedPointMarkers[index_y] == 0)
              dispMatrix(index_y, k) -= disp_k;
          }

        }
      }

    }

    // Move the points according to the force displacements
    for (int i = 0; i < numPoints; i++)
    {
//PP
      if (fixedPointMarkers[i] != 0)
        continue;

      meshPoints->GetPoint(i, x);

      for (int k = 0; k < 3; k++)
        x[k] += dispMatrix(i, k);

      FloatImagePointType p;
      p[0] = x[0];
      p[1] = x[1];
      p[2] = x[2];

      if (!m_BrainDistanceInterpolator->IsInsideBuffer(p))
      {
        for (int d = 0; d < 3; d++)
          dispMatrix(i, d) = 0.0;
        continue;
      }

      float phi = m_BrainDistanceInterpolator->Evaluate(p);

      if (phi > insideTol)
      {
        for (int d = 0; d < 3; d++)
          dispMatrix(i, d) = 0.0;
        continue;
      }

// PP
// TEST
// if boundary (fixed marker on) sliding movement
/*
      if (fixedPointMarkers[i] != 0)
      {
        BSplineInterpolatorType::CovariantVectorType grad_phi =
          m_BrainDistanceInterpolator->EvaluateDerivative(p);
        float norm = grad_phi.GetNorm();
        if (norm > 1e-10)
        {
          grad_phi /= norm;

          float dotn = 0;
          for (int dim = 0; dim < 3; dim++)
            dotn += grad_phi[dim] *  dispMatrix(i, dim);

          for (int dim = 0; dim < 3; dim++)
            dispMatrix(i, dim) -= dotn * grad_phi[dim];
        }

        meshPoints->GetPoint(i, x);
        for (int k = 0; k < 3; k++)
          x[k] += dispMatrix(i, k);
      } // Slide
*/

      meshPoints->SetPoint(i, x);
    }

    // Compute max internal displacements due to truss forces
std::cout << "Max interior change? " << std::endl;
    maxChange = 0;

    float aveChange = 0;

    for (int i = 0; i < numPoints; i++)
    {
      meshPoints->GetPoint(i, x);

      FloatImagePointType p;
      p[0] = x[0];
      p[1] = x[1];
      p[2] = x[2];

      if (!m_BrainDistanceInterpolator->IsInsideBuffer(p))
        continue;

      float phi = m_BrainDistanceInterpolator->Evaluate(p);

      if (phi < insideTol)
      {
        float d = 0;
        for (int k = 0; k < 3; k++)
        {
          float t = dispMatrix(i, k);
          d += t*t;
        }
        d = sqrtf(d);

        if (d > maxChange)
          maxChange = d;

        aveChange += d;
      }
    }

    aveChange /= numPoints;

std::cout << "Average change = " << aveChange << std::endl;

    //
    // Compute displacements that bring outside points back to boundary
    //
    // p = p - \phi* \frac{\nabla\phi}{|\nabla\phi|^2}
    //

std::cout << "Move external" << std::endl;
    for (int i = 0; i < numPoints; i++)
    {
// TODO
// if boundary no need to move back???
// already taken care of with sliding bc???
      if (fixedPointMarkers[i] != 0)
        continue;

      meshPoints->GetPoint(i, x);

      FloatImagePointType p;
      p[0] = x[0];
      p[1] = x[1];
      p[2] = x[2]; 

      if (!m_BrainDistanceInterpolator->IsInsideBuffer(p))
        continue;

      float phi = m_BrainDistanceInterpolator->Evaluate(p);

      if (phi > 0)
      {
        BSplineInterpolatorType::CovariantVectorType grad_phi =
          m_BrainDistanceInterpolator->EvaluateDerivative(p);

#if SECOND_ORDER_DIST
        float magsq = grad_phi.GetSquaredNorm();
        if (magsq < MIN_DISTANCEMAG)
          continue;
#else
        float magsq = 1.0;
#endif

        for (int k = 0; k < 3; k++)
        {
          float disp_k = phi / magsq * grad_phi[k];
          x[k] -= disp_k;
          dispMatrix(i, k) -= disp_k;
        }

        meshPoints->SetPoint(i, x);
      }
    }

    // Adjust edges that cross internal boundary
std::cout << "Move internal crossing" << std::endl;
    for (int label = 0; label < m_InternalDistanceInterpolators.GetSize(); label++)
    {
      for (int cell = 0; cell < numCells; cell++)
      {
        vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();

        mesh->GetCellPoints(cell, ids);

        unsigned int n = ids->GetNumberOfIds();

        for (int i = 0; i < n; i++)
        {
          unsigned int index_x = ids->GetId(i);

          if (fixedPointMarkers[index_x] != 0)
            continue;

          meshPoints->GetPoint(index_x, x);

          FloatImagePointType px;
          px[0] = x[0];
          px[1] = x[1];
          px[2] = x[2];

          if (!m_InternalDistanceInterpolators[label]->IsInsideBuffer(px))
            continue;

          float phi_x =
            m_InternalDistanceInterpolators[label]->Evaluate(px);

          // Don't move this point if it is close enough to boundary
          if (fabs(phi_x) < insideTol)
            continue;

          for (int j = i+1; j < n; j++)
          {
            unsigned int index_y = ids->GetId(j);

            if (fixedPointMarkers[index_y] != 0)
              continue;

            meshPoints->GetPoint(index_y, y);

            FloatImagePointType py;
            py[0] = y[0];
            py[1] = y[1];
            py[2] = y[2];

            if (!m_InternalDistanceInterpolators[label]->IsInsideBuffer(py))
              continue;

            float phi_y =
              m_InternalDistanceInterpolators[label]->Evaluate(py);

            if (fabs(phi_y) < insideTol)
              continue;

            if (vnl_math_sgn(phi_x) == vnl_math_sgn(phi_y))
              continue;

            // Move the point that is outside and closest to boundary

            // If x is closer to boundary, move it and stop
            if ((fabs(phi_x) < fabs(phi_y)) && (phi_x > 0))
            {
              BSplineInterpolatorType::CovariantVectorType grad_phi =
                m_InternalDistanceInterpolators[label]->EvaluateDerivative(px);

#if SECOND_ORDER_DIST
              float magsq = grad_phi.GetSquaredNorm();
              if (magsq < MIN_DISTANCEMAG)
                break;
#else
              float magsq = 1.0;
#endif

              for (int k = 0; k < 3; k++)
              {
                float disp_k = phi_x / magsq * grad_phi[k];
                x[k] -= disp_k;
                dispMatrix(index_x, k) -= disp_k;
              }

              px[0] = x[0];
              px[1] = x[1];
              px[2] = x[2];

              if (m_BrainDistanceInterpolator->IsInsideBuffer(px))
                meshPoints->SetPoint(index_x, x);

              break;
            }

// PP
            // Don't move any point in reference to this internal distance map
            // Process it in the other maps
            //if (phi_y <= 0)
              //continue;

            // Move y
            BSplineInterpolatorType::CovariantVectorType grad_phi =
              m_InternalDistanceInterpolators[label]->EvaluateDerivative(py);

#if SECOND_ORDER_DIST
            float magsq = grad_phi.GetSquaredNorm();
            if (magsq < MIN_DISTANCEMAG) 
              continue;
#else
            float magsq = 1.0;
#endif

            for (int k = 0; k < 3; k++)
            {
              float disp_k = phi_y / magsq * grad_phi[k];
              y[k] -= disp_k;
              dispMatrix(index_y, k) -= disp_k;
            }
  
            py[0] = y[0];
            py[1] = y[1];
            py[2] = y[2];

            if (m_BrainDistanceInterpolator->IsInsideBuffer(py))
              meshPoints->SetPoint(index_y, y);
          } // for j
        } // for i


      } // for cell
    } // for label

std::cout << "Max interior change = " << maxChange << std::endl;

    // Make sure all points are within bounds
    for (unsigned int i = 0; i < numPoints; i++)
    {
      meshPoints->GetPoint(i, x);
      for (unsigned int dim = 0; dim < 3; dim++)
      {
        if (x[dim] < bounds[2*dim])
          x[dim] = bounds[2*dim];
        if (x[dim] > bounds[2*dim+1])
          x[dim] = bounds[2*dim+1];
      }
      meshPoints->SetPoint(i, x);
    }

    // Converged when all interior nodes move less than tolerance
    bool converged = ((maxChange / m_InitialSpacing) < 1e-2);
    if (converged || iter > m_MaximumIterations)
      break;

  }

std::cout << "Done" << std::endl;

  return mesh;

}

void
BrainMeshGenerator
::ComputeInternalImages()
{

  // Threshold and compute distance transform for brain tissue
  // (wm, gm, and falx)
std::cout << "Whole brain dist map" << std::endl;
  ByteImageType::Pointer maskImg = ByteImageType::New();
  maskImg->CopyInformation(m_LabelImage);
  maskImg->SetRegions(m_LabelImage->GetLargestPossibleRegion());
  maskImg->Allocate();

  ByteImageIndexType ind;

  ByteImageSizeType size = m_LabelImage->GetLargestPossibleRegion().GetSize();

  maskImg->FillBuffer(0);
  for (ind[2] = 0; ind[2] < size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < size[0]; ind[0]++)
      {
        unsigned char c = m_LabelImage->GetPixel(ind);
        if (c == 1 || c == 2 || c == 4)
          maskImg->SetPixel(ind, 1);
      }

  typedef itk::SignedDanielssonDistanceMapImageFilter<
    ByteImageType, FloatImageType> DistanceMapFilterType;

  typedef itk::DiscreteGaussianImageFilter<FloatImageType, FloatImageType>
    SmootherType;

  {
    DistanceMapFilterType::Pointer distanceMapFilter =
      DistanceMapFilterType::New();

    distanceMapFilter->InsideIsPositiveOff();
    distanceMapFilter->SetInput(maskImg);
    distanceMapFilter->SquaredDistanceOff();
    distanceMapFilter->UseImageSpacingOn();

    distanceMapFilter->Update();

#if SMOOTH_DIST
    SmootherType::Pointer smoother = SmootherType::New();
    smoother->SetInput(distanceMapFilter->GetDistanceMap());
    smoother->SetVariance(4.0);
    smoother->Update();

    m_BrainDistanceInterpolator = BSplineInterpolatorType::New();
    m_BrainDistanceInterpolator->SetInputImage(smoother->GetOutput());
    m_BrainDistanceInterpolator->SetSplineOrder(
      m_SplineInterpolationOrder);
    m_BrainDistanceImage = smoother->GetOutput();
#else
    m_BrainDistanceInterpolator = BSplineInterpolatorType::New();
    m_BrainDistanceInterpolator->SetInputImage(distanceMapFilter->GetDistanceMap());
    m_BrainDistanceInterpolator->SetSplineOrder(m_SplineInterpolationOrder);
    m_BrainDistanceImage = distanceMapFilter->GetDistanceMap();
#endif
  }

  // Compute maximum label
  m_MaxLabel = 0;
  for (ind[2] = 0; ind[2] < size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < size[0]; ind[0]++)
      {
        unsigned int c = m_LabelImage->GetPixel(ind);
        if (c > m_MaxLabel)
          m_MaxLabel = c;
      }

/*
std::cout << "Max label = " << m_MaxLabel << std::endl;
  if (m_MaxLabel < 4)
    muExceptionMacro(<< "Bad labels, need wm, gm, ventricles, and falx");
*/

  // Threshold and compute distance transform for each class
  m_InternalDistanceInterpolators.Clear();
  for (unsigned int label = 1; label <= m_MaxLabel; label++)
  {
std::cout << "Internal brain dist map " << label << std::endl;
    maskImg->FillBuffer(0);
    for (ind[2] = 0; ind[2] < size[2]; ind[2]++)
      for (ind[1] = 0; ind[1] < size[1]; ind[1]++)
        for (ind[0] = 0; ind[0] < size[0]; ind[0]++)
        {
          if (m_LabelImage->GetPixel(ind) == label)
            maskImg->SetPixel(ind, 1);
        }

    DistanceMapFilterType::Pointer distanceMapFilter =
      DistanceMapFilterType::New();

    distanceMapFilter->InsideIsPositiveOff();
    distanceMapFilter->SetInput(maskImg);
    distanceMapFilter->SquaredDistanceOff();
    distanceMapFilter->UseImageSpacingOn();

    distanceMapFilter->Update();

#if SMOOTH_DIST
    SmootherType::Pointer smoother = SmootherType::New();
    smoother->SetInput(distanceMapFilter->GetDistanceMap());
    smoother->SetVariance(4.0);
    smoother->Update();

    BSplineInterpolatorType::Pointer bsplineInt =
      BSplineInterpolatorType::New();
    bsplineInt->SetInputImage(smoother->GetOutput());
    bsplineInt->SetSplineOrder(m_SplineInterpolationOrder);
#else
    BSplineInterpolatorType::Pointer bsplineInt =
      BSplineInterpolatorType::New();
    bsplineInt->SetInputImage(distanceMapFilter->GetDistanceMap());
    bsplineInt->SetSplineOrder(m_SplineInterpolationOrder);
#endif

    m_InternalDistanceInterpolators.Append(bsplineInt);
  }

}

float
BrainMeshGenerator
::EvaluateHFunction(float x, float y, float z)
{

#if 0

  return 1.0;

#else

  // Find closest boundary, either internal or external
  FloatImagePointType p;
  p[0] = x;
  p[1] = y;
  p[2] = z;

  if (!m_BrainDistanceInterpolator->IsInsideBuffer(p))
    return m_InitialSpacing + 1.0;

  float phi = m_BrainDistanceInterpolator->Evaluate(p);
  phi = fabs(phi);

  // If far enough outside brain return default size
  if (phi > m_InitialSpacing)
    return m_InitialSpacing + 1.0;

  phi = vnl_huge_val(1.0);

  for (unsigned int i = 0; i < m_InternalDistanceInterpolators.GetSize(); i++)
  {
    float phi_i =
     m_InternalDistanceInterpolators[i]->Evaluate(p);

    phi_i = fabs(phi_i);

    if (phi_i < phi)
      phi = phi_i;
  }

  // Bound h to get homogeneous tetrahedra in internal regions
  if (phi > m_InitialSpacing)
    phi = m_InitialSpacing;

  phi += 1.0;

  return phi;

#endif

}

float
BrainMeshGenerator
::EvaluateInternalHFunction(float x, float y, float z, unsigned int i)
{
#if 1
  FloatImagePointType p;
  p[0] = x;
  p[1] = y;
  p[2] = z;

  if (!m_InternalDistanceInterpolators[i]->IsInsideBuffer(p))
    return m_InitialSpacing;

  float phi = m_InternalDistanceInterpolators[i]->Evaluate(p);
  phi = fabs(phi);
  if (phi > m_InitialSpacing)
    phi = m_InitialSpacing;

  return phi;
#else
  return 1.0;
#endif
}

vtkSmartPointer<vtkUnstructuredGrid>
BrainMeshGenerator
::GenerateMeshSequentially()
{
  FloatImageSizeType size =
    m_LabelImage->GetLargestPossibleRegion().GetSize();

  FloatImageSpacingType spacing = m_LabelImage->GetSpacing();

  // Compute number of labels
  unsigned int maxLabel = 0;

  typedef itk::ImageRegionIteratorWithIndex<ByteImageType> LabelIteratorType;
  LabelIteratorType labelIt(
    m_LabelImage, m_LabelImage->GetLargestPossibleRegion());
  for (labelIt.GoToBegin(); !labelIt.IsAtEnd(); ++labelIt)
  {
    unsigned int label = labelIt.Get();
     if (label > maxLabel)
       maxLabel = label;
  }

  unsigned int numLabels = maxLabel+1;

  // Bounds for Delaunay triangulation
  double bounds[6];
  for (int dim = 0; dim < 3; dim++)
  {
    bounds[2*dim] = 0.0;
    bounds[2*dim + 1] = size[dim]*spacing[dim];
  }

  // Initialize point locator
  vtkSmartPointer<vtkPoints> globPoints =
    vtkSmartPointer<vtkPoints>::New();
  globPoints->SetDataTypeToFloat();
  globPoints->Allocate(100000);

  vtkSmartPointer<vtkPointLocator> globPLoc =
    vtkSmartPointer<vtkPointLocator>::New();
  //globPLoc->SetTolerance(m_InitialSpacing*0.25);
  globPLoc->SetTolerance(1.0);
  globPLoc->InitPointInsertion(globPoints, bounds, 1000000);

  for (unsigned int i = 0; i < m_InternalDistanceInterpolators.GetSize(); i++)
    this->ProcessRegion(globPLoc, i);

  // Final Delaunay on combined points
std::cout << "\n=========\n" << std::endl;
std::cout << "Final Delaunay 3D" << std::endl;
std::cout << "# of points = " << globPLoc->GetPoints()->GetNumberOfPoints() << std::endl;

  vtkSmartPointer<vtkUnstructuredGrid> mesh =
    createMesh3D(globPLoc->GetPoints(), m_BrainDistanceImage);

  return mesh;
}

void
BrainMeshGenerator
::ProcessRegion(vtkPointLocator* globalPLoc, unsigned int which)
{
  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();

  FloatImageSizeType size =
    m_LabelImage->GetLargestPossibleRegion().GetSize();

  FloatImageSpacingType spacing = m_LabelImage->GetSpacing();

  // Generate list of points
  // Spaced uniformly with distance to brain boundary <= cube diagonal
std::cout << "Initial set of points" << std::endl;
  float insideTol = m_InitialSpacing * 0.1;

  vtkSmartPointer<vtkPoints> meshPoints = vtkSmartPointer<vtkPoints>::New();
  meshPoints->SetDataTypeToFloat();

//PP
// No need for explicit boundaries???
#if 1
  // Use contour filter for boundary points
std::cout << "Contour filter..." << std::endl;
  FloatImagePointer contourMask = FloatImageType::New();
  contourMask->CopyInformation(m_LabelImage);
  contourMask->SetRegions(m_LabelImage->GetLargestPossibleRegion());
  contourMask->Allocate();
  contourMask->FillBuffer(0);

  FloatImagePointer distImg =
    (FloatImageType*)m_InternalDistanceInterpolators[which]->GetInputImage();

  typedef itk::ImageRegionIteratorWithIndex<FloatImageType>
    DistanceIteratorType;
  DistanceIteratorType distIt(
    distImg, distImg->GetLargestPossibleRegion());

  for (distIt.GoToBegin(); !distIt.IsAtEnd(); ++distIt)
  {
    if (distIt.Get() > 0.01)
      continue;

    contourMask->SetPixel(distIt.GetIndex(), 1.0);
  }

  typedef itk::VTKImageExport<FloatImageType> ITKExportType;
  ITKExportType::Pointer itkexport = ITKExportType::New();
  itkexport->SetInput(contourMask);
  itkexport->Update();

  // See InsightApplications/Auxialiary/vtk/itkImageToVTKImageFilter
  vtkSmartPointer<vtkImageImport> vtkimport = 
    vtkSmartPointer<vtkImageImport>::New();
  vtkimport->SetUpdateInformationCallback(
    itkexport->GetUpdateInformationCallback());
  vtkimport->SetPipelineModifiedCallback(
    itkexport->GetPipelineModifiedCallback());
  vtkimport->SetWholeExtentCallback(itkexport->GetWholeExtentCallback());
  vtkimport->SetSpacingCallback(itkexport->GetSpacingCallback());
  vtkimport->SetOriginCallback(itkexport->GetOriginCallback());
  vtkimport->SetScalarTypeCallback(itkexport->GetScalarTypeCallback());
  vtkimport->SetNumberOfComponentsCallback(itkexport->GetNumberOfComponentsCallback());
  vtkimport->SetPropagateUpdateExtentCallback(itkexport->GetPropagateUpdateExtentCallback());
  vtkimport->SetUpdateDataCallback(itkexport->GetUpdateDataCallback());
  vtkimport->SetDataExtentCallback(itkexport->GetDataExtentCallback());
  vtkimport->SetBufferPointerCallback(itkexport->GetBufferPointerCallback());
  vtkimport->SetCallbackUserData(itkexport->GetCallbackUserData());

  vtkSmartPointer<vtkContourFilter> contourf = 
    vtkSmartPointer<vtkContourFilter>::New();
  contourf->SetInput(vtkimport->GetOutput());
  contourf->SetNumberOfContours(1);
  contourf->SetValue(0, 1.0);
  contourf->ComputeNormalsOff();
  contourf->ComputeGradientsOff();

  contourf->Update();

  vtkPolyData* contourPD = contourf->GetOutput();

std::cout << "Contour filter: " << contourPD->GetNumberOfPoints() << " boundary points" << std::endl;

  for (vtkIdType k = 0; k < contourPD->GetNumberOfPoints(); k++)
  {
    double x[3];
    contourPD->GetPoint(k, x);

    meshPoints->InsertNextPoint(x);
  }

  // Clean up contour filtering vars
  contourMask = 0;
#endif

  // Insert equally spaced points inside regions
  BSplineInterpolatorType::Pointer distInterp =
    m_InternalDistanceInterpolators[which];

  float lenx = (size[0]-1) * spacing[0];
  float leny = (size[1]-1) * spacing[1];
  float lenz = (size[2]-1) * spacing[2];

  float randomDisp = 0.1*m_InitialSpacing;

  for (float pz = 0; pz <= lenz; pz += m_InitialSpacing)
    for (float py = 0; py <= leny; py += m_InitialSpacing)
      for (float px = 0; px <= lenx; px += m_InitialSpacing)
      {
        FloatImagePointType p;
        p[0] = px;
        p[1] = py;
        p[2] = pz;

/*
        for (int dim = 0; dim < 3; dim++)
        {
          float r = 2.0*rng->GenerateUniformRealOpenInterval() - 1.0;
          p[dim] += r * randomDisp;
        }
*/

        if (!distInterp->IsInsideBuffer(p))
          continue;

        float phi = distInterp->Evaluate(p);

        // Skip outside points
        if (phi > insideTol)
          continue;

        meshPoints->InsertNextPoint(p[0], p[1], p[2]);
      }

  // Bounds for Delaunay triangulation
  double bounds[6];
  for (int dim = 0; dim < 3; dim++)
  {
    bounds[2*dim] = 0.0;
    bounds[2*dim + 1] = size[dim]*spacing[dim];
  }

  vtkSmartPointer<vtkUnstructuredGrid> mesh =
    //createMesh3D(meshPoints, m_BrainDistanceImage);
    createMesh3D(meshPoints, m_InternalDistanceInterpolators[which]->GetInputImage());

  LinearTetrahedralMesh femesh;

  // Begin loop
  unsigned int numCells = 0;
  unsigned int numPoints = meshPoints->GetNumberOfPoints();
std::cout << "Initially " << numPoints << " points" << std::endl;

  float maxChange = m_InitialSpacing;

  unsigned int iter = 0;

  double x[3];
  double y[3];

  while (true)
  {
    ++iter;

std::cout << "---------------------------" << std::endl;
std::cout << "Loop iteration " << iter << std::endl;
    // Recompute Delaunay triangulation if change is higher than threshold
    //if ((maxChange/m_InitialSpacing) > 0.1)
    //if ((iter % 5) == 1)
    {

std::cout << "Recompute Delaunay tesselation" << std::endl;
std::cout << "Current count: " << meshPoints->GetNumberOfPoints() << " points" << std::endl;

      // Drop duplicate points
std::cout << "Dropping duplicate points" << std::endl;
      vtkSmartPointer<vtkPoints> filteredPoints =
        vtkSmartPointer<vtkPoints>::New();
      filteredPoints->SetDataTypeToFloat();
      filteredPoints->Allocate(numPoints);

      vtkSmartPointer<vtkPointLocator> pLoc =
        vtkSmartPointer<vtkPointLocator>::New();
      //pLoc->SetTolerance(m_InitialSpacing*0.25);
      pLoc->SetTolerance(1.0);
      pLoc->InitPointInsertion(filteredPoints, bounds, numPoints);
      for (int i = 0; i < numPoints; i++)
      {
        meshPoints->GetPoint(i, x);

        int id = pLoc->IsInsertedPoint(x);

        if (id < 0)
          pLoc->InsertNextPoint(x);
      
      } // for i

std::cout << "After removing dupes: " << filteredPoints->GetNumberOfPoints() << " remaining" << std::endl;

      meshPoints = pLoc->GetPoints();

      mesh = createMesh3D(meshPoints, m_BrainDistanceImage);

      meshPoints = mesh->GetPoints();

      numCells = mesh->GetNumberOfCells();
      numPoints = mesh->GetNumberOfPoints();
std::cout << "Done delaunay " << numCells << " cells, " << numPoints << " points" << std::endl;

      // Stop
      if (iter >= m_MaximumIterations)
        break;

    } // if max change > tol

std::cout << "Bar SS" << std::endl;
    // Compute SS (sum of squares) of the bar lengths
    float lengthSS = 1e-10;
    float sum_length = 0;
    for (int cell = 0; cell < numCells; cell++)
    {
      vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();

      mesh->GetCellPoints(cell, ids);

      unsigned int n = ids->GetNumberOfIds();

      for (int i = 0; i < n; i++)
      {
        meshPoints->GetPoint(ids->GetId(i), x);
        for (int j = i+1; j < n; j++)
        {
          meshPoints->GetPoint(ids->GetId(j), y);

          float barlength = 0;
          for (int k = 0; k < 3; k++)
          {
            float t = (x[k] - y[k]);
            barlength += t*t;
          }
          barlength = sqrtf(barlength);

          lengthSS += barlength*barlength;

          sum_length += barlength;
        }
      }

    }

    float root_lengthSS = sqrtf(lengthSS);
std::cout << "root_lengthSS = " << root_lengthSS << std::endl;

    VectorType mid(3, 0);

std::cout << "h func SS" << std::endl;
    // Compute SS of the size function at edge midpoints
    float hmidSS = 1e-10;
    float sum_hmid = 1e-10;
    for (int cell = 0; cell < numCells; cell++)
    {
      vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();

      mesh->GetCellPoints(cell, ids);

      unsigned int n = ids->GetNumberOfIds();

      for (int i = 0; i < n; i++)
      {
        meshPoints->GetPoint(ids->GetId(i), x);
        for (int j = i+1; j < n; j++)
        {
          meshPoints->GetPoint(ids->GetId(j), y);

          for (int k = 0; k < 3; k++)
            mid[k] = (x[k] + y[k]) / 2.0;

          float hmid =
            this->EvaluateInternalHFunction(mid[0], mid[1], mid[2], which);

          sum_hmid += hmid;

          hmidSS += hmid*hmid;
        }
      }

    }

    float root_hmidSS = sqrtf(hmidSS);
std::cout << "root_hmidSS = " << root_hmidSS << std::endl;

    float scalef = sqrtf(lengthSS / hmidSS);

    VectorType dir(3, 0);

    MatrixType dispMatrix(numPoints, 3, 0.0);

    //
    // Compute displacements due to spring forces at each point
    //
    // F = (L - L0) * unit(x-y)
    //
    // L0 is the root of SS of bar lengths, divided among edges
    // using weight = Fscale * h_mid / root_SS(h_mid)
    //

std::cout << "Spring disps" << std::endl;
    for (int cell = 0; cell < numCells; cell++)
    {
      vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();

      mesh->GetCellPoints(cell, ids);

      unsigned int n = ids->GetNumberOfIds();

      for (int i = 0; i < n; i++)
      {
        unsigned int index_x = ids->GetId(i);

        meshPoints->GetPoint(index_x, x);

        for (int j = i+1; j < n; j++)
        {
          unsigned int index_y = ids->GetId(j);
 

          meshPoints->GetPoint(index_y, y);

          for (int k = 0; k < 3; k++)
          {
            mid[k] = (x[k] + y[k]) / 2.0;
            dir[k] = (x[k] - y[k]);
          }

          float hmid =
            this->EvaluateInternalHFunction(mid[0], mid[1], mid[2], which);

          float length = dir.magnitude();

          if (length < 1e-10)
            continue;

          float length0 = hmid * scalef *  m_FScale;
          //float length0 = hmid / sum_hmid * sum_length;

          dir /= length;

          //forces(x) = max(L0-L, 0) * dir
          //forces(y) = -forces(x)

          float forcemag = length0 - length;
// PP
// No contraction?
          //if (forcemag < 0.0)
            //continue;

          // Compute displacements
          // dx =  fmag*dir*dt
          // dy = -fmag*dir*dt

          for (int k = 0; k < 3; k++)
          {
            float disp_k = forcemag * dir[k] * m_TimeStep;

            dispMatrix(index_x, k) += disp_k;
            dispMatrix(index_y, k) -= disp_k;
          }

        }
      }

    }

    // Move the points according to the force displacements
    for (int i = 0; i < numPoints; i++)
    {
      meshPoints->GetPoint(i, x);

      for (int k = 0; k < 3; k++)
        x[k] += dispMatrix(i, k);

      FloatImagePointType p;
      p[0] = x[0];
      p[1] = x[1];
      p[2] = x[2];

      if (!distInterp->IsInsideBuffer(p))
      {
        for (int d = 0; d < 3; d++)
          dispMatrix(i, d) = 0.0;
        continue;
      }

      float phi = distInterp->Evaluate(p);

      if (phi > insideTol)
      {
        for (int d = 0; d < 3; d++)
          dispMatrix(i, d) = 0.0;
        continue;
      }

      meshPoints->SetPoint(i, x);
    }

    // Compute max internal displacements due to truss forces
std::cout << "Max interior change? " << std::endl;
    maxChange = 0;

    float aveChange = 0;

    for (int i = 0; i < numPoints; i++)
    {
      meshPoints->GetPoint(i, x);

      FloatImagePointType p;
      p[0] = x[0];
      p[1] = x[1];
      p[2] = x[2];

      if (!distInterp->IsInsideBuffer(p))
        continue;

      float phi = distInterp->Evaluate(p);

      if (phi < insideTol)
      {
        float d = 0;
        for (int k = 0; k < 3; k++)
        {
          float t = dispMatrix(i, k);
          d += t*t;
        }
        d = sqrtf(d);

        if (d > maxChange)
          maxChange = d;

        aveChange += d;
      }
    }

    aveChange /= numPoints;

std::cout << "Average change = " << aveChange << std::endl;

    //
    // Compute displacements that bring outside points back to boundary
    //
    // p = p - \phi* \frac{\nabla\phi}{|\nabla\phi|^2}
    //

std::cout << "Move external" << std::endl;

    for (int i = 0; i < numPoints; i++)
    {
      meshPoints->GetPoint(i, x);

      FloatImagePointType p;
      p[0] = x[0];
      p[1] = x[1];
      p[2] = x[2]; 

      if (!distInterp->IsInsideBuffer(p))
        continue;

      float phi = distInterp->Evaluate(p);

      if (phi > 0)
      {
        BSplineInterpolatorType::CovariantVectorType grad_phi =
          distInterp->EvaluateDerivative(p);

#if SECOND_ORDER_DIST
        float magsq = grad_phi.GetSquaredNorm();
        if (magsq < MIN_DISTANCEMAG)
          continue;
#else
        float magsq = 1.0;
#endif

        for (int k = 0; k < 3; k++)
        {
          float disp_k = phi / magsq * grad_phi[k];
          x[k] -= disp_k;
          dispMatrix(i, k) -= disp_k;
        }

        meshPoints->SetPoint(i, x);
      }
    }

    // Make sure all points are within bounds
    for (unsigned int i = 0; i < numPoints; i++)
    {
      meshPoints->GetPoint(i, x);
      for (unsigned int dim = 0; dim < 3; dim++)
      {
        if (x[dim] < bounds[2*dim])
          x[dim] = bounds[2*dim];
        if (x[dim] > bounds[2*dim+1])
          x[dim] = bounds[2*dim+1];
      }
      meshPoints->SetPoint(i, x);
    }

    // Converged when all interior nodes move less than tolerance
    bool converged = ((maxChange / m_InitialSpacing) < 1e-3);
    if (converged || iter > m_MaximumIterations)
      break;

  } // main loop

  // Insert final mesh points to the global point locator
  for (unsigned int i = 0; i < meshPoints->GetNumberOfPoints(); i++)
  {
    meshPoints->GetPoint(i, x);

    int id = globalPLoc->IsInsertedPoint(x);
    if (id < 0)
      globalPLoc->InsertNextPoint(x);
  }

}
