#include "Computation.h"

#define DEBUG -1

Computation::Computation()
{}

Computation::~Computation()
{}

void Computation::ReadImage(std::string _FileName)
{
  std::cout<<"Reading image..."<<std::endl;
  
  VolumeReaderType::Pointer imageReader = VolumeReaderType::New();
  imageReader->SetFileName(_FileName.c_str());
  try
    {
      imageReader->Update() ;
    }
  catch (itk::ExceptionObject err)
    {
      std::cerr<<"Exception object caught!"<<std::endl;
      std::cerr<<err<<std::endl;
      exit(EXIT_FAILURE);
    }
  m_inputImage = imageReader->GetOutput();
}

void Computation::ReadMesh(std::string _FileName)
{
  std::cout<<"Reading mesh..."<<std::endl;
  
  std::ifstream Infile;
  std::string fileExtension;
  char Line[40], NbVertices[10];
  std::string line;
  std::size_t found1, found2, length;
  CovariantVectorType Vertex1, Vertex2, Vertex3;
  int CurrentPoint;
   
  m_vVertices.clear();  
  
  Infile.open(_FileName.c_str()); 
  
  int lastPoint = _FileName.rfind('.');
  fileExtension = _FileName.substr(lastPoint);

  if (fileExtension.compare(".meta") == 0 )
    {
      //Skips the header and gets the number of points
      while ( strncmp (Line, "NPoints =", 9) && strncmp (Line, "NPoints=", 8))
	Infile.getline (Line, 40);    
      SetNbVertices(atoi(strrchr(Line,' ')));
      
      Infile.getline ( Line, 40);

      //read the points in the file and set them as vertices
      for (int i = 0; i < GetNbVertices(); i++ )
	{
	  Infile >> CurrentPoint >> Vertex1[0] >> Vertex1[1] >> Vertex1[2];
	  m_vVertices.push_back(Vertex1);
	}
    }
  else if (fileExtension.compare(".vtk") == 0 )
    {
      while ( strncmp (Line, "POINTS", 6))
	Infile.getline (Line, 40);
      line = Line;
      found1 = line.find(' ');
      found2 = line.find(' ',found1+1);
      length = line.copy(NbVertices,found2-found1-1,found1+1);
      NbVertices[length]='\0';
      SetNbVertices(atoi(NbVertices));
      
      for (int i = 0; i < GetNbVertices()/3; i++ )
	{
	  Infile >> Vertex1[0] >> Vertex1[1] >> Vertex1[2] >> Vertex2[0] >> Vertex2[1] >> Vertex2[2] >> Vertex3[0] >> Vertex3[1] >> Vertex3[2];
	  m_vVertices.push_back(Vertex1);
	  m_vVertices.push_back(Vertex2);
	  m_vVertices.push_back(Vertex3);	  
	}
      if ((GetNbVertices() % 3) == 1)
	{
	  Infile >> Vertex1[0] >> Vertex1[1] >> Vertex1[2];
	  m_vVertices.push_back(Vertex1);
	}
      else if ((GetNbVertices() % 3) == 2)
	{
	  Infile >> Vertex1[0] >> Vertex1[1] >> Vertex1[2] >> Vertex2[0] >> Vertex2[1] >> Vertex2[2];
	  m_vVertices.push_back(Vertex1);
	  m_vVertices.push_back(Vertex2);
	}      
    }
  //close file
  Infile.close();
}

void Computation::ReadTextFile(std::string _FileName)
{
  std::cout<<"Reading text file..."<<std::endl;
  
  std::ifstream Infile;
  char Line[40];
  int CurrentIndex;
   
  m_vIndexBadVertices.clear();  
  
  Infile.open(_FileName.c_str());  
  //Skips the header and gets the number of points
  while ( strncmp (Line, "NUMBER_OF_POINTS =", 18) && strncmp (Line, "NUMBER_OF_POINTS=", 17))
    Infile.getline (Line, 40);    
  SetNbBadVertices(atoi(strrchr(Line,'=')+1));
  
  Infile.getline ( Line, 40);
  Infile.getline ( Line, 40);

  //read the points in the file and set them as vertices
  for (int i = 0; i < GetNbBadVertices(); i++ )
    {
      Infile >> CurrentIndex;
      m_vIndexBadVertices.push_back(CurrentIndex);
    }

  //close file
  Infile.close();  
}

