/*=========================================================================
  Program:   WMSegPostProcess
  Module:    $RCSfile: Computation.cxx,v $
  Language:  C++
  Date:      $Date: 2010/06/09 11:19:42 $
  Version:   $Revision: 1.4 $
  Author:    Clement Vachet (cvachet@email.unc.edu)

  Copyright (c) Clement Vachet. All rights reserved.
  See NeuroLibCopyright.txt or http://www.ia.unc.edu/dev/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even 
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/

#include "Computation.h"

Computation::Computation(string _InputImageName, string _OutputImageName)
{
  SetInputImage(_InputImageName);
  SetOutputImage(_OutputImageName);
}

Computation::~Computation()
{}

void Computation::ReadImage()
{
  cout<<"Reading image..."<<endl;
  VolumeReaderType::Pointer imageReader = VolumeReaderType::New();
  imageReader->SetFileName(m_InputImageName);
  try
    {
      imageReader->Update();
    }
  catch (ExceptionObject err)
    {
      cerr<<"Exception object caught!"<<endl;
      cerr<<err<<endl;
      exit(EXIT_FAILURE);
    }
  m_inputImage = imageReader->GetOutput();
}

void Computation::LargestComponent()
{
  cout<<"Computing largest component..."<<endl;
  
  ImageType::Pointer outputImage;
  ComputeLargestComponent(m_inputImage, outputImage);
  m_inputImage = outputImage;
}

void Computation::ComputeLargestComponent(ImageType::Pointer _inputImage, ImageType::Pointer &_outputImage)
{
  ConnectiveFilterType::Pointer Connective = ConnectiveFilterType::New();
  RelabelFilterType::Pointer relabelFilter = RelabelFilterType::New();
  threshFilterType::Pointer ThresFilter = threshFilterType::New();
  
  //Get the connectivity map of the image
  Connective->SetInput(_inputImage);
  try 
    {
      Connective->Update();
    }
  catch (ExceptionObject & err)
    {
      cerr << "ExceptionObject caught!" << endl;
      cerr << err << endl;
      exit(EXIT_FAILURE);	
    } 
  //Sort the labels according to their size, each labeled object has a different value
  relabelFilter->SetInput(Connective->GetOutput());    
  
  //Keep the Lbl label which is the biggest component
  ThresFilter->SetInput(relabelFilter->GetOutput());
  
  ThresFilter->SetLowerThreshold(1);
  ThresFilter->SetUpperThreshold(1);
  ThresFilter->SetInsideValue(1);
  ThresFilter->SetOutsideValue(0);
  try 
    {
      ThresFilter->Update();
    }
  catch (ExceptionObject & err)
    {
      cerr << "ExceptionObject caught!" << endl;
      cerr << err << endl;
      exit(EXIT_FAILURE);
    } 
  _outputImage = ThresFilter->GetOutput();
}

void Computation::InsideFilling()
{
  cout<<"InsideFilling..."<<endl;
  
  ImageType::Pointer InvertedImage = ImageType::New();
  InvertedImage->SetRegions(m_inputImage->GetRequestedRegion());
  InvertedImage->SetSpacing(m_inputImage->GetSpacing());
  InvertedImage->Allocate();

  // Computing inverted image
  ConstIteratorType ConstInputIterator(m_inputImage,m_inputImage->GetRequestedRegion());
  IteratorType InvertedIterator(InvertedImage, m_inputImage->GetRequestedRegion());
  ConstInputIterator.GoToBegin();
  InvertedIterator.GoToBegin();
  while ( !ConstInputIterator.IsAtEnd() )
    {
      InvertedIterator.Set(1 - ConstInputIterator.Get());
      ++ConstInputIterator;
      ++InvertedIterator;
    }
  
  // Connected component on inverted image
  ImageType::Pointer LargestInvertedImage;
  ComputeLargestComponent(InvertedImage, LargestInvertedImage);
 
  // Substraction to get inside image
  subFilterType::Pointer subFilter = subFilterType::New();
  subFilter->SetInput1(InvertedImage);
  subFilter->SetInput2(LargestInvertedImage);
  try 
    {
      subFilter->Update();    
    }
  catch (ExceptionObject & err)
    {
      cerr<<"ExceptionObject caught!"<<endl;
      cerr<<err<<endl;
      exit(EXIT_FAILURE);
    }   

  // Combining images to fill the input image
  IteratorType InputIterator(m_inputImage, m_inputImage->GetRequestedRegion());
  ConstIteratorType ConstInsideIterator(subFilter->GetOutput(),subFilter->GetOutput()->GetRequestedRegion());
  InputIterator.GoToBegin();
  ConstInsideIterator.GoToBegin();
  while (!InputIterator.IsAtEnd())
    {
      if (!InputIterator.Get() && ConstInsideIterator.Get())
	InputIterator.Set(ConstInsideIterator.Get());
      ++InputIterator;
      ++ConstInsideIterator;
    }
}

void Computation::Smoothing()
{
  cout<<"Smoothing..."<<endl;

  // Creating output image
  ImageType::Pointer OutputImage = ImageType::New();
  OutputImage->SetRegions(m_inputImage->GetRequestedRegion());
  OutputImage->SetSpacing(m_inputImage->GetSpacing());
  OutputImage->Allocate();

  ConstIteratorType ConstInputIterator(m_inputImage,m_inputImage->GetRequestedRegion());
  IteratorType OutputIterator(OutputImage, m_inputImage->GetRequestedRegion());
    
  NeighborhoodIteratorType::RadiusType Radius;
  Radius.Fill(1);
  NeighborhoodIteratorType NeighborhoodInputIterator(Radius,m_inputImage,m_inputImage->GetRequestedRegion());
  NeighborhoodIteratorType::OffsetType offset1 = {{-1,-1,-1}};
  NeighborhoodIteratorType::OffsetType offset2 = {{-1,-1,0}};
  NeighborhoodIteratorType::OffsetType offset3 = {{-1,-1,1}};
  NeighborhoodIteratorType::OffsetType offset4 = {{-1,0,-1}};
  NeighborhoodIteratorType::OffsetType offset5 = {{-1,0,0}};
  NeighborhoodIteratorType::OffsetType offset6 = {{-1,0,1}};
  NeighborhoodIteratorType::OffsetType offset7 = {{-1,1,-1}};
  NeighborhoodIteratorType::OffsetType offset8 = {{-1,1,0}};
  NeighborhoodIteratorType::OffsetType offset9 = {{-1,1,1}};
  
  NeighborhoodIteratorType::OffsetType offset10 = {{0,-1,-1}};
  NeighborhoodIteratorType::OffsetType offset11 = {{0,-1,0}};
  NeighborhoodIteratorType::OffsetType offset12 = {{0,-1,1}};
  NeighborhoodIteratorType::OffsetType offset13 = {{0,0,-1}};
  NeighborhoodIteratorType::OffsetType offset14 = {{0,0,1}};
  NeighborhoodIteratorType::OffsetType offset15 = {{0,1,-1}};
  NeighborhoodIteratorType::OffsetType offset16 = {{0,1,0}};
  NeighborhoodIteratorType::OffsetType offset17 = {{0,1,1}};
  
  NeighborhoodIteratorType::OffsetType offset18 = {{1,-1,-1}};
  NeighborhoodIteratorType::OffsetType offset19 = {{1,-1,0}};
  NeighborhoodIteratorType::OffsetType offset20 = {{1,-1,1}};
  NeighborhoodIteratorType::OffsetType offset21 = {{1,0,-1}};
  NeighborhoodIteratorType::OffsetType offset22 = {{1,0,0}};
  NeighborhoodIteratorType::OffsetType offset23 = {{1,0,1}};
  NeighborhoodIteratorType::OffsetType offset24 = {{1,1,-1}};
  NeighborhoodIteratorType::OffsetType offset25 = {{1,1,0}};
  NeighborhoodIteratorType::OffsetType offset26 = {{1,1,1}};
 
  // Computation
  // -  Copy input data to output
  for (ConstInputIterator.GoToBegin(), OutputIterator.GoToBegin(); !ConstInputIterator.IsAtEnd(); \
       ++ConstInputIterator, ++OutputIterator)
    OutputIterator.Set(ConstInputIterator.Get());
      
  ImagePixelType sum;
  OutputIterator.GoToBegin();
  NeighborhoodInputIterator.GoToBegin();
  while (!OutputIterator.IsAtEnd())
      {
	sum = NeighborhoodInputIterator.GetPixel(offset1)+NeighborhoodInputIterator.GetPixel(offset2)+\
	  NeighborhoodInputIterator.GetPixel(offset3)+NeighborhoodInputIterator.GetPixel(offset4)+ \
	  NeighborhoodInputIterator.GetPixel(offset5)+NeighborhoodInputIterator.GetPixel(offset6)+ \
	  NeighborhoodInputIterator.GetPixel(offset7)+NeighborhoodInputIterator.GetPixel(offset8)+ \
	  NeighborhoodInputIterator.GetPixel(offset9)+NeighborhoodInputIterator.GetPixel(offset10)+ \
	  NeighborhoodInputIterator.GetPixel(offset11)+NeighborhoodInputIterator.GetPixel(offset12)+ \
	  NeighborhoodInputIterator.GetPixel(offset13)+NeighborhoodInputIterator.GetPixel(offset14)+ \
	  NeighborhoodInputIterator.GetPixel(offset15)+NeighborhoodInputIterator.GetPixel(offset16)+ \
	  NeighborhoodInputIterator.GetPixel(offset17)+NeighborhoodInputIterator.GetPixel(offset18)+ \
	  NeighborhoodInputIterator.GetPixel(offset19)+NeighborhoodInputIterator.GetPixel(offset20)+ \
	  NeighborhoodInputIterator.GetPixel(offset21)+NeighborhoodInputIterator.GetPixel(offset22)+ \
	  NeighborhoodInputIterator.GetPixel(offset23)+NeighborhoodInputIterator.GetPixel(offset24)+ \
	  NeighborhoodInputIterator.GetPixel(offset25)+NeighborhoodInputIterator.GetPixel(offset26);

	if ( (NeighborhoodInputIterator.GetCenterPixel() == 0) && (sum > 13) )
	  OutputIterator.Set(1);
	else if ( (NeighborhoodInputIterator.GetCenterPixel() == 1) && (sum < 13) )
	  OutputIterator.Set(0);

	++OutputIterator;
	++NeighborhoodInputIterator;
      }
  m_inputImage = OutputImage;
}

void Computation::ConnectivityEnforcement()
{
  cout<<"Connectivity enforcement..."<<endl;
  
  ImageType::IndexType nullIndex;
  nullIndex[0] = 0;
  nullIndex[1] = 0;
  nullIndex[2] = 0;
  
  ImagePixelType *data = &((*m_inputImage)[nullIndex]);
  ImageType::RegionType imageRegion = m_inputImage->GetBufferedRegion();
  int dim[3];
  dim[0] = imageRegion.GetSize(0);
  dim[1] = imageRegion.GetSize(1);
  dim[2] = imageRegion.GetSize(2);

  clear_edge(data, dim, 0);
  NoDiagConnect(data,dim);
}

// does not allow connection via diagonals only, enforces strict 6 connectedness
// image has to be of type unsigned short
int Computation::NoDiagConnect(unsigned short *image, int *dim) 
{
  //z axis
  int dimx = dim[0];
  int dimy = dim[1];
  int dimz = dim[2];
  bool correctionNeeded = true;
  int cnt = 0;

  while (correctionNeeded)
    {
      cnt++;
      //if (debug) cout << "NoDiag scan " << cnt << endl; 
      correctionNeeded = false;
      int dy = dimx*dimy;
      int dx = dimx;
      
      for (int i = 1; i < dimx - 1; i++)
	{
	  for (int j = 1; j < dimy - 1; j++)
	    {
	      for (int k = 1; k < dimz - 1; k++) 
		{
		  unsigned short val = image[i + j * dimx + k * dy];
		  if (val != 0)
		    {
		      // x,y 
		      if ((image[i-1+j*dx+k*dy] == 0) && (image[i+(j-1)*dx+k*dy] == 0) && (image[i-1+(j-1)*dx+k*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i-1+j*dx+k*dy] = val;
			}
		      if ((image[i+1+j*dx+k*dy] == 0) && (image[i+(j+1)*dx+k*dy] == 0) && (image[i+1+(j+1)*dx+k*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+1+j*dx+k*dy] = val;
			}
		      if ((image[i+1+j*dx+k*dy] == 0) && (image[i+(j-1)*dx+k*dy] == 0) && (image[i+1+(j-1)*dx+k*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+1+j*dx+k*dy] = val;
			}
		      if ((image[i-1+j*dx+k*dy] == 0) && (image[i+(j+1)*dx+k*dy] == 0) && (image[i-1+(j+1)*dx+k*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i-1+j*dx+k*dy] = val;
			}
         
		      // xz
		      if ((image[i-1+j*dx+k*dy] == 0) && (image[i+j*dx+(k-1)*dy] == 0) && (image[i-1+j*dx+(k-1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i-1+j*dx+k*dy] = val;
			}
		      if ((image[i+1+j*dx+k*dy] == 0) && (image[i+j*dx+(k+1)*dy] == 0) && (image[i+1+j*dx+(k+1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+1+j*dx+k*dy] = val;
			}
		      if ((image[i+1+j*dx+k*dy] == 0) && (image[i+j*dx+(k-1)*dy] == 0) && (image[i+1+j*dx+(k-1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+1+j*dx+k*dy] = val;
			}
		      if ((image[i-1+j*dx+k*dy] == 0) && (image[i+j*dx+(k+1)*dy] == 0) && (image[i-1+j*dx+(k+1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i-1+j*dx+k*dy] = val;
			}
		      
		      // yz
		      if ((image[i+(j-1)*dx+k*dy] == 0) && (image[i+j*dx+(k-1)*dy] == 0) && (image[i+(j-1)*dx+(k-1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+(j-1)*dx+k*dy] = val;
			}
		      if ((image[i+(j+1)*dx+k*dy] == 0) && (image[i+j*dx+(k+1)*dy] == 0) && (image[i+(j+1)*dx+(k+1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+(j+1)*dx+k*dy] = val;
			}
		      if ((image[i+(j+1)*dx+k*dy] == 0) && (image[i+j*dx+(k-1)*dy] == 0) && (image[i+(j+1)*dx+(k-1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+(j+1)*dx+k*dy] = val;
			}
		      if ((image[i+(j-1)*dx+k*dy] == 0) && (image[i+j*dx+(k+1)*dy] == 0) && (image[i+(j-1)*dx+(k+1)*dy] != 0))
			{
			  correctionNeeded = true;
			  image[i+(j-1)*dx+k*dy] = val;
			}
		    }
		}
	    }
	}
    }  
  return 1;
}

// clears the edge of the image
void Computation::clear_edge(unsigned short *image, int *dims, int clear_label)
{
  int size_plane = dims[0]*dims[1];
  int size_line = dims[0];

  for (int z = 0; z < dims[2]; z++)
    {
      for (int y = 0; y < dims[1]; y++)
	{
	  if ( (y == 0) || (y == dims[1]-1) || (z == 0) || (z == dims[2]-1) )
	    { // draw whole plane
	      for (int x = 0; x < dims[0] ; x++) 
		image[x +  size_line * y + size_plane * z] = clear_label;
	    }
	  else 
	    { // draw edges of x
	      image[0 +  size_line * y + size_plane * z] = clear_label;
	      image[size_line - 1 +  size_line * y + size_plane * z] = clear_label;
	    }
	}
    }
}

void Computation::WriteImage()
{
  cout<<"Writing image..."<<endl;
  
  VolumeWriterType::Pointer writer = VolumeWriterType::New();
  writer->SetFileName(GetOutputImage());
  writer->SetInput(m_inputImage);
  writer->UseCompressionOn();
  try
    {
      writer->Update();
    }
  catch (itk::ExceptionObject & err)
    {
      std::cerr<<"Exception object caught!"<<std::endl;
      std::cerr<<err<<std::endl;
      exit(EXIT_FAILURE);
    }
}


