
#include "itkLinearInterpolateImageFunction.h"
#include "itkImage.h"
#include "itkImageFileReader.h"

#include "vtkUnstructuredGrid.h"
#include "vtkUnstructuredGridReader.h"

#include "vtkActor.h"
#include "vtkCommand.h"
#include "vtkDataSetMapper.h"
#include "vtkDataSetSurfaceFilter.h"
#include "vtkPolyDataMapper.h"
#include "vtkExtractEdges.h"
#include "vtkFeatureEdges.h"
#include "vtkIdList.h"
#include "vtkProperty.h"
#include "vtkRenderer.h"
#include "vtkRenderWindow.h"
#include "vtkRenderWindowInteractor.h"
#include "vtkSmartPointer.h"
#include "vtkTetra.h"
#include "vtkTubeFilter.h"

#include <iostream>

int main(int argc, char** argv)
{

  if (argc != 3)
  {
    std::cerr << "Usage: " << argv[0] << " <mesh.vtk> <mask>" << std::endl;
    return -1;
  }

  // Read mesh file
  vtkUnstructuredGridReader* meshReader = vtkUnstructuredGridReader::New();

  meshReader->SetFileName(argv[1]);
  meshReader->Update();

  vtkUnstructuredGrid* mesh = meshReader->GetOutput();

  int numCells = mesh->GetNumberOfCells();
  int numPoints = mesh->GetNumberOfPoints();

  std::cout << numCells << " cells" << std::endl;
  std::cout << numPoints << " points" << std::endl;

  // Read image
  typedef itk::Image<float, 3> FloatImageType;
  typedef itk::ImageFileReader<FloatImageType> ImageReaderType;

  ImageReaderType::Pointer imgReader = ImageReaderType::New();
  imgReader->SetFileName(argv[2]);
  imgReader->Update();

  FloatImageType::Pointer maskImg = imgReader->GetOutput();

  typedef itk::LinearInterpolateImageFunction<FloatImageType, double>
    InterpolatorType;
  InterpolatorType::Pointer maskInterp = InterpolatorType::New();
  maskInterp->SetInputImage(maskImg);

  // Compute z centroid of masked mesh
  double zInt = 0;
  //double maskWeight = 1e-20;
  double maskWeight = 0;

  double pt[3];
  for (int i = 0; i < numPoints; i++)
  {
    mesh->GetPoint(i, pt);

    FloatImageType::PointType p;
    p[0] = pt[0];
    p[1] = pt[1];
    p[2] = pt[2];

    if (!maskInterp->IsInsideBuffer(p))
    {
      std::cout << "Point outside buffer: " << p << std::endl;
      continue;
    }

    double m = maskInterp->Evaluate(p);
    zInt += m * p[2];
    maskWeight += m;

  }

  double zc = zInt / maskWeight;

  std::cout << "Z centroid = " << zc << std::endl;

  vtkUnstructuredGrid* croppedMesh = vtkUnstructuredGrid::New();
  vtkUnstructuredGrid* croppedTissueMesh = vtkUnstructuredGrid::New();
  vtkUnstructuredGrid* croppedCSFMesh = vtkUnstructuredGrid::New();

  croppedMesh->SetPoints(mesh->GetPoints());
  croppedTissueMesh->SetPoints(mesh->GetPoints());
  croppedCSFMesh->SetPoints(mesh->GetPoints());

  for (int i = 0; i < numCells; i++)
  {
    vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();

    mesh->GetCellPoints(i, ids);

    unsigned int n = ids->GetNumberOfIds();
    if (n != 4)
    {
      std::cerr << "Cell " << i << " is not a tetra" << std::endl;
      continue;
    }

    FloatImageType::PointType centroid;
    centroid.Fill(0);

    double x[3];

    bool useCell = true;
    for (unsigned int j = 0; j < n; j++)
    {
      mesh->GetPoint(ids->GetId(j), x);

      centroid[0] += x[0];
      centroid[1] += x[1];
      centroid[2] += x[2];

      float z = x[2];

      if (z < zc)
      {
        useCell = false; 
        break;
      }
    }

    centroid [0] /= n;
    centroid [1] /= n;
    centroid [2] /= n;

    if (useCell)
    {
      FloatImageType::IndexType ind;
      bool transformOK = maskImg->TransformPhysicalPointToIndex(centroid, ind);

      //float pcsf = 0;
      //if (transformOK)
      //  pcsf = maskImg->GetPixel(ind) / 255.0;
      float pcsf = maskImg->GetPixel(ind) / 255.0;

      if (pcsf > 0.1)
        croppedCSFMesh->InsertNextCell(VTK_TETRA, 4, ids->GetPointer(0));
      else
        croppedTissueMesh->InsertNextCell(VTK_TETRA, 4, ids->GetPointer(0));

      croppedMesh->InsertNextCell(VTK_TETRA, 4, ids->GetPointer(0));
    }

  }

  croppedMesh->Squeeze();
  croppedTissueMesh->Squeeze();
  croppedCSFMesh->Squeeze();

  // Mesh surface
  vtkDataSetMapper* tissueMeshMapper = vtkDataSetMapper::New();
  tissueMeshMapper->SetInput(croppedTissueMesh);

  vtkDataSetMapper* csfMeshMapper = vtkDataSetMapper::New();
  csfMeshMapper->SetInput(croppedCSFMesh);

  // Get edges
  vtkDataSetSurfaceFilter* surfext = vtkDataSetSurfaceFilter::New();
  surfext->SetInput(croppedMesh);

  vtkFeatureEdges* featext = vtkFeatureEdges::New();
  featext->ColoringOff();
  featext->BoundaryEdgesOn();
  featext->ManifoldEdgesOn();
  featext->NonManifoldEdgesOn();
  featext->SetInput(surfext->GetOutput());

  vtkTubeFilter* tubes = vtkTubeFilter::New();
  tubes->SetInput(featext->GetOutput());
  tubes->SetRadius(0.2);
  tubes->SetNumberOfSides(6);

  vtkPolyDataMapper* edgeMapper = vtkPolyDataMapper::New();
  edgeMapper->SetResolveCoincidentTopologyToPolygonOffset();
  edgeMapper->SetInput(tubes->GetOutput());
  //edgeMapper->SetInput(featext->GetOutput());

  vtkActor* tissueMeshActor = vtkActor::New();
  tissueMeshActor->SetMapper(tissueMeshMapper);
  tissueMeshActor->GetProperty()->SetColor(1.0, 1.0, 0.0);
  tissueMeshActor->GetProperty()->BackfaceCullingOff();

  vtkActor* csfMeshActor = vtkActor::New();
  csfMeshActor->SetMapper(csfMeshMapper);
  csfMeshActor->GetProperty()->SetColor(0.9, 0.1, 1.0);
  csfMeshActor->GetProperty()->BackfaceCullingOff();

  vtkActor* edgeActor = vtkActor::New();
  edgeActor->SetMapper(edgeMapper);
  edgeActor->GetProperty()->SetColor(0.1, 0.6, 0.1);

  // Display
  vtkRenderer* ren = vtkRenderer::New();
  vtkRenderWindow* renWin = vtkRenderWindow::New();
  vtkRenderWindowInteractor* iren = vtkRenderWindowInteractor::New();

  ren->AddActor(tissueMeshActor);
  ren->AddActor(csfMeshActor);
  ren->AddActor(edgeActor);
  ren->SetBackground(0, 0, 0);

  renWin->AddRenderer(ren);
  renWin->SetSize(500, 500);

  iren->SetRenderWindow(renWin);

  iren->Initialize();
  renWin->Render();

  // Switch to trackball mode
  iren->SetKeyEventInformation(0,0,'t',0,0);
  iren->InvokeEvent(vtkCommand::CharEvent, NULL);

  iren->Start();
}
