#include <iostream>
#include "itkImage.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkListSample.h"
#include "itkMembershipSample.h"
#include "itkVector.h"
#include "itkConnectedComponentImageFilter.h"
#include "itkLabelStatisticsImageFilter.h"
#include "itkRelabelComponentImageFilter.h"
#include "itkJoinImageFilter.h"
#include "itkThresholdImageFilter.h"
#include "itkMaskImageFilter.h"
#include "itkHistogramMatchingImageFilter.h"
//#include "itkCurvatureAnisotropicDiffusionImageFilter.h"
#include "itkCurvatureAnisotropicDiffusionImageFilter.h"
#include "itkStatisticsImageFilter.h"
//#include "itkMinimumMaximumImageFilter.h"
#include "itkCastImageFilter.h"
#include "itkThresholdImageFilter.h"
#include "itkMaskImageFilter.h"
#include "itkPluginFilterWatcher.h"

#include "IntensityStandardizeCLP.h"

int main(int argc, char * argv [])
{
  PARSE_ARGS;

  bool violated=false;
  if (inputT1Mov.size() == 0) { violated = true; std::cout << "  --inputT1Mov Required! "  << std::endl; }
  if (inputT2Mov.size() == 0) { violated = true; std::cout << "  --inputT2Mov Required! "  << std::endl; }
  if (inputFLAIRMov.size() == 0) { violated = true; std::cout << "  --inputFLAIRMov Required! "  << std::endl; }
  if (inputMovMask.size() == 0) { violated = true; std::cout << "  --inputMovMask Required! "  << std::endl; }
  if (inputT1Ref.size() == 0) { violated = true; std::cout << "  --inputT1Ref Required! "  << std::endl; }
  if (inputT2Ref.size() == 0) { violated = true; std::cout << "  --inputT2Ref Required! "  << std::endl; }
  if (inputFLAIRRef.size() == 0) { violated = true; std::cout << "  --inputFLAIRRef Required! "  << std::endl; }
  if (inputRefMask.size() == 0) { violated = true; std::cout << "  --inputRefMask Required! "  << std::endl; }
  if (outputT1Mov.size() == 0) { violated = true; std::cout << "  --outputT1Mov Required! "  << std::endl; }
  if (outputT2Mov.size() == 0) { violated = true; std::cout << "  --outputT2Mov Required! "  << std::endl; }
  if (outputFLAIRMov.size() == 0) { violated = true; std::cout << "  --outputFLAIRMov Required! "  << std::endl; }
  if (violated) exit(1);

  typedef signed short       PixelType;
  const unsigned int          Dimension = 3;

  typedef itk::Image< PixelType,  Dimension >   ImageType;
  typedef itk::ImageFileReader< ImageType  >  ReaderType;

  ReaderType::Pointer movT1Reader = ReaderType::New();
  ReaderType::Pointer movT2Reader = ReaderType::New();
  ReaderType::Pointer movFLAIRReader = ReaderType::New();
  ReaderType::Pointer maskMovReader = ReaderType::New();
  ReaderType::Pointer refT1Reader = ReaderType::New();
  ReaderType::Pointer refT2Reader = ReaderType::New();
  ReaderType::Pointer refFLAIRReader = ReaderType::New();
  ReaderType::Pointer maskRefReader = ReaderType::New();

  typedef itk::ImageFileWriter< ImageType >  WriterType;
  WriterType::Pointer outputT1Writer = WriterType::New();
  WriterType::Pointer outputT2Writer = WriterType::New();
  WriterType::Pointer outputFLAIRWriter = WriterType::New();

  movT1Reader->SetFileName( inputT1Mov.c_str() );
  movT2Reader->SetFileName( inputT2Mov.c_str() );
  movFLAIRReader->SetFileName( inputFLAIRMov.c_str() );
  maskMovReader->SetFileName( inputMovMask.c_str() );
  refT1Reader->SetFileName( inputT1Ref.c_str() );
  refT2Reader->SetFileName( inputT2Ref.c_str() );
  refFLAIRReader->SetFileName( inputFLAIRRef.c_str() );
  maskRefReader->SetFileName( inputRefMask.c_str() );
  outputT1Writer->SetFileName (outputT1Mov.c_str());
  outputT2Writer->SetFileName (outputT2Mov.c_str());
  outputFLAIRWriter->SetFileName (outputFLAIRMov.c_str());

  /* Setup the plugin watcher to monitor the reading of images */
  double start = 0;
  double fraction = .0125;
  itk::PluginFilterWatcher watchT1MovRead(movT1Reader,"Reading Moving T1 Image",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchT2MovRead(movT2Reader,"Reading Moving T2 Image",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchFLAIRMovRead(movFLAIRReader,"Reading Moving FLAIR Image",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchMaskMovRead(maskMovReader,"Reading Moving Brain Mask",CLPProcessInformation,fraction,start);
  start+=fraction;

  itk::PluginFilterWatcher watchT1RefRead(refT1Reader,"Reading Reference T1 Image",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchT2RefRead(refT2Reader,"Reading Reference T2 Image",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchFLAIRRefRead(refFLAIRReader,"Reading Reference FLAIR Image",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchMaskRefRead(maskRefReader,"Reading Reference Brain Mask",CLPProcessInformation,fraction,start);
  start+=fraction;

typedef itk::HistogramMatchingImageFilter< ImageType, ImageType> HistogramMatchingFilterType;
  HistogramMatchingFilterType::Pointer histogramMatchingFilterT1 = HistogramMatchingFilterType::New();
  HistogramMatchingFilterType::Pointer histogramMatchingFilterT2 = HistogramMatchingFilterType::New();
  HistogramMatchingFilterType::Pointer histogramMatchingFilterFLAIR = HistogramMatchingFilterType::New();

  typedef HistogramMatchingFilterType::OutputImageType HistMatchType;
  typedef itk::Image< float,  3 >   SmoothedImageType;

  typedef itk::StatisticsImageFilter< ImageType > StatisticsFilterType;
  StatisticsFilterType::Pointer statisticsFilterT1 = StatisticsFilterType::New();
  StatisticsFilterType::Pointer statisticsFilterT2 = StatisticsFilterType::New();
  StatisticsFilterType::Pointer statisticsFilterFLAIR = StatisticsFilterType::New();

  fraction = .11;
  itk::PluginFilterWatcher watchT1Stats(statisticsFilterT1,"Calculating T1 Statistics",CLPProcessInformation,start,fraction);
  start+=fraction;
  itk::PluginFilterWatcher watchT2Stats(statisticsFilterT2,"Calculating T2 Statistics",CLPProcessInformation,start,fraction);
  start+=fraction;
  itk::PluginFilterWatcher watchFLAIRStats(statisticsFilterFLAIR,"Calculating FLAIR Statistics",CLPProcessInformation,start,fraction);
  start+=fraction;

//  typedef itk::CastImageFilter< SmoothedImageType,ImageType > CastFilterType;
  typedef itk::CastImageFilter< HistMatchType,ImageType > CastFilterType;
  CastFilterType::Pointer castFilterT1 = CastFilterType::New();
  CastFilterType::Pointer castFilterT2 = CastFilterType::New();
  CastFilterType::Pointer castFilterFLAIR = CastFilterType::New();

  unsigned long numOfBinsT1 = 0;
  unsigned long numOfBinsT2 = 0;
  unsigned long numOfBinsFLAIR = 0;

  const PixelType maskThresholdBelow = 1;     // someday with more generality?

/* The Threshold Image Filter is used to produce the brain clipping mask. */
  typedef itk::ThresholdImageFilter< ImageType > ThresholdFilterType;

  ThresholdFilterType::Pointer brainMaskFilterMov = ThresholdFilterType::New();
  brainMaskFilterMov->SetInput( maskMovReader->GetOutput() );
  brainMaskFilterMov->ThresholdBelow( maskThresholdBelow );
  brainMaskFilterMov->Update();

  ThresholdFilterType::Pointer brainMaskFilterRef = ThresholdFilterType::New();
  brainMaskFilterRef->SetInput( maskRefReader->GetOutput() );
  brainMaskFilterRef->ThresholdBelow( maskThresholdBelow );
  brainMaskFilterRef->Update();

  /* The Mask Image Filter applies the clipping mask by stepping 
     on the excluded region with the imageExclusion value. */
  typedef itk::MaskImageFilter< ImageType, ImageType > MaskFilterType;
  MaskFilterType::Pointer brainMovFilterT1 = MaskFilterType::New();
  MaskFilterType::Pointer brainMovFilterT2 = MaskFilterType::New();
  MaskFilterType::Pointer brainMovFilterFLAIR = MaskFilterType::New();
  brainMovFilterT1->SetInput1( movT1Reader->GetOutput() );
  brainMovFilterT2->SetInput1( movT2Reader->GetOutput() );
  brainMovFilterFLAIR->SetInput1( movFLAIRReader->GetOutput() );
  brainMovFilterT1->SetInput2( brainMaskFilterMov->GetOutput() );
  brainMovFilterT2->SetInput2( brainMaskFilterMov->GetOutput() );
  brainMovFilterFLAIR->SetInput2( brainMaskFilterMov->GetOutput() );
  brainMovFilterT1->SetOutsideValue( 0 );
  brainMovFilterT2->SetOutsideValue( 0 );
  brainMovFilterFLAIR->SetOutsideValue( 0 );
  brainMovFilterT1->Update();
  brainMovFilterT2->Update();
  brainMovFilterFLAIR->Update();

  MaskFilterType::Pointer brainRefFilterT1 = MaskFilterType::New();
  MaskFilterType::Pointer brainRefFilterT2 = MaskFilterType::New();
  MaskFilterType::Pointer brainRefFilterFLAIR = MaskFilterType::New();
  brainRefFilterT1->SetInput1( refT1Reader->GetOutput() );
  brainRefFilterT2->SetInput1( refT2Reader->GetOutput() );
  brainRefFilterFLAIR->SetInput1( refFLAIRReader->GetOutput() );
  brainRefFilterT1->SetInput2( brainMaskFilterRef->GetOutput() );
  brainRefFilterT2->SetInput2( brainMaskFilterRef->GetOutput() );
  brainRefFilterFLAIR->SetInput2( brainMaskFilterRef->GetOutput() );
  brainRefFilterT1->SetOutsideValue( 0 );
  brainRefFilterT2->SetOutsideValue( 0 );
  brainRefFilterFLAIR->SetOutsideValue( 0 );
  brainRefFilterT1->Update();
  brainRefFilterT2->Update();
  brainRefFilterFLAIR->Update();

  fraction = .18;
  itk::PluginFilterWatcher watchHistMatchT1(histogramMatchingFilterT1,"Histogram Matching T1s",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchHistMatchT2(histogramMatchingFilterT2,"Histogram Matching T2s",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchHistMatchFLAIR(histogramMatchingFilterFLAIR,"Histogram Matching FLAIRs",CLPProcessInformation,fraction,start);
  start+=fraction;

  fraction = .01;
  itk::PluginFilterWatcher watchT1Writer(outputT1Writer,"Writing Standardized T1",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchT2Writer(outputT2Writer,"Writing Standardized T2",CLPProcessInformation,fraction,start);
  start+=fraction;
  itk::PluginFilterWatcher watchFLAIRWriter(outputFLAIRWriter,"Writing Standardized FLAIR",CLPProcessInformation,fraction,start);
  
  try {  
    statisticsFilterT1->SetInput(brainMovFilterT1->GetOutput());
    statisticsFilterT2->SetInput(brainMovFilterT2->GetOutput());
    statisticsFilterFLAIR->SetInput(brainMovFilterFLAIR->GetOutput());
    statisticsFilterT1->Update();
    statisticsFilterT2->Update();
    statisticsFilterFLAIR->Update();
    numOfBinsT1 = statisticsFilterT1->GetMaximum() - statisticsFilterT1->GetMinimum() + 1;
    numOfBinsT2 = statisticsFilterT2->GetMaximum() - statisticsFilterT2->GetMinimum() + 1;
    numOfBinsFLAIR = statisticsFilterFLAIR->GetMaximum() - statisticsFilterFLAIR->GetMinimum() + 1;
    
    histogramMatchingFilterT1->SetSourceImage(brainMovFilterT1->GetOutput());
    histogramMatchingFilterT2->SetSourceImage(brainMovFilterT2->GetOutput());
    histogramMatchingFilterFLAIR->SetSourceImage(brainMovFilterFLAIR->GetOutput());
    histogramMatchingFilterT1->SetReferenceImage(brainRefFilterT1->GetOutput());
    histogramMatchingFilterT2->SetReferenceImage(brainRefFilterT2->GetOutput());
    histogramMatchingFilterFLAIR->SetReferenceImage(brainRefFilterFLAIR->GetOutput());
    histogramMatchingFilterT1->SetNumberOfHistogramLevels(numOfBinsT1); // Max value in moving image - min value in moving image + 1.
    histogramMatchingFilterT2->SetNumberOfHistogramLevels(numOfBinsT2); // Max value in moving image - min value in moving image + 1.
    histogramMatchingFilterFLAIR->SetNumberOfHistogramLevels(numOfBinsFLAIR); // Max value in moving image - min value in moving image + 1.
    histogramMatchingFilterT1->SetNumberOfMatchPoints(numOfBinsT1); // Equal to number of bins?
    histogramMatchingFilterT2->SetNumberOfMatchPoints(numOfBinsT2); // Equal to number of bins?
    histogramMatchingFilterFLAIR->SetNumberOfMatchPoints(numOfBinsFLAIR); // Equal to number of bins?
    histogramMatchingFilterT1->SetThresholdAtMeanIntensity(1);
    histogramMatchingFilterT2->SetThresholdAtMeanIntensity(1);
    histogramMatchingFilterFLAIR->SetThresholdAtMeanIntensity(1);
    castFilterT1->SetInput(histogramMatchingFilterT1->GetOutput());
    castFilterT2->SetInput(histogramMatchingFilterT2->GetOutput());
    castFilterFLAIR->SetInput(histogramMatchingFilterFLAIR->GetOutput());
    outputT1Writer->SetInput(castFilterT1->GetOutput());
    outputT2Writer->SetInput(castFilterT2->GetOutput());
    outputFLAIRWriter->SetInput(castFilterFLAIR->GetOutput());
    outputT1Writer->Update();
    outputT2Writer->Update();
    outputFLAIRWriter->Update();
  }

  catch (itk::ExceptionObject &excep)
  {
    std::cerr << argv[0] << ": exception caught !" << std::endl;
    return EXIT_FAILURE;
  }

  return EXIT_SUCCESS;
}