void Computation::FixImage()
{
  std::cout<<"Computation..."<<std::endl;
  
  bool IsImageCorrected = true;
  
  FindBadPixels();
  //----
  if (GetNeighborhood())
    Dilation();
  //----
  OutputImageAllocation();
  ProcessedImageCreation();
  while (IsImageCorrected)
    {
      IsImageCorrected = MajorityVoting();
      m_processedImage = m_outputImage;
    }
}

void Computation::FindBadPixels()
{
  std::cout<<"Finding bad pixels..."<<std::endl;
  
  // Allocate correction image
  m_correctionImage = ImageType::New();
  m_correctionImage->SetOrigin(m_inputImage->GetOrigin());
  m_correctionImage->SetSpacing(m_inputImage->GetSpacing());
  m_correctionImage->SetRegions(m_inputImage->GetRequestedRegion());
  m_correctionImage->Allocate();

  ImageType::PointType origin = m_inputImage->GetOrigin();
  
  IteratorType CorrectionIterator(m_correctionImage, m_correctionImage->GetRequestedRegion());
  for (CorrectionIterator.GoToBegin(); !CorrectionIterator.IsAtEnd(); ++CorrectionIterator)
    CorrectionIterator.Set(0);
  
  CovariantVectorType CurrentBadVertex;
  ImageType::IndexType CurrentIndex;
  int Element[3] = {-1,-1,1};

  for (unsigned int i = 0; i < m_vIndexBadVertices.size(); i++)
    {
      CurrentBadVertex = m_vVertices[m_vIndexBadVertices[i]];
      for (int j = 0; j < 3; j++)
	CurrentIndex[j] = (int) (((Element[j]*CurrentBadVertex[j])-origin[j])/m_inputImage->GetSpacing()[j] + 0.5);
      m_correctionImage->SetPixel(CurrentIndex,1);

      if (DEBUG == 1)
	{
	  std::cout<<"\ni: "<<i<<std::endl;
	  std::cout<<"BadVertexId: "<<m_vIndexBadVertices[i]<<std::endl;
	  std::cout<<"BadVertex 3DPosition: "<<CurrentBadVertex[0]<<" "<<CurrentBadVertex[1]<<" "<<CurrentBadVertex[2]<<std::endl;
	  std::cout<<"BadVertex Index: "<<CurrentIndex[0]<<" "<<CurrentIndex[1]<<" "<<CurrentIndex[2]<<std::endl;
	  std::cout<<"BadVertex Value on Input Image: "<<m_inputImage->GetPixel(CurrentIndex)<<std::endl;
	}    
    }
}

//----
void Computation::Dilation()
{
  std::cout<<"Dilation..."<<std::endl;
  
  StructuringElementType structuringElement;
  structuringElement.SetRadius(1);  // 3x3x3 structuring element
  structuringElement.CreateStructuringElement();
  
  dilateFilterType::Pointer dilateFilter = dilateFilterType::New();
  dilateFilter->SetInput(m_correctionImage);
  dilateFilter->SetDilateValue(1);
  dilateFilter->SetKernel(structuringElement);
  try
    {
      dilateFilter->Update();
    }
  catch (itk::ExceptionObject & err)
    {
      std::cerr << "ExceptionObject caught!" << std::endl;
      std::cerr << err << std::endl;
      exit(EXIT_FAILURE);
    }   
  
  m_correctionImage = dilateFilter->GetOutput();
}
//----

void Computation::OutputImageAllocation()
{
  std::cout<<"Allocating output image..."<<std::endl;
  
  m_outputImage = ImageType::New();
  m_outputImage->SetOrigin(m_inputImage->GetOrigin());
  m_outputImage->SetSpacing(m_inputImage->GetSpacing());
  m_outputImage->SetRegions(m_inputImage->GetRequestedRegion());
  m_outputImage->Allocate();
}

void Computation::ProcessedImageCreation()
{
  std::cout<<"Creating processed image..."<<std::endl;
  
  m_processedImage = ImageType::New();
  m_processedImage->SetOrigin(m_inputImage->GetOrigin());
  m_processedImage->SetSpacing(m_inputImage->GetSpacing());
  m_processedImage->SetRegions(m_inputImage->GetRequestedRegion());
  m_processedImage->Allocate();

  ConstIteratorType ConstInputIterator(m_inputImage,m_inputImage->GetRequestedRegion());
  IteratorType ProcessedIterator(m_processedImage,m_processedImage->GetRequestedRegion());
  for (ConstInputIterator.GoToBegin(), ProcessedIterator.GoToBegin(); !ConstInputIterator.IsAtEnd(); \
       ++ConstInputIterator, ++ProcessedIterator)
    ProcessedIterator.Set(ConstInputIterator.Get());
}

