//#if defined(_MSC_VER)
//#pragma warning ( disable : 4786 )
//#endif

#ifndef __itkSkullStripping_txx
#define __itkSkullStripping_txx

#include "itkSkullStripping.h"

using namespace std;
using namespace TCLAP;

static const float PI = 3.1415926535897932;
#define X .525731112119133606
#define Z .850650808352039932

#ifndef M_PI
#define M_PI 3.1415926535897932
#endif
#ifndef M_PI_2
#define M_PI_2 1.5707963267948966
#endif

//Vertices, triangles, edges of a single icosahedron
static double vert[12][3] = {
  {-X, 0.0, Z}, {X, 0.0, Z}, {-X, 0.0, -Z}, {X, 0.0, -Z},
  {0.0, Z, X}, {0.0, Z, -X}, {0.0, -Z, X}, {0.0, -Z, -X},
  {Z, X, 0.0}, {-Z, X, 0.0}, {Z, -X, 0.0}, {-Z, -X, 0.0}
};
static int triang[20][3] = {
  {0,4,1}, {0,9,4}, {9,5,4}, {4,5,8}, {4,8,1},
  {8,10,1}, {8,3,10}, {5,3,8}, {5,2,3}, {2,7,3},
  {7,10,3}, {7,6,10}, {7,11,6}, {11,0,6}, {0,1,6},
  {6,1,10}, {9,0,11}, {9,11,2}, {9,2,5}, {7,2,11}
};
static int edge[30][2] = {
  {0,1}, {0,4}, {0,6}, {0,9}, {0,11}, {1,4}, {1,6}, {1,8}, {1,10}, {2,3},
  {2,5}, {2,7}, {2,9}, {2,11}, {3,5}, {3,7}, {3,8}, {3,10}, {4,5}, {4,8},
  {4,9}, {5,8}, {5,9}, {6,7}, {6,10}, {6,11}, {7,10}, {7,11}, {8,10}, {9,11}
};




