/*=========================================================================
  Program:   LabelSegPostProcess
  Module:    $RCSfile: Computation.cxx,v $
  Language:  C++
  Date:      $Date: 2011/04/15 11:19:42 $
  Version:   $Revision: 1.0 $
  Author:    Clement Vachet (cvachet@unc.edu)

  Copyright (c) Clement Vachet. All rights reserved.
  See NeuroLibCopyright.txt or http://www.ia.unc.edu/dev/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even 
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/

#include "Computation.h"

void Computation::Compute()
{
  std::cout<<"Computation..."<<std::endl;

  //Extract tissue labels
  ExtractTissueLabels();

  //Create WMMap
  CreateWMMap();

  // Create GMMap
  CreateGMMap();

  // Compute new additionnal GM (GM which is no longer WM due to pre-processing)
  ComputeNewAddGM();

  // Create final label map
  CreateFinalLabelMap();

  // Save final label map
  WriteImage(m_OutputLabelName, m_outputLabelImage);

  // Save optional WM label
  if (!m_OutputWMName.empty())
    WriteImage(m_OutputWMName, m_WMLabelImage);

  // Save optional GM label
  if (!m_OutputGMName.empty())
    {
      // Update GMMap (old GMMap + new additionnal GM - GM which became WM), by extracting label from final label map
      ExtractLabel(m_outputLabelImage,2,m_GMLabelImage);
      WriteImage(m_OutputGMName, m_GMLabelImage);
    }

  // Save optional CSF label
  if (!m_OutputCSFName.empty())
    {
      ExtractLabel(m_outputLabelImage,3,m_CSFLabelImage);
      WriteImage(m_OutputCSFName, m_CSFLabelImage);
    }
}

void Computation::ReadInputImages()
{
  std::cout<<"Reading input images..."<<std::endl;

  ReadImage(m_InputLabelName, m_inputLabelImage);
  ReadImage(m_AbsoluteWMMaskImageName, m_AbsoluteWMMaskImage);
  ReadImage(m_RemoveGMMaskImageName,m_RemoveGMMaskImage);
  ReadImage(m_ExclusionMaskImageName,m_ExclusionMaskImage);
}

void Computation::ExtractTissueLabels()
{
  if (DebugOn()) std::cout<<"Extracting tissue labels..."<<std::endl;

  ExtractLabel(m_inputLabelImage,GetWMLabel(), m_WMLabelImage);
  ExtractLabel(m_inputLabelImage,GetGMLabel(), m_GMLabelImage);
  ExtractLabel(m_inputLabelImage,GetCSFLabel(), m_CSFLabelImage);
}

void Computation::CreateWMMap()
{
  if (DebugOn()) std::cout<<"Creating WMMap..."<<std::endl;

  // Dilate AbsoluteWMMaskImage
  if (DilationOn())
    Dilation(m_AbsoluteWMMaskImage,m_AbsoluteWMMaskImage);

  //Combine WM and dilated AbsoluteWMMask:
  Combine(m_WMLabelImage,m_AbsoluteWMMaskImage,m_WMLabelImage);

  //Remove RemoveGMMask
  Subtract(m_WMLabelImage,m_RemoveGMMaskImage,m_WMLabelImage);
  ThreshMask(m_WMLabelImage,m_WMLabelImage);

  //Remove ExclusionMask
  Subtract(m_WMLabelImage,m_ExclusionMaskImage,m_WMLabelImage);
  ThreshMask(m_WMLabelImage,m_WMLabelImage);

  // Inside Filling()
  InsideFilling(m_WMLabelImage,m_WMLabelImage);
}

void Computation::CreateGMMap()
{
  if (DebugOn()) std::cout<<"Creating GMMap..."<<std::endl;

  Combine(m_GMLabelImage,m_RemoveGMMaskImage,m_GMLabelImage);
  Subtract(m_GMLabelImage,m_ExclusionMaskImage,m_GMLabelImage);
  ThreshMask(m_GMLabelImage,m_GMLabelImage);
}

void Computation::ComputeNewAddGM()
{
  if (DebugOn()) std::cout<<"Computing new additionnal GM..."<<std::endl;
  
  ImageType::Pointer m_WMGMImage = ImageType::New();
  m_WMGMImage->SetRegions(m_inputLabelImage->GetRequestedRegion());
  m_WMGMImage->CopyInformation(m_inputLabelImage);
  m_WMGMImage->Allocate();
  
  // Combine WMMap and GMMap
  Combine(m_WMLabelImage,m_GMLabelImage,m_WMGMImage);

  // Filling
  ImageType::Pointer m_FilledWMGMImage = ImageType::New();
  m_FilledWMGMImage->SetRegions(m_inputLabelImage->GetRequestedRegion());
  m_FilledWMGMImage->CopyInformation(m_inputLabelImage);
  m_FilledWMGMImage->Allocate();

  InsideFilling(m_WMGMImage,m_FilledWMGMImage);

  //New Additionnal GM
  m_newAddGMImage = ImageType::New();
  m_newAddGMImage->SetRegions(m_inputLabelImage->GetRequestedRegion());
  m_newAddGMImage->CopyInformation(m_inputLabelImage);
  m_newAddGMImage->Allocate();

  Subtract(m_FilledWMGMImage,m_WMGMImage,m_newAddGMImage);

  ImageType::Pointer m_newAddGM_MaskCSFImage = ImageType::New();
  m_newAddGM_MaskCSFImage->SetRegions(m_inputLabelImage->GetRequestedRegion());
  m_newAddGM_MaskCSFImage->CopyInformation(m_inputLabelImage);
  m_newAddGM_MaskCSFImage->Allocate();
  
  Mask(m_newAddGMImage,m_CSFLabelImage,m_newAddGM_MaskCSFImage);
  Subtract(m_newAddGMImage,m_newAddGM_MaskCSFImage,m_newAddGMImage);
}

void Computation::CreateFinalLabelMap()
{
  if (DebugOn()) std::cout<<"Creating final label map..."<<std::endl;

  m_outputLabelImage = ImageType::New();
  m_outputLabelImage->SetRegions(m_inputLabelImage->GetRequestedRegion());
  m_outputLabelImage->CopyInformation(m_inputLabelImage);
  m_outputLabelImage->Allocate();

  Multiply(m_GMLabelImage,2);
  Multiply(m_newAddGMImage,2);
  Combine(m_WMLabelImage,m_GMLabelImage,m_outputLabelImage);
  Combine(m_outputLabelImage,m_newAddGMImage,m_outputLabelImage);

  Subtract(m_CSFLabelImage,m_ExclusionMaskImage,m_CSFLabelImage);
  ThreshMask(m_CSFLabelImage,m_CSFLabelImage);
  Multiply(m_CSFLabelImage,3);
  Combine(m_outputLabelImage,m_CSFLabelImage,m_outputLabelImage);
}

void Computation::ReadImage(std::string _FileName, ImageType::Pointer &_outputImage)
{
  if (DebugOn()) std::cout<<"\tReading image "<<_FileName<<" ..."<<std::endl;

  VolumeReaderType::Pointer imageReader = VolumeReaderType::New();
  imageReader->SetFileName(_FileName);
  try
    {
      imageReader->Update();
    }
  catch (ExceptionObject err)
    {
      std::cerr<<"Exception object caught!"<<std::endl;
      std::cerr<<err<<std::endl;
      exit(EXIT_FAILURE);
    }
  _outputImage = imageReader->GetOutput();
}

void Computation::WriteImage(std::string _FileName, ImageType::Pointer _Image)
{
  if (DebugOn()) std::cout<<"Writing image "<<_FileName<<" ..."<<std::endl;
  
  VolumeWriterType::Pointer writer = VolumeWriterType::New();
  writer->SetFileName(_FileName);
  writer->SetInput(_Image);
  writer->UseCompressionOn();
  try
    {
      writer->Update();
    }
  catch (itk::ExceptionObject & err)
    {
      std::cerr<<"Exception object caught!"<<std::endl;
      std::cerr<<err<<std::endl;
      exit(EXIT_FAILURE);
    }
}

void Computation::ExtractLabel(ImageType::Pointer _inputImage, unsigned int _Label, ImageType::Pointer &_outputImage)
{
  if (DebugOn()) std::cout<<"\tExtracting label "<<_Label<<" ..."<<std::endl;

  threshFilterType::Pointer threshFilter = threshFilterType::New();
  threshFilter->SetInput(_inputImage);
  threshFilter->SetLowerThreshold(_Label);
  threshFilter->SetUpperThreshold(_Label);
  threshFilter->SetOutsideValue (0);
  threshFilter->SetInsideValue (1);
  try 
    {
      threshFilter->Update();
    }
  catch (ExceptionObject & err)
    {
      std::cerr << "ExceptionObject caught!" << std::endl;
      std::cerr << err << std::endl;
      exit(EXIT_FAILURE);
    }
  _outputImage = threshFilter->GetOutput();  
}

void Computation::Dilation(ImageType::Pointer _inputImage, ImageType::Pointer &_outputImage)
{
  if (DebugOn()) std::cout<<"\tDilation..."<<std::endl;

  StructuringElementType structuringElement;
  structuringElement.SetRadius(1);  // 3x3x3 structuring element
  structuringElement.CreateStructuringElement( );
  
  dilateFilterType::Pointer dilateFilter = dilateFilterType::New(); 
  dilateFilter->SetInput(_inputImage);
  dilateFilter->SetDilateValue (1);
  dilateFilter->SetKernel( structuringElement );
  try 
    {
      dilateFilter->Update();
    }
  catch (ExceptionObject & err) 
    {
      std::cerr << "ExceptionObject caught!" << std::endl;
      std::cerr << err << std::endl;
      exit(EXIT_FAILURE);
    }
  
  _outputImage = dilateFilter->GetOutput();  
}

void Computation::Combine(ImageType::Pointer _inputImage1, ImageType::Pointer _inputImage2, ImageType::Pointer &_outputImage)
{
  if (DebugOn()) std::cout<<"\tCombining images..."<<std::endl;

  IteratorType iterImage1 (_inputImage1, _inputImage1->GetBufferedRegion());
  IteratorType iterImage2 (_inputImage2, _inputImage2->GetBufferedRegion());
  IteratorType iterOutputImage (_outputImage, _outputImage->GetBufferedRegion());
  
  while ( !iterOutputImage.IsAtEnd() )
    {
      ImagePixelType value1 =  iterImage1.Get();
      ImagePixelType value2 =  iterImage2.Get();
      
      iterOutputImage.Set(value1);
      if (!value1 && value2)
	iterOutputImage.Set(value2);
      ++iterImage1;
      ++iterImage2;
      ++iterOutputImage;
    }
}

void Computation::Subtract(ImageType::Pointer _inputImage1, ImageType::Pointer _inputImage2, ImageType::Pointer &_outputImage)
{
  if (DebugOn()) std::cout<<"\tSubtracting images..."<<std::endl;

  subFilterType::Pointer subFilter = subFilterType::New();
  subFilter->SetInput1(_inputImage1);
  subFilter->SetInput2(_inputImage2);
  try 
    {
      subFilter->Update();
    }
  catch (ExceptionObject & err)
    {
      std::cerr << "ExceptionObject caught!" << std::endl;
      std::cerr << err << std::endl;
      exit(EXIT_FAILURE);
    }
  _outputImage = subFilter->GetOutput();
}

void Computation::ThreshMask(ImageType::Pointer _inputImage, ImageType::Pointer &_outputImage)
{
  if (DebugOn()) std::cout<<"\tThreshMasking images..."<<std::endl;
  
  maskThreshFilterType::Pointer threshFilter = maskThreshFilterType::New();
  threshFilter->SetInput(_inputImage);
  threshFilter->SetOutsideValue(0);
  threshFilter->ThresholdOutside(0,1);
  try 
    {
      threshFilter->Update();
    }
  catch (ExceptionObject & err) 
    {
      std::cerr << "ExceptionObject caught!" << std::endl;
      std::cerr << err << std::endl;
      exit(EXIT_FAILURE);
    }        
  _outputImage = threshFilter->GetOutput();  
}

void Computation::InsideFilling(ImageType::Pointer _inputImage, ImageType::Pointer &_outputImage)
{
  if (DebugOn()) std::cout<<"\tInside Filling..."<<std::endl;

  ImageType::Pointer InvertedImage = ImageType::New();
  InvertedImage->SetRegions(_inputImage->GetRequestedRegion());
  InvertedImage->CopyInformation(_inputImage);
  InvertedImage->Allocate();

  // Computing inverted image
  ConstIteratorType ConstInputIterator(_inputImage,_inputImage->GetRequestedRegion());
  IteratorType InvertedIterator(InvertedImage, _inputImage->GetRequestedRegion());
  ConstInputIterator.GoToBegin();
  InvertedIterator.GoToBegin();
  while ( !ConstInputIterator.IsAtEnd() )
    {
      InvertedIterator.Set(1 - ConstInputIterator.Get());
      ++ConstInputIterator;
      ++InvertedIterator;
    }
  
  // Connected component on inverted image
  ImageType::Pointer LargestInvertedImage;
  ComputeLargestComponent(InvertedImage, LargestInvertedImage);
 
  // Substraction to get inside image
  subFilterType::Pointer subFilter = subFilterType::New();
  subFilter->SetInput1(InvertedImage);
  subFilter->SetInput2(LargestInvertedImage);
  try 
    {
      subFilter->Update();    
    }
  catch (ExceptionObject & err)
    {
      std::cerr<<"ExceptionObject caught!"<<std::endl;
      std::cerr<<err<<std::endl;
      exit(EXIT_FAILURE);
    }   

  // Combining images to fill the input image
  ConstIteratorType ConstInsideIterator(subFilter->GetOutput(),subFilter->GetOutput()->GetRequestedRegion());
  IteratorType outputIterator(_outputImage,_outputImage->GetRequestedRegion());
  ConstInputIterator.GoToBegin();
  ConstInsideIterator.GoToBegin();
  outputIterator.GoToBegin();
  while (!outputIterator.IsAtEnd())
    {
      ImagePixelType value1 =  ConstInputIterator.Get();
      ImagePixelType value2 =  ConstInsideIterator.Get();
      
      outputIterator.Set(value1);
      if (!value1 && value2)
	outputIterator.Set(value2);
      ++ConstInputIterator;
      ++ConstInsideIterator;
      ++outputIterator;
    }
}

void Computation::ComputeLargestComponent(ImageType::Pointer _inputImage, ImageType::Pointer &_outputImage)
{
  ConnectiveFilterType::Pointer Connective = ConnectiveFilterType::New();
  RelabelFilterType::Pointer relabelFilter = RelabelFilterType::New();
  threshFilterType::Pointer ThresFilter = threshFilterType::New();
  
  //Get the connectivity map of the image
  Connective->SetInput(_inputImage);
  try 
    {
      Connective->Update();
    }
  catch (ExceptionObject & err)
    {
      std::cerr << "ExceptionObject caught!" << std::endl;
      std::cerr << err << std::endl;
      exit(EXIT_FAILURE);	
    } 
  //Sort the labels according to their size, each labeled object has a different value
  relabelFilter->SetInput(Connective->GetOutput());    
  
  //Keep the Lbl label which is the biggest component
  ThresFilter->SetInput(relabelFilter->GetOutput());  
  ThresFilter->SetLowerThreshold(1);
  ThresFilter->SetUpperThreshold(1);
  ThresFilter->SetInsideValue(1);
  ThresFilter->SetOutsideValue(0);
  try 
    {
      ThresFilter->Update();
    }
  catch (ExceptionObject & err)
    {
      std::cerr << "ExceptionObject caught!" << std::endl;
      std::cerr << err << std::endl;
      exit(EXIT_FAILURE);
    } 
    _outputImage = ThresFilter->GetOutput();
}


void Computation::Mask(ImageType::Pointer _inputImage1, ImageType::Pointer _inputImage2, ImageType::Pointer &_outputImage)
{
  if (DebugOn()) std::cout<<"\tMasking..."<<std::endl;

  maskFilterType::Pointer maskFilter = maskFilterType::New() ;
  maskFilter->SetInput1( _inputImage1 ) ;
  maskFilter->SetInput2( _inputImage2 ) ;
  try
    {
      maskFilter->Update() ;
    }
  catch (ExceptionObject & err)
    {
      std::cerr << "ExceptionObject caught!" << std::endl;
      std::cerr << err << std::endl;
      exit(EXIT_FAILURE);
    }
  _outputImage = maskFilter->GetOutput();
}

void Computation::Multiply(ImageType::Pointer _inputImage, ImagePixelType value)
{
  if (DebugOn()) std::cout<<"\tMultiplying image..."<<std::endl;

  IteratorType iterImage (_inputImage, _inputImage->GetBufferedRegion());
  while ( !iterImage.IsAtEnd() ) 
    {
      iterImage.Set(value * iterImage.Get());
      ++iterImage;
    }
}