bool Computation::MajorityVoting()
{
  std::cout<<"Fixing image..."<<std::endl;

  bool Correction = false;
  
  ConstIteratorType ConstInputIterator(m_processedImage,m_processedImage->GetRequestedRegion());
  ConstIteratorType ConstCorrectionIterator(m_correctionImage,m_correctionImage->GetRequestedRegion());
  IteratorType OutputIterator(m_outputImage, m_processedImage->GetRequestedRegion());
 
  NeighborhoodIteratorType::RadiusType Radius;
  Radius.Fill(1);
  NeighborhoodIteratorType NeighborhoodInputIterator(Radius,m_processedImage,m_processedImage->GetRequestedRegion());
  NeighborhoodIteratorType::OffsetType offset1 = {{-1,-1,-1}};
  NeighborhoodIteratorType::OffsetType offset2 = {{-1,-1,0}};
  NeighborhoodIteratorType::OffsetType offset3 = {{-1,-1,1}};
  NeighborhoodIteratorType::OffsetType offset4 = {{-1,0,-1}};
  NeighborhoodIteratorType::OffsetType offset5 = {{-1,0,0}};
  NeighborhoodIteratorType::OffsetType offset6 = {{-1,0,1}};
  NeighborhoodIteratorType::OffsetType offset7 = {{-1,1,-1}};
  NeighborhoodIteratorType::OffsetType offset8 = {{-1,1,0}};
  NeighborhoodIteratorType::OffsetType offset9 = {{-1,1,1}};
  
  NeighborhoodIteratorType::OffsetType offset10 = {{0,-1,-1}};
  NeighborhoodIteratorType::OffsetType offset11 = {{0,-1,0}};
  NeighborhoodIteratorType::OffsetType offset12 = {{0,-1,1}};
  NeighborhoodIteratorType::OffsetType offset13 = {{0,0,-1}};
  NeighborhoodIteratorType::OffsetType offset14 = {{0,0,1}};
  NeighborhoodIteratorType::OffsetType offset15 = {{0,1,-1}};
  NeighborhoodIteratorType::OffsetType offset16 = {{0,1,0}};
  NeighborhoodIteratorType::OffsetType offset17 = {{0,1,1}};
  
  NeighborhoodIteratorType::OffsetType offset18 = {{1,-1,-1}};
  NeighborhoodIteratorType::OffsetType offset19 = {{1,-1,0}};
  NeighborhoodIteratorType::OffsetType offset20 = {{1,-1,1}};
  NeighborhoodIteratorType::OffsetType offset21 = {{1,0,-1}};
  NeighborhoodIteratorType::OffsetType offset22 = {{1,0,0}};
  NeighborhoodIteratorType::OffsetType offset23 = {{1,0,1}};
  NeighborhoodIteratorType::OffsetType offset24 = {{1,1,-1}};
  NeighborhoodIteratorType::OffsetType offset25 = {{1,1,0}};
  NeighborhoodIteratorType::OffsetType offset26 = {{1,1,1}};
 
  // Computation
  // -  Copy input data to output
  for (ConstInputIterator.GoToBegin(), OutputIterator.GoToBegin(); !ConstInputIterator.IsAtEnd(); \
       ++ConstInputIterator, ++OutputIterator)
    OutputIterator.Set(ConstInputIterator.Get());
      
  ImagePixelType sum;
  OutputIterator.GoToBegin();
  ConstCorrectionIterator.GoToBegin();
  NeighborhoodInputIterator.GoToBegin();
  while (!OutputIterator.IsAtEnd())
      {
	if (ConstCorrectionIterator.Get() == 1)
	  {
	    sum = NeighborhoodInputIterator.GetPixel(offset1)+NeighborhoodInputIterator.GetPixel(offset2)+ \
	      NeighborhoodInputIterator.GetPixel(offset3)+NeighborhoodInputIterator.GetPixel(offset4)+ \
	      NeighborhoodInputIterator.GetPixel(offset5)+NeighborhoodInputIterator.GetPixel(offset6)+ \
	      NeighborhoodInputIterator.GetPixel(offset7)+NeighborhoodInputIterator.GetPixel(offset8)+ \
	      NeighborhoodInputIterator.GetPixel(offset9)+NeighborhoodInputIterator.GetPixel(offset10)+ \
	      NeighborhoodInputIterator.GetPixel(offset11)+NeighborhoodInputIterator.GetPixel(offset12)+ \
	      NeighborhoodInputIterator.GetPixel(offset13)+NeighborhoodInputIterator.GetPixel(offset14)+ \
	      NeighborhoodInputIterator.GetPixel(offset15)+NeighborhoodInputIterator.GetPixel(offset16)+ \
	      NeighborhoodInputIterator.GetPixel(offset17)+NeighborhoodInputIterator.GetPixel(offset18)+ \
	      NeighborhoodInputIterator.GetPixel(offset19)+NeighborhoodInputIterator.GetPixel(offset20)+ \
	      NeighborhoodInputIterator.GetPixel(offset21)+NeighborhoodInputIterator.GetPixel(offset22)+ \
	      NeighborhoodInputIterator.GetPixel(offset23)+NeighborhoodInputIterator.GetPixel(offset24)+ \
	      NeighborhoodInputIterator.GetPixel(offset25)+NeighborhoodInputIterator.GetPixel(offset26);
	    if ( (NeighborhoodInputIterator.GetCenterPixel() == 0) && (sum > 13) )
	      {
		OutputIterator.Set(1);
		Correction = true;
	      }
	    else if ( (NeighborhoodInputIterator.GetCenterPixel() == 1) && (sum < 13) )
	      {
		OutputIterator.Set(0);
		Correction = true;
	      }
	  }
	++OutputIterator;
	++ConstCorrectionIterator;
	++NeighborhoodInputIterator;
      }
  return Correction;
}