namespace itk
{
namespace Statistics
{

itk::Point<float, 3> COG;


typedef std::vector<vtkIdType> NeighborhoodType;

template <class TInputImage, class TOutputImage>
SkullStripping<TInputImage, TOutputImage>
::SkullStripping()
{
}

template <class TInputImage, class TOutputImage>
SkullStripping<TInputImage, TOutputImage>
::~SkullStripping()
{
}

template <class TInputImage, class TOutputImage>
void
SkullStripping<TInputImage, TOutputImage>
::PrintSelf(std::ostream& os, Indent indent) const
{
   Superclass::PrintSelf( os, indent );
}


template <class TInputImage, class TOutputImage>
typename SkullStripping<TInputImage, TOutputImage>::LabelImagePointer 
SkullStripping<TInputImage, TOutputImage>
::BinaryErodeFilter3D ( LabelImageType* img , unsigned int ballsize )
{
  typedef itk::BinaryBallStructuringElement<LabelPixelType, ImageDimension> KernalType;
  typedef itk::BinaryErodeImageFilter<LabelImageType, LabelImageType, KernalType> ErodeFilterType;
  typename ErodeFilterType::Pointer erodeFilter = ErodeFilterType::New();
  erodeFilter->SetInput( img );

  KernalType ball;
  typename KernalType::SizeType ballSize;
  for (int k = 0; k < 3; k++)
    {
    ballSize[k] = ballsize;
    }
  ball.SetRadius(ballSize);
  ball.CreateStructuringElement();
  erodeFilter->SetKernel( ball );
  erodeFilter->Update();
  return erodeFilter->GetOutput();

}

template <class TInputImage, class TOutputImage>
typename SkullStripping<TInputImage, TOutputImage>::LabelImagePointer
SkullStripping<TInputImage, TOutputImage>
::BinaryDilateFilter3D ( LabelImageType *img , unsigned int ballsize )
{
  typedef itk::BinaryBallStructuringElement<LabelPixelType, ImageDimension> KernalType;
  typedef itk::BinaryDilateImageFilter<LabelImageType, LabelImageType, KernalType> DilateFilterType;
  typename DilateFilterType::Pointer dilateFilter = DilateFilterType::New();
  dilateFilter->SetInput( img );
  KernalType ball;
  typename KernalType::SizeType ballSize;
  for (int k = 0; k < 3; k++)
    {
    ballSize[k] = ballsize;
    }
  ball.SetRadius(ballSize);
  ball.CreateStructuringElement();
  dilateFilter->SetKernel( ball );
  dilateFilter->Update();
  return dilateFilter->GetOutput();
}

template <class TInputImage, class TOutputImage>
typename SkullStripping<TInputImage, TOutputImage>::LabelImagePointer 
SkullStripping<TInputImage, TOutputImage>
::BinaryOpeningFilter3D ( LabelImageType *img , unsigned int ballsize )
{
   typename LabelImageType::Pointer imgErode = BinaryErodeFilter3D( img, ballsize );
  return BinaryDilateFilter3D( imgErode, ballsize );
}


template <class TInputImage, class TOutputImage>
typename SkullStripping<TInputImage, TOutputImage>::LabelImagePointer 
SkullStripping<TInputImage, TOutputImage>
::BinaryClosingFilter3D ( LabelImageType *img , unsigned int ballsize )
{
  typename LabelImageType::Pointer imgDilate = BinaryDilateFilter3D( img, ballsize );
  return BinaryErodeFilter3D( imgDilate, ballsize );
}



template <class TInputImage, class TOutputImage>
void
SkullStripping<TInputImage, TOutputImage>
::PolyDataToLabelMap( vtkPolyData* polyData, LabelImageType *label)
{
  vtkPolyDataPointSampler* sampler = vtkPolyDataPointSampler::New();
  sampler->SetInput( polyData );
  sampler->SetDistance( 0.75 );
  sampler->GenerateEdgePointsOn();
  sampler->GenerateInteriorPointsOn();
  sampler->GenerateVertexPointsOn();
  sampler->Update();

  std::cout << polyData->GetNumberOfPoints() << std::endl;
  std::cout << sampler->GetOutput()->GetNumberOfPoints() << std::endl;

  label->FillBuffer( 0 );
  for (int k = 0; k < sampler->GetOutput()->GetNumberOfPoints(); k++)
  {
    double *pt = sampler->GetOutput()->GetPoint( k );
    typename LabelImageType::PointType pitk;
    pitk[0] = pt[0];
    pitk[1] = pt[1];
    pitk[2] = pt[2];
    typename LabelImageType::IndexType idx;
    label->TransformPhysicalPointToIndex( pitk, idx );

    if ( label->GetLargestPossibleRegion().IsInside(idx) )
    {
      label->SetPixel( idx, 255 );
    }
  }

  // do morphological closing
  int ballSize = 2;
  typename LabelImageType::Pointer closedLabel = BinaryClosingFilter3D( label, ballSize );
  //LabelImageType::Pointer closedLabel = BinaryDilateFilter3D( label2, 2 );

  itk::ImageRegionIteratorWithIndex<LabelImageType>
    itLabel (closedLabel, closedLabel->GetLargestPossibleRegion() );

  // do flood fill using binary threshold image function
  typedef itk::BinaryThresholdImageFunction<LabelImageType> ImageFunctionType;
  typename ImageFunctionType::Pointer func = ImageFunctionType::New();
  func->SetInputImage( closedLabel );
  func->ThresholdBelow(0.5);

  typename FloatImageType::IndexType idx;
  label->TransformPhysicalPointToIndex( COG, idx );

  itk::FloodFilledImageFunctionConditionalIterator<LabelImageType, ImageFunctionType>
    floodFill( closedLabel, func, idx );
  
  for (floodFill.GoToBegin(); !floodFill.IsAtEnd(); ++floodFill)
  {
    typename LabelImageType::IndexType i = floodFill.GetIndex();
    closedLabel->SetPixel( i, 255 );
  }

  typename LabelImageType::Pointer finalLabel = BinaryClosingFilter3D( closedLabel, ballSize );

  for (itLabel.GoToBegin(); !itLabel.IsAtEnd(); ++itLabel)
  {
     typename LabelImageType::IndexType i = itLabel.GetIndex();
     label->SetPixel( i, finalLabel->GetPixel(i) );
  }

  return;
}


template <class TInputImage, class TOutputImage>
typename SkullStripping<TInputImage, TOutputImage>::PixelType 
SkullStripping<TInputImage, TOutputImage>
::FindWhiteMatterPeak ( HistogramType* histogram )
{
  typename HistogramType::SizeType size = histogram->GetSize();
  std::cout << "Histogram size: " << size << std::endl;

  PixelType t95 = static_cast<PixelType> (histogram->Quantile(0, 0.95));
  std::cout << "t95 = " << t95 << std::endl;

  std::vector<typename ImageType::PixelType> intensity( size[0] );
  std::vector<float> frequency( size[0] );

  for (::size_t k = 0; k < size[0]; k++)
  {
    typename HistogramType::IndexType hidx;
    hidx[0] = k;
    typename ImageType::PixelType p = static_cast<typename ImageType::PixelType>( histogram->GetHistogramMinFromIndex (hidx)[0] );
    intensity[k] = p;
    frequency[k] = histogram->GetFrequency ( hidx );
  }

  // suppress the backgroud values
  frequency[0] = 0;
  frequency[1] = 0;

  // do simple five point average
  std::vector<unsigned long> smoothedfrequency( size[0] );
  for (::size_t k = 2; k < size[0]-2; k++)
  {
    double d = 0;
    for (int m = -2; m <=2; m++)
    {
      d += static_cast<double> (frequency[k+m]);
    }
    smoothedfrequency[k] = static_cast<unsigned long> ( d/5 );
    //std::cout << intensity[k] << " " << smoothedfrequency[k] << std::endl;
  }

  return t95;
}

template <class TInputImage, class TOutputImage>
void
SkullStripping<TInputImage, TOutputImage>
::ComputeVertexNeighbors(vtkIdType iVertexId, vtkPolyData* pMesh, std::vector<vtkIdType>& pIdRet)
{
  std::set<vtkIdType> setNeighbors;
  vtkIdType* pIncidentCells;
  unsigned short iNumCells;
      
  pMesh->GetPointCells(iVertexId, iNumCells, pIncidentCells);
            
  int j;
  vtkIdType* pIncidentPoints;
  vtkIdType iNumPoints;
                  
  for(int i=0; i<iNumCells; ++i)
    {
    pMesh->GetCellPoints(pIncidentCells[i], iNumPoints, pIncidentPoints);
    for(j=0; j<iNumPoints; ++j)
      if(pIncidentPoints[j]!=iVertexId)
        setNeighbors.insert(pIncidentPoints[j]);
    }
  
  // make pointIds in order
  std::vector<vtkIdType> pIds;
  for (std::set<vtkIdType>::iterator m = setNeighbors.begin(); m != setNeighbors.end(); m++)
    {
    pIds.push_back( *m );
    }
  
  // find first edge
  vtkIdType Id0 = pIds[0];
  vtkIdType Id1;
  for (::size_t k = 1; k < pIds.size(); k++)
    {
    if ( !pMesh->IsEdge( Id0, pIds[k]) )
      {
      continue;
      }
    Id1 = pIds[k];
    break;
    }

  // figure out if Id0 and Id1 is in the right order;
  double pc[3];
  memcpy( pc, pMesh->GetPoint(iVertexId), 3*sizeof(double) );
  double p0[3];
  memcpy( p0, pMesh->GetPoint(Id0), 3*sizeof(double) );
  double p1[3];
  memcpy( p1, pMesh->GetPoint(Id1), 3*sizeof(double) );
  for (int m = 0; m <3; m++)
    {
    p0[m] -= pc[m];
    p1[m] -= pc[m];
    }
  double op[3];
  op[0] = p0[1]*p1[2]-p0[2]*p1[1];
  op[1] = p0[2]*p1[0]-p0[0]*p1[2];
  op[2] = p0[0]*p1[1]-p0[1]*p1[0];
  double ip = op[0]*pc[0]+op[1]*pc[1]+op[2]*pc[2];
  if (ip < 0) // swap Id0 and Id1;
    {
    vtkIdType tempId = Id0;
    Id0 = Id1;
    Id1 = tempId;
    }

  pIdRet.push_back( Id0 );
  pIdRet.push_back( Id1 );

  for (::size_t k = 2; k < pIds.size(); k++)
    {
    vtkIdType curentId = pIdRet[k-1];
    for (::size_t m = 0; m < pIds.size(); m++)
      {
      vtkIdType tempId = pIds[m];
      int InSet = 0;
      for (::size_t n = 0; n < pIdRet.size(); n++)
        {
        if (pIdRet[n] == tempId)
          {
          InSet = 1;
          break;
          }
        }

      if (InSet == 1)
        {
        continue;
        }

      if ( pMesh->IsEdge(curentId, tempId) )
        {
        pIdRet.push_back( tempId );
        break;
        }
      }
    }
}
                    
template <class TInputImage, class TOutputImage>
vtkPolyData* 
SkullStripping<TInputImage, TOutputImage>
::TessellateIcosahedron(int level)
{

  //Calculate n_vertex, n_triag
  int n=0;
  if(level > 2) 
    {
    for(int i=1; i<(level-1); i++) 
      n += i;
    }
  int n_vert = 12 + (level - 1)*30 + n*20;
  int numtriags = 0;
  if(level == 1)
    {
    numtriags = 20;
    }
  else 
    {
    n = 1;
    do
      {
      for(int m=1; m<=n; m++)
        {
        numtriags = numtriags + 3;
        if(m != n)
          numtriags = numtriags + 3;
      
        }
      n++;
      }while(n<=level);
    numtriags = numtriags * 20;
    numtriags = numtriags / 3;
    }

  typedef double Point3[3];
  //Allocate datas
  Point3* all_vert = new Point3[n_vert];
  Point3* all_triangs = new Point3[numtriags*3];//all possible vertices in triangs
  int * triangs = new int[3*numtriags];

  int i, m, k;
  double x1, x2, y1, y2, z1, z2, x3, y3, z3; 
  double dx12, dy12, dz12, dx23, dy23, dz23;
  double length;   
   
  double epsilon = 0.00001;//machine epsilon??
   
  memcpy(all_vert, vert, 12*sizeof(Point3));
   
  //std::cout<<"after memcpy"<<std::endl;
   
  k=12;
  for(i=0; i<30; i++) 
    {
    x1 = vert[edge[i][0] ][0];
    y1 = vert[edge[i][0] ][1];
    z1 = vert[edge[i][0] ][2];
    x2 = vert[edge[i][1] ][0];
    y2 = vert[edge[i][1] ][1];
    z2 = vert[edge[i][1] ][2];
    dx12 = (x2 - x1)/level;
    dy12 = (y2 - y1)/level;
    dz12 = (z2 - z1)/level;
    for(n=1; n<level; n++) 
      {
      all_vert[k][0] = x1 + n*dx12;
      all_vert[k][1] = y1 + n*dy12;
      all_vert[k][2] = z1 + n*dz12;
      length = sqrt(static_cast<double> (all_vert[k][0]*all_vert[k][0]+
                                         all_vert[k][1]*all_vert[k][1]+ 
                                         all_vert[k][2]*all_vert[k][2]));
      all_vert[k][0] /= length;
      all_vert[k][1] /= length;
      all_vert[k][2] /= length;
      k++;
      }
    }

  if(level > 2) 
    {
    for(i=0; i<20; i++) 
      {
      x1 = vert[triang[i][0] ][0];
      y1 = vert[triang[i][0] ][1];
      z1 = vert[triang[i][0] ][2];
      x2 = vert[triang[i][1] ][0];
      y2 = vert[triang[i][1] ][1];
      z2 = vert[triang[i][1] ][2];
      x3 = vert[triang[i][2] ][0];
      y3 = vert[triang[i][2] ][1];
      z3 = vert[triang[i][2] ][2];
      dx12 = (x2 - x1)/level;
      dy12 = (y2 - y1)/level;
      dz12 = (z2 - z1)/level;
      dx23 = (x3 - x2)/level;
      dy23 = (y3 - y2)/level;
      dz23 = (z3 - z2)/level;

      n = 1;
      do 
        {
        for(m=1; m<=n; m++) 
          {
          all_vert[k][0] = x1 + (n+1)*dx12 + m*dx23;
          all_vert[k][1] = y1 + (n+1)*dy12 + m*dy23;
          all_vert[k][2] = z1 + (n+1)*dz12 + m*dz23;
          length = sqrt((double) all_vert[k][0]*all_vert[k][0]+
                        all_vert[k][1]*all_vert[k][1]+
                        all_vert[k][2]*all_vert[k][2]);
          all_vert[k][0] /= length;
          all_vert[k][1] /= length;
          all_vert[k][2] /= length;
          k++;
          }
        n++;
        }while( n<=(level-2) );
      }
    }
  numtriags=0;
   
  //std::cout<<"before get triangulation"<<std::endl;   
  //std::cout<<n_triangs<<std::endl;
   
  // get triangulation
  if (level > 1) 
    {
    for(i=0; i<20; i++) 
      {
      x1 = vert[triang[i][0] ][0];
      y1 = vert[triang[i][0] ][1];
      z1 = vert[triang[i][0] ][2];
      x2 = vert[triang[i][1] ][0];
      y2 = vert[triang[i][1] ][1];
      z2 = vert[triang[i][1] ][2];
      x3 = vert[triang[i][2] ][0];
      y3 = vert[triang[i][2] ][1];
      z3 = vert[triang[i][2] ][2];
      dx12 = (x2 - x1)/level;
      dy12 = (y2 - y1)/level;
      dz12 = (z2 - z1)/level;
      dx23 = (x3 - x2)/level;
      dy23 = (y3 - y2)/level;
      dz23 = (z3 - z2)/level;

      n = 1;
      do 
        {
        for(m=1; m<=n; m++) 
          {
          // Draw lower triangle
          all_triangs[numtriags][0] = x1 + n*dx12 + m*dx23;
          all_triangs[numtriags][1] = y1 + n*dy12 + m*dy23;
          all_triangs[numtriags][2] = z1 + n*dz12 + m*dz23;
          length = sqrt((double) all_triangs[numtriags][0]*all_triangs[numtriags][0]+
                        all_triangs[numtriags][1]*all_triangs[numtriags][1]+
                        all_triangs[numtriags][2]*all_triangs[numtriags][2]);
          all_triangs[numtriags][0] /= length;
          all_triangs[numtriags][1] /= length;
          all_triangs[numtriags][2] /= length;
          numtriags++;
          all_triangs[numtriags][0] = x1 + (n-1)*dx12 + (m-1)*dx23;
          all_triangs[numtriags][1] = y1 + (n-1)*dy12 + (m-1)*dy23;
          all_triangs[numtriags][2] = z1 + (n-1)*dz12 + (m-1)*dz23;
          length = sqrt((double) all_triangs[numtriags][0]*all_triangs[numtriags][0]+
                        all_triangs[numtriags][1]*all_triangs[numtriags][1]+
                        all_triangs[numtriags][2]*all_triangs[numtriags][2]);
          all_triangs[numtriags][0] /= length;
          all_triangs[numtriags][1] /= length;
          all_triangs[numtriags][2] /= length;
          numtriags++;
          all_triangs[numtriags][0] = x1 + n*dx12 + (m-1)*dx23;
          all_triangs[numtriags][1] = y1 + n*dy12 + (m-1)*dy23;
          all_triangs[numtriags][2] = z1 + n*dz12 + (m-1)*dz23;
          length = sqrt((double) all_triangs[numtriags][0]*all_triangs[numtriags][0]+
                        all_triangs[numtriags][1]*all_triangs[numtriags][1]+
                        all_triangs[numtriags][2]*all_triangs[numtriags][2]);
          all_triangs[numtriags][0] /= length;
          all_triangs[numtriags][1] /= length;
          all_triangs[numtriags][2] /= length;
          numtriags++;
          if ( m != n ) 
            {
            // Draw lower left triangle
            all_triangs[numtriags][0] = x1 + n*dx12 + m*dx23;
            all_triangs[numtriags][1] = y1 + n*dy12 + m*dy23;
            all_triangs[numtriags][2] = z1 + n*dz12 + m*dz23;
            length = sqrt((double) all_triangs[numtriags][0]*all_triangs[numtriags][0]+
                          all_triangs[numtriags][1]*all_triangs[numtriags][1]+
                          all_triangs[numtriags][2]*all_triangs[numtriags][2]);
            all_triangs[numtriags][0] /= length;
            all_triangs[numtriags][1] /= length;
            all_triangs[numtriags][2] /= length;
            numtriags++;
            all_triangs[numtriags][0] = x1 + (n-1)*dx12 + m*dx23;
            all_triangs[numtriags][1] = y1 + (n-1)*dy12 + m*dy23;
            all_triangs[numtriags][2] = z1 + (n-1)*dz12 + m*dz23;
            length = sqrt((double) all_triangs[numtriags][0]*all_triangs[numtriags][0]+
                          all_triangs[numtriags][1]*all_triangs[numtriags][1]+
                          all_triangs[numtriags][2]*all_triangs[numtriags][2]);
            all_triangs[numtriags][0] /= length;
            all_triangs[numtriags][1] /= length;
            all_triangs[numtriags][2] /= length;
            numtriags++;
            all_triangs[numtriags][0] = x1 + (n-1)*dx12 + (m-1)*dx23;
            all_triangs[numtriags][1] = y1 + (n-1)*dy12 + (m-1)*dy23;
            all_triangs[numtriags][2] = z1 + (n-1)*dz12 + (m-1)*dz23;
            length = sqrt((double) all_triangs[numtriags][0]*all_triangs[numtriags][0]+
                          all_triangs[numtriags][1]*all_triangs[numtriags][1]+
                          all_triangs[numtriags][2]*all_triangs[numtriags][2]);
            all_triangs[numtriags][0] /= length;
            all_triangs[numtriags][1] /= length;
            all_triangs[numtriags][2] /= length;
            numtriags++;
            }
          }
        n++;
        } while( n<=level );
      }
    }
   
  //std::cout<<"before indexing of triangs"<<std::endl;
   
  // indexing of triangs
  if (level == 1) 
    {
    memcpy(triangs, triang, 20*3*sizeof(int));
    numtriags = 20;
    } 
  else 
    {
    //find for every point in triangle list the corresponding index in all_vert
     
    // initialize
    for (i=0; i < numtriags; i ++) {
    triangs[i] = -1;
    }

    // find indexes
    for(i=0; i<n_vert; i++) 
      {
      for (int j = 0; j < numtriags; j++) 
        {
        if (triangs[j] < 0) 
          {
          if ( (fabs(all_vert[i][0] - all_triangs[j][0]) < epsilon) && 
               (fabs(all_vert[i][1] - all_triangs[j][1]) < epsilon) && 
               (fabs(all_vert[i][2] - all_triangs[j][2]) < epsilon ) ) 
            {
            triangs[j] = i;
            }
          }
        }
      }
     
    //for(i=0; i<n_vert; i++) 
    //  std::cout<<triangs[3*i]<<","<<triangs[3*i+1]<<","<<triangs[3*i+2]<<std::endl;

    for (i=0; i < numtriags; i ++) 
      {
      if (triangs[i] == -1)
        std::cerr << " - " << i << " :" << all_triangs[i][0] 
                  << "," << all_triangs[i][1] << "," << all_triangs[i][2] << std::endl;
      }
     
    // numtriags is the number of vertices in triangles -> divide it by 3 
    numtriags = numtriags / 3;
    }

  vtkIdList *ids = vtkIdList::New();
  vtkPoints *pts = vtkPoints::New();
  vtkPolyData* polyData = vtkPolyData::New();

  polyData ->SetPoints (pts);

  ids -> SetNumberOfIds(0);
  pts -> SetNumberOfPoints(0);
  polyData->Allocate();
  for (int k = 0; k < n_vert; k++)
    {
    vtkIdType id;
    id = pts->InsertNextPoint(all_vert[k][0],
                              all_vert[k][1], all_vert[k][2]);
    ids->InsertNextId(id);
    }
  for (int k = 0; k < numtriags; k++)
    {
    vtkIdList *tids = vtkIdList::New();
    tids->SetNumberOfIds(0);
    tids->InsertNextId(triangs[3*k]);
    tids->InsertNextId(triangs[3*k+1]);
    tids->InsertNextId(triangs[3*k+2]);
    polyData->InsertNextCell(VTK_TRIANGLE, tids);

    }

  return polyData;

  delete [] all_vert;
  delete [] all_triangs;
  delete [] triangs;
}

template <class TInputImage, class TOutputImage>
void
SkullStripping<TInputImage, TOutputImage>
::Update()
{

  cout << "SkullStripping::Update()" << endl;

//  std::string inputVolume; 
  std::string brainSurface; 
//  std::string brainMask; 
//  std::string maskedBrain;

//  int sphericalResolution = 20; 
  int sphericalResolution = 25;  
  int nIterations = 800; 
//  float lThreshold = 400; 
//  float uThreshold = 1600; 

  float lThreshold = 25;
  float uThreshold = 95;

  int postDilationRadius = 0;

// sphererical : 35, nItera : 1000/1500, lT: 40 uT: 80 -> Too small, but promising

  // parameters -- temporary setting
//  inputVolume = "aftersmooth_T1.mha";
  brainSurface = "brain.vtk";
//  brainMask = "mask.mha";
//  maskedBrain = "maskedbrain.mha";


   // parameter passing part: removed  !!!

  

  float constTanh1 = uThreshold/16.0;
  float constTanh2 = uThreshold/4.0;

  std::cout << "Constants for tanh: " << constTanh1 << ", " << constTanh2 << std::endl;

 /*
  ImageReaderType::Pointer imageReader  = ImageReaderType::New();

  imageReader->SetFileName( inputVolume.c_str() );
  imageReader->Update();

  ImageType::Pointer image = imageReader->GetOutput();
  */


  typename ImageType::SpacingType spacing = image->GetSpacing();

  // initialize label image
  typename LabelImageType::Pointer label = LabelImageType::New();
  label->CopyInformation( image );
  label->SetRegions( label->GetLargestPossibleRegion() );
  label->Allocate();

  typename LabelImageType::Pointer flabel = LabelImageType::New();
  flabel->CopyInformation( image );
  flabel->SetRegions( flabel->GetLargestPossibleRegion() );
  flabel->Allocate();
  flabel->FillBuffer( 0 );

  typename itk::ImageFileWriter<LabelImageType>::Pointer wlabel = itk::ImageFileWriter<LabelImageType>::New();

  // compute histogram
  typename GeneratorType::Pointer generator = GeneratorType::New();
  generator->SetInput( image );
  generator->SetNumberOfBins( 256 );
  generator->Compute();

  typename HistogramType::Pointer histogram = const_cast<HistogramType*>( generator->GetOutput() );
  PixelType t2 = static_cast<PixelType>(histogram->Quantile(0, 0.02));
  PixelType t98 = static_cast<PixelType> (histogram->Quantile(0, 0.98));
  PixelType tinit = static_cast<PixelType>(t2+0.1*static_cast<float>(t98-t2));  

  FindWhiteMatterPeak ( histogram );

    // compute brain size and center of gravity
  COG.Fill( 0.0 );
  unsigned long HeadVoxelCount = 0;
  itk::ImageRegionIteratorWithIndex<ImageType> it( image, image->GetLargestPossibleRegion() );
  for ( it.GoToBegin(); !it.IsAtEnd(); ++it )
    {
    PixelType a = it.Get();
    if (a < tinit || a > t98)
      {
      continue;
      }
    HeadVoxelCount ++;
    typename ImageType::IndexType idx = it.GetIndex();
    typename ImageType::PointType point;
    image->TransformIndexToPhysicalPoint( idx, point );
    for (::size_t k = 0; k < ImageDimension; k++)
      {
      COG[k] += point[k];
      }
    }
  

  float HeadVolume = static_cast<float>( HeadVoxelCount );
  for (::size_t k = 0; k < ImageDimension; k++)
    {
    COG[k] /= static_cast<float>( HeadVoxelCount );
    HeadVolume *= spacing[k];
    }

  // geometry you learn from middle school
  float radius = pow(HeadVolume*3.0/4.0/PI, 1.0/3.0);
  
  std::cout << "Threshold: \n";
  std::cout << "t2 = " << t2 << std::endl;
  std::cout << "tinit = " << tinit << std::endl;
  std::cout << "t98 = " << t98 << std::endl;
  std::cout << "number of head voxel = " << HeadVoxelCount << std::endl;
  std::cout << "volume of head = " << HeadVolume << std::endl;
  std::cout << "radius of head = " << radius << std::endl;
  std::cout << "COG = " << COG << std::endl;

  typename ImageType::IndexType COGIdx;
  image->TransformPhysicalPointToIndex( COG, COGIdx ); 
  std::cout << COGIdx << ": " << image->GetPixel( COGIdx ) << std::endl;

  // figure out aspects of the initial elipsoid
  typename ImageType::SizeType imageSize = image->GetLargestPossibleRegion().GetSize();
  typename ImageType::IndexType imageStart = image->GetLargestPossibleRegion().GetIndex();

  typename ImageType::IndexType Idx = COGIdx;
  PixelType tNonBackground = static_cast<PixelType>(0.25*static_cast<float>(tinit));
  int xStart = 0;
  int xEnd = 0;
  for (::size_t k = imageStart[0]; k < imageStart[0]+imageSize[0]; k++)
  {
    Idx[0] = k;
    if ( image->GetPixel(Idx) > tNonBackground )
    {
      xStart = k;
      break;
    }
  }
  for (int k = imageStart[0]+imageSize[0]-1; k >= imageStart[0]; k--)
  {
    Idx[0] = k;
    if ( image->GetPixel(Idx) > tNonBackground )
    {
      xEnd = k;
      break;
    }
  }
  float headWidth = static_cast<float>(xEnd-xStart)*spacing[0];

  Idx = COGIdx;
  int yStart = 0;
  int yEnd = 0;
  for (::size_t k = imageStart[1]; k < imageStart[1]+imageSize[1]; k++)
  {
    Idx[1] = k;
    if ( image->GetPixel(Idx) > tNonBackground )
    {
      yStart = k;
      break;
    }
  }
  for (int k = imageStart[1]+imageSize[1]-1; k >= imageStart[1]; k--)
  {
    Idx[1] = k;
    if ( image->GetPixel(Idx) > tNonBackground )
    {
      yEnd = k;
      break;
    }
  }
  float headLength = static_cast<float>(yEnd-yStart)*spacing[1];

  Idx = COGIdx;
  int zEnd=0;
  for (int k = imageStart[2]+imageSize[2]-1; k >= imageStart[2]; k--)
  {
    Idx[2] = k;
    if ( image->GetPixel(Idx) > tNonBackground )
    {
      zEnd = k;
      break;
    }
  }
  float headHeight = static_cast<float>(zEnd-COGIdx[2])*spacing[2]*2;

  std::cout << COG << " " << COGIdx << std::endl;
  std::cout << "spacing: " << spacing << std::endl;
  std::cout << xStart << " " << xEnd << "==" << headWidth << std::endl;
  std::cout << yStart << " " << yEnd << "==" << headLength << std::endl;
  std::cout << COGIdx[2] << " " << zEnd << "==" << headHeight << std::endl;

  // determain ellipsoid dimensions
  headLength *= (0.5/headWidth);
  headHeight *= (0.4/headWidth);
  headHeight = (headHeight > 0.5 ? headHeight : 0.5);
  headWidth = 0.5;

  std::cout << headWidth << " " << headLength << " " << headHeight << std::endl;

  vtkPolyData* polyData = TessellateIcosahedron( sphericalResolution );
  int nPoints = polyData->GetNumberOfPoints();
 
  // Build neighborhood structure
  std::vector<NeighborhoodType> NeighborhoodStructure;
  polyData->BuildLinks();
  for (int k = 0; k < nPoints; k++)
    {
    NeighborhoodType setIds;
    ComputeVertexNeighbors(k, polyData , setIds);
    NeighborhoodStructure.push_back( setIds );
    }

  // put mesh in the right position with right radius
  vtkPoints * allPoints = polyData->GetPoints();
  for (int k = 0; k < nPoints; k++)
    {
    double* point = polyData->GetPoint( k );
    point[0] = point[0]*radius*headWidth+COG[0];
    point[1] = point[1]*radius*headLength+COG[1];
    point[2] = point[2]*radius*headHeight+COG[2];
    allPoints->SetPoint( k, point[0], point[1], point[2] );
    }  

  typename itk::LinearInterpolateImageFunction<ImageType, double>::Pointer interpolator = 
    itk::LinearInterpolateImageFunction<ImageType, double>::New();
  interpolator->SetInputImage( image );

  // figure out average edge length and use it to adjust image terms 
  int count = 0;
  double AveInitialEdgeLength = 0;
  for (int k = 0; k < 1000; k++)  // for each point
  {
    NeighborhoodType nbr = NeighborhoodStructure[k];
    int nNeighbors = nbr.size();
    double pc[3];
    double * p = polyData->GetPoint( k );
    pc[0] = p[0]; pc[1] = p[1]; pc[2] = p[2]; 

    for (int m = 0; m < nNeighbors; m ++)
    {
      vtkIdType id = nbr[m];
      double p0[3];
      double p1[3];
      p = polyData->GetPoint( id );
      p0[0] = p[0]; p0[1] = p[1]; p0[2] = p[2]; 
      id = nbr[(m+1)%nNeighbors];
      p = polyData->GetPoint( id );
      p1[0] = p[0]; p1[1] = p[1]; p1[2] = p[2]; 

      for (int n = 0; n < 3; n++)
      {
        p0[n] -= pc[n];
        p1[n] -= pc[n];
      }

      AveInitialEdgeLength += sqrt(p0[0]*p0[0]+p0[1]*p0[1]+p0[2]*p0[2]);
      count ++;
    }
  }
  AveInitialEdgeLength /= static_cast<double>( count );
  std::cout << "Average initial Edge Length: " << AveInitialEdgeLength << std::endl;

  AveInitialEdgeLength /= 2.8;
  AveInitialEdgeLength = AveInitialEdgeLength*AveInitialEdgeLength;

  // do iteration

  double radiusMin = 3.33;
  double radiusMax = 10;

  double E = (radiusMin + radiusMax)/(2*radiusMin*radiusMax);
  double F = 6*radiusMin*radiusMax/(radiusMax-radiusMin);

  int iter = 0;
  int nSearchPoints = 40;
  double stepSize = 0.5;
  double relaxFactor = 0.75;

  std::vector<typename ImageType::PixelType> IntensityOnLine(nSearchPoints);

  unsigned char lblValue = 0;
  double change = 0;
  double change1 = 0;
  double change2 = 0;
  double change3 = 0;

  while (1)
    {
    if (iter == nIterations)
    {
      break;
    }
    if (change > 0 && change1 > 0 && change2 > 0 && change3 > 0 && (2*change/(change1+change3)) < 0.05)
    {
      break;
    }

    // write out something
    // if (iter % 25 == 0)
    if (0)
      {
      lblValue ++;
      char filename[1024];
      PolyDataToLabelMap( polyData, label );
      sprintf( filename, "deform%03d.mha", iter );
      wlabel->SetFileName( filename );
      wlabel->SetInput( label );
      wlabel->Update();

      itk::ImageRegionIteratorWithIndex<LabelImageType> itLabel( label, label->GetLargestPossibleRegion() );
      for (itLabel.GoToBegin(); !itLabel.IsAtEnd(); ++itLabel)
      {
        typename LabelImageType::IndexType i = itLabel.GetIndex();
        if (itLabel.Get() == 0 && flabel->GetPixel(i) != 0)
        {
          flabel->SetPixel( i, 0 );
        }
        else if (itLabel.Get() != 0 && flabel->GetPixel(i) == 0)
        {
          flabel->SetPixel( i, lblValue );
        }
      }


      sprintf( filename, "deform%03d.vtk", iter );
      vtkPolyDataWriter *w = vtkPolyDataWriter::New();
      w->SetFileName( filename );
      w->SetFileTypeToASCII();

      for (int k = 0; k < nPoints; k++)
        {
        double* point = polyData->GetPoint( k );
        point[0] = -point[0];
        point[1] = -point[1];
        allPoints->SetPoint( k, point[0], point[1], point[2] );
        }  
      w->SetInput(polyData);
      w->Update();
      w->Delete();

      for (int k = 0; k < nPoints; k++)
        {
        double* point = polyData->GetPoint( k );
        point[0] = -point[0];
        point[1] = -point[1];
        allPoints->SetPoint( k, point[0], point[1], point[2] );
        }  
      }
    
    change = 0;
    change1 = 0;
    change2 = 0;
    change3 = 0;
    for (int k = 0; k < nPoints; k++)  // for each point
      {
      double update[3];
      double update1[3];
      double update2[3];
      double update3[3];      

      // 1. compute normal at the point and average position of its neighbors
      double pc[3];
      double * p = polyData->GetPoint( k );
      pc[0] = p[0]; pc[1] = p[1]; pc[2] = p[2]; 

      NeighborhoodType nbr = NeighborhoodStructure[k];
      int nNeighbors = nbr.size();
      double normal[3];
      double average[3];
      double averageEdgeLength = 0;
      for (int m = 0; m < 3; m++)
        {
        normal[m] = 0;
        average[m] = 0;
        update[m] = 0;
        update1[m] = 0;
        update2[m] = 0;
        update3[m] = 0;
        }
      
      for (int m = 0; m < nNeighbors; m ++)
        {
        vtkIdType id = nbr[m];
        double p0[3];
        double p1[3];
        p = polyData->GetPoint( id );
        p0[0] = p[0]; p0[1] = p[1]; p0[2] = p[2]; 
        id = nbr[(m+1)%nNeighbors];
        p = polyData->GetPoint( id );
        p1[0] = p[0]; p1[1] = p[1]; p1[2] = p[2]; 
        
        for (int n = 0; n < 3; n++)
          {
          average[n] += p0[n];
          p0[n] -= pc[n];
          p1[n] -= pc[n];
          }
        
        averageEdgeLength += sqrt(p0[0]*p0[0]+p0[1]*p0[1]+p0[2]*p0[2]);

        double op[3];
        op[0] = p0[1]*p1[2]-p0[2]*p1[1];
        op[1] = p0[2]*p1[0]-p0[0]*p1[2];
        op[2] = p0[0]*p1[1]-p0[1]*p1[0];
        
        for (int n = 0; n < 3; n++)
          {
          normal[n] += op[n];
          }
        }
      
      average[0] /= static_cast<float>( nNeighbors );
      average[1] /= static_cast<float>( nNeighbors );
      average[2] /= static_cast<float>( nNeighbors );

      averageEdgeLength /= static_cast<float>( nNeighbors );

      float mag = sqrt(normal[0]*normal[0]+normal[1]*normal[1]+normal[2]*normal[2]);
      normal[0] /= mag;
      normal[1] /= mag;
      normal[2] /= mag;
      
      double diffvector[3];
      diffvector[0] = average[0]-pc[0];
      diffvector[1] = average[1]-pc[1];
      diffvector[2] = average[2]-pc[2];
      
      double normalcomponent = fabs(diffvector[0]*normal[0]+diffvector[1]*normal[1]+diffvector[2]*normal[2]);
      double diffnormal[3];
      diffnormal[0] = normalcomponent*normal[0];
      diffnormal[1] = normalcomponent*normal[1];
      diffnormal[2] = normalcomponent*normal[2];

      double difftangent[3];
      difftangent[0] = diffvector[0]-diffnormal[0];
      difftangent[1] = diffvector[1]-diffnormal[1];
      difftangent[2] = diffvector[2]-diffnormal[2];

      update1[0] = difftangent[0]*0.5*stepSize;
      update1[1] = difftangent[1]*0.5*stepSize;
      update1[2] = difftangent[2]*0.5*stepSize;

      double radiusOfCurvature = fabs(averageEdgeLength*averageEdgeLength/2.0);
      radiusOfCurvature /= normalcomponent;
      double w2 = 0.5*(1.0+tanh(F*(1/radiusOfCurvature-E)));

      update2[0] = 0.25*w2*diffnormal[0]*stepSize;
      update2[1] = 0.25*w2*diffnormal[1]*stepSize;
      update2[2] = 0.25*w2*diffnormal[2]*stepSize;
      
      float imageforceiter;
      float imageforce = 1.0;
      // xk = pc is a physical point
      for (int d = 0; d < 30; d++)
        {
        typename ImageType::PointType point;  
       
        // set point values
        for( int m = 0; m <3; m ++)
          {
          point[m] = pc[m]-d*normal[m]*stepSize;         
          }
        itk::ContinuousIndex<double, 3> cidx;    
        image->TransformPhysicalPointToContinuousIndex( point, cidx );
        typename ImageType::PixelType value;
        if (image->GetLargestPossibleRegion().IsInside(cidx))
        {
          value = static_cast<PixelType>( interpolator->EvaluateAtContinuousIndex( cidx ) );
        }
        else
        {
          value = 0;
        }

        double tanhvalue = tanh((static_cast <float> (value)-lThreshold)/constTanh1);
        //double tanhvalue = tanh((static_cast <float> (value*24)/static_cast <float> (t98))-4);
        imageforceiter = (0 < tanhvalue ? tanhvalue : 0);
        if ( iter > 50 && imageforceiter == 0)
        {
          imageforce = -0.05;
        }
        else if ( imageforceiter > 0 && imageforce > 0 )
        {
          imageforce *= imageforceiter;
        }

        //take care of areas around eyes 
        double tanhvalue2 = tanh((static_cast <float> (value)-uThreshold)/constTanh2);
        //double tanhvalue2 = tanh((static_cast <float> (value*6)/static_cast <float> (t98))-4);
        if ( tanhvalue2 > 0 )
        {
          imageforce = -(fabs(imageforce));
          update1[0] *= relaxFactor;
          update1[1] *= relaxFactor;
          update1[2] *= relaxFactor;
          update2[0] *= relaxFactor;
          update2[1] *= relaxFactor;
          update2[2] *= relaxFactor;
        }
      }
      imageforce *= (stepSize*1.25*AveInitialEdgeLength);
      update3[0] = imageforce*normal[0];
      update3[1] = imageforce*normal[1];
      update3[2] = imageforce*normal[2];
            
      if ( iter > 200 )
        {
        update[0] = update1[0]+update2[0]+update3[0]*0.75;
        update[1] = update1[1]+update2[1]+update3[1]*0.75;
        update[2] = update1[2]+update2[2]+update3[2]*0.75;
        }
      else
        {
        update[0] = (update1[0]+update2[0])+update3[0];
        update[1] = (update1[1]+update2[1])+update3[1];
        update[2] = (update1[2]+update2[2])+update3[2];
        }

      pc[0] += update[0];
      pc[1] += update[1];
      pc[2] += update[2];

      allPoints->SetPoint( k, pc[0], pc[1], pc[2] );

      change += sqrt(update[0]*update[0]+update[1]*update[1]+update[2]*update[2]);
      change1 += sqrt(update1[0]*update1[0]+update1[1]*update1[1]+update1[2]*update1[2]);
      change2 += sqrt(update2[0]*update2[0]+update2[1]*update2[1]+update2[2]*update2[2]);
      change3 += sqrt(update3[0]*update3[0]+update3[1]*update3[1]+update3[2]*update3[2]);

      }

      if (iter % 25 == 0)
      {
        std::cout << "EOI " << iter << ": ";
      std::cout << "  C = " << change <<  "  C1 = " << change1 <<  "  C = " << change2 <<  "  C3 = " << change3 << std::endl;
      }

      iter++;

    }
  std::cout << "Total iterations: " << iter << std::endl;

  PolyDataToLabelMap( polyData, label );

  allPoints = polyData->GetPoints();
  for (int k = 0; k < nPoints; k++)
  {
    double* point = polyData->GetPoint( k );
    point[0] = -point[0];
    point[1] = -point[1];
    allPoints->SetPoint( k, point[0], point[1], point[2] );
  }  

  vtkPolyDataWriter *wPoly = vtkPolyDataWriter::New();
  wPoly->SetFileTypeToASCII();
  wPoly->SetFileName(brainSurface.c_str());
  wPoly->SetInput(polyData);    
  wPoly->Update();

  // binary dilation with radius 2
  typename LabelImageType::Pointer imgDilate;
  if (postDilationRadius > 0)
  {
    imgDilate = BinaryDilateFilter3D( label, postDilationRadius );
  }
  else
  {
    imgDilate = label;
  }
  //wlabel->SetFileName( brainMask.c_str() );
  wlabel->SetFileName( m_MaskFilename.c_str() );
  wlabel->SetInput( imgDilate );
  wlabel->Update();

  // masked out brain region
  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
    {
    typename ImageType::IndexType idx = it.GetIndex();
    if (imgDilate->GetPixel(idx) == 0)
      {
      it.Set( 0 );
      }
    }

  typename itk::ImageFileWriter<ImageType>::Pointer wImage = itk::ImageFileWriter<ImageType>::New();
  //wImage->SetFileName( maskedBrain.c_str() );
  wImage->SetFileName( m_MaskedBrainFilename.c_str() );
  wImage->SetInput( image );

  try
    {
    wImage->Update();
    }
  catch (itk::ExceptionObject * ob)
    {
    std::cerr << ob << std::endl;
    }
}



template <class TInputImage, class TOutputImage>
void
SkullStripping<TInputImage, TOutputImage>
::SetInputVolume(ImageType *inputvolume)
{

  cout << "SkullStripping::SetInputVolume" << endl;

  image = ImageType::New();
  image = inputvolume; 
}

template <class TInputImage, class TOutputImage>
void
SkullStripping<TInputImage, TOutputImage>
::SetOutputVolumeFilename(std::string& maskedbrainfilename)
{
  cout << "SkullStripping::SetOutputVolumeFilename" << endl;

  m_MaskedBrainFilename = maskedbrainfilename;
}

template <class TInputImage, class TOutputImage>
void
SkullStripping<TInputImage, TOutputImage>
::SetOutputMaskFilename(std::string& maskfilename)
{
  cout << "SkullStripping::SetOutputMaskFilename" << endl;

  m_MaskFilename = maskfilename;
}




} // namespace Statistics
}  // namespace itk

#endif