void Computation::ConnectivityEnforcement()
{
  std::cout<<"Connectivity enforcement..."<<std::endl;
  
  ImageType::IndexType nullIndex;
  nullIndex[0] = 0;
  nullIndex[1] = 0;
  nullIndex[2] = 0;
  
  ImagePixelType *data = &((*m_outputImage)[nullIndex]);
  ImageType::RegionType imageRegion = m_outputImage->GetBufferedRegion();
  int dim[3];
  dim[0] = imageRegion.GetSize(0);
  dim[1] = imageRegion.GetSize(1);
  dim[2] = imageRegion.GetSize(2);

  clear_edge(data, dim, 0);
  NoDiagConnect(data,dim);
}

// does not allow connection via diagonals only, enforces strict 6 connectedness
// image has to be of type unsigned short
int Computation::NoDiagConnect(unsigned short *image, int *dim) 
{
  //z axis
  int dimx = dim[0];
  int dimy = dim[1];
  int dimz = dim[2];
  bool correctionNeeded = true;
  int cnt = 0;

  while (correctionNeeded)
    {
      cnt++;
      //if (debug) cout << "NoDiag scan " << cnt << endl; 
      correctionNeeded = false;
      int dy = dimx*dimy;
      int dx = dimx;
      
      for (int i = 1; i < dimx - 1; i++)
	{
	  for (int j = 1; j < dimy - 1; j++)
	    {
	      for (int k = 1; k < dimz - 1; k++) 
		{
		  unsigned short val = image[i + j * dimx + k * dy];
		  if (val != 0)
		    {
		      // x,y 
		      if ((image[i-1+j*dx+k*dy] == 0) && (image[i+(j-1)*dx+k*dy] == 0) && (image[i-1+(j-1)*dx+k*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i-1+j*dx+k*dy] = val;
			}
		      if ((image[i+1+j*dx+k*dy] == 0) && (image[i+(j+1)*dx+k*dy] == 0) && (image[i+1+(j+1)*dx+k*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+1+j*dx+k*dy] = val;
			}
		      if ((image[i+1+j*dx+k*dy] == 0) && (image[i+(j-1)*dx+k*dy] == 0) && (image[i+1+(j-1)*dx+k*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+1+j*dx+k*dy] = val;
			}
		      if ((image[i-1+j*dx+k*dy] == 0) && (image[i+(j+1)*dx+k*dy] == 0) && (image[i-1+(j+1)*dx+k*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i-1+j*dx+k*dy] = val;
			}
         
		      // xz
		      if ((image[i-1+j*dx+k*dy] == 0) && (image[i+j*dx+(k-1)*dy] == 0) && (image[i-1+j*dx+(k-1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i-1+j*dx+k*dy] = val;
			}
		      if ((image[i+1+j*dx+k*dy] == 0) && (image[i+j*dx+(k+1)*dy] == 0) && (image[i+1+j*dx+(k+1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+1+j*dx+k*dy] = val;
			}
		      if ((image[i+1+j*dx+k*dy] == 0) && (image[i+j*dx+(k-1)*dy] == 0) && (image[i+1+j*dx+(k-1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+1+j*dx+k*dy] = val;
			}
		      if ((image[i-1+j*dx+k*dy] == 0) && (image[i+j*dx+(k+1)*dy] == 0) && (image[i-1+j*dx+(k+1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i-1+j*dx+k*dy] = val;
			}
		      
		      // yz
		      if ((image[i+(j-1)*dx+k*dy] == 0) && (image[i+j*dx+(k-1)*dy] == 0) && (image[i+(j-1)*dx+(k-1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+(j-1)*dx+k*dy] = val;
			}
		      if ((image[i+(j+1)*dx+k*dy] == 0) && (image[i+j*dx+(k+1)*dy] == 0) && (image[i+(j+1)*dx+(k+1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+(j+1)*dx+k*dy] = val;
			}
		      if ((image[i+(j+1)*dx+k*dy] == 0) && (image[i+j*dx+(k-1)*dy] == 0) && (image[i+(j+1)*dx+(k-1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+(j+1)*dx+k*dy] = val;
			}
		      if ((image[i+(j-1)*dx+k*dy] == 0) && (image[i+j*dx+(k+1)*dy] == 0) && (image[i+(j-1)*dx+(k+1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+(j-1)*dx+k*dy] = val;
			}
		    }
		}
	    }
	}
    }  
  return 1;
}

// clears the edge of the image
void Computation::clear_edge(unsigned short *image, int *dims, int clear_label)
{
  int size_plane = dims[0]*dims[1];
  int size_line = dims[0];

  for (int z = 0; z < dims[2]; z++)
    {
      for (int y = 0; y < dims[1]; y++)
	{
	  if ( (y == 0) || (y == dims[1]-1) || (z == 0) || (z == dims[2]-1) )
	    { // draw whole plane
	      for (int x = 0; x < dims[0] ; x++) 
		image[x +  size_line * y + size_plane * z] = clear_label;
	    }
	  else 
	    { // draw edges of x
	      image[0 +  size_line * y + size_plane * z] = clear_label;
	      image[size_line - 1 +  size_line * y + size_plane * z] = clear_label;
	    }
	}
    }
}

bool Computation::CheckOutput()
{
  subFilterType::Pointer subFilter = subFilterType::New();
  subFilter->SetInput1(m_inputImage);
  subFilter->SetInput2(m_outputImage);
  try 
    {
      subFilter->Update();    
    }
  catch (itk::ExceptionObject & err)
    {
      std::cerr<<"ExceptionObject caught!"<<std::endl;
      std::cerr<<err<<std::endl;
      exit(EXIT_FAILURE);
    }

  // Computation of the maximum intensity
  minmaxImageCalculatorType::Pointer minmaxImageCalculator = minmaxImageCalculatorType::New();
  minmaxImageCalculator->SetImage(subFilter->GetOutput());
  minmaxImageCalculator->Compute();
  int IntensityMax = (int) minmaxImageCalculator->GetMaximum();
  int IntensityMin = (int) minmaxImageCalculator->GetMinimum();
  if (IntensityMax == 0 && IntensityMin == 0)
    return 0;
  else
    return 1;
}

void Computation::WriteCorrectionImage(std::string _FileName)
{
  WriteImage(_FileName,m_correctionImage);
}

void Computation::WriteOutputImage(std::string _FileName)
{
  WriteImage(_FileName,m_outputImage);
}

void Computation::WriteImage(std::string _FileName, ImageType::Pointer _Image)
{
  std::cout<<"Writing "<<_FileName<<" ..."<<std::endl;
  
  VolumeWriterType::Pointer imageWriter = VolumeWriterType::New();
  imageWriter->SetFileName(_FileName.c_str());
  imageWriter->SetInput(_Image);
  imageWriter->UseCompressionOn();
  try
    {
      imageWriter->Update();
    }
  catch (itk::ExceptionObject & err)
    {
      std::cerr<<"Exception object caught!"<<std::endl;
      std::cerr<<err<<std::endl;
      exit(EXIT_FAILURE);
    }
}
