#include <iostream>
#include "itkImage.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkListSample.h"
#include "itkMembershipSample.h"
#include "itkVector.h"
#include "itkConnectedComponentImageFilter.h"
#include "itkLabelStatisticsImageFilter.h"
#include "itkStatisticsImageFilter.h"
#include "itkRelabelComponentImageFilter.h"
#include "itkChangeLabelImageFilter.h"
#include "itkPluginFilterWatcher.h"
#include "itkBinaryThresholdImageFilter.h"
#include "LesionStatsCLP.h"

int main(int argc, char * argv [])
{
  PARSE_ARGS;

  bool violated=false;
  if (inputT1Volume.size() == 0) { violated = true; std::cout << "  --inputT1Volume Required! "  << std::endl; }
  if (inputT2Volume.size() == 0) { violated = true; std::cout << "  --inputT2Volume Required! "  << std::endl; }
  if (inputFLAIRVolume.size() == 0) { violated = true; std::cout << "  --inputFLAIRVolume Required! "  << std::endl; }
  if (inputLesionVolume.size() == 0) { violated = true; std::cout << "  --inputLesionVolume Required! "  << std::endl; }
  if (inputMaskVolume.size() == 0) { violated = true; std::cout << "  --inputMaskVolume Required! "  << std::endl; }
  if (inputAtlasVolume.size() == 0) { violated = true; std::cout << "  --inputAtlasVolume Required! "  << std::endl; }
  if (outputLabelVolume.size() == 0) { violated = true; std::cout << "  --outputLabelVolume Required! "  << std::endl; }
  if (outputAtlasVolume.size() == 0) { violated = true; std::cout << "  --outputAtlasVolume Required! "  << std::endl; }
  if (violated) exit(1);

  //  typedef signed short       PixelType;
  typedef unsigned int       PixelType;
  const unsigned int          Dimension = 3;

  typedef itk::Image< PixelType,  Dimension >   ImageType;
  typedef itk::ImageFileReader< ImageType  >  ReaderType;
  typedef itk::ImageFileWriter< ImageType >  WriterType;
  ReaderType::Pointer lesionReader = ReaderType::New();
  ReaderType::Pointer t1Reader = ReaderType::New();
  ReaderType::Pointer t2Reader = ReaderType::New();
  ReaderType::Pointer flairReader = ReaderType::New();
  ReaderType::Pointer atlasReader = ReaderType::New();
  ImageType::Pointer image = ImageType::New();
  WriterType::Pointer writer = WriterType::New();

  lesionReader->SetFileName( inputLesionVolume.c_str() );
  t1Reader->SetFileName( inputT1Volume.c_str() );
  t2Reader->SetFileName( inputT2Volume.c_str() );
  flairReader->SetFileName( inputFLAIRVolume.c_str() );
  atlasReader->SetFileName( inputAtlasVolume.c_str() );
  writer->SetFileName (outputLabelVolume.c_str());

  typedef itk::ConnectedComponentImageFilter<ImageType, ImageType> ConnectedFilterType;
  ConnectedFilterType::Pointer connectedComponent = ConnectedFilterType::New();
  itk::PluginFilterWatcher watchConnected(connectedComponent,"Clustering Lesions",CLPProcessInformation);

  typedef itk::RelabelComponentImageFilter<ImageType, ImageType> RelabelFilterType;
  RelabelFilterType::Pointer relabelComponent = RelabelFilterType::New();
  itk::PluginFilterWatcher watchRelabelLesions(relabelComponent,"Relabeling Lesions",CLPProcessInformation);

  /* The Statistics Image Filter lets us find the initial cluster means.
     Should this be limited to the excluded region of the clipped T1 image?  */
  typedef itk::LabelStatisticsImageFilter< ImageType, ImageType > LabelStatisticsFilterType;
  typedef LabelStatisticsFilterType::RealType StatisticRealType;
  LabelStatisticsFilterType::Pointer statisticsFilterT1 = LabelStatisticsFilterType::New();
  LabelStatisticsFilterType::Pointer statisticsFilterT2 = LabelStatisticsFilterType::New();
  LabelStatisticsFilterType::Pointer statisticsFilterFLAIR = LabelStatisticsFilterType::New();

  itk::PluginFilterWatcher watchT1Stats(statisticsFilterT1,"Calculating T1 Statistics",CLPProcessInformation);
  itk::PluginFilterWatcher watchT2Stats(statisticsFilterT2,"Calculating T2 Statistics",CLPProcessInformation);
  itk::PluginFilterWatcher watchFLAIRStats(statisticsFilterFLAIR,"Calculating FLAIR Statistics",CLPProcessInformation);

  try {  
    connectedComponent->SetInput(lesionReader->GetOutput());
    relabelComponent->SetInput(connectedComponent->GetOutput());
    statisticsFilterT1->SetInput( t1Reader->GetOutput() );
    statisticsFilterT1->SetLabelInput( relabelComponent->GetOutput());
    statisticsFilterT1->Update();
    statisticsFilterT2->SetInput( t2Reader->GetOutput() );
    statisticsFilterT2->SetLabelInput( relabelComponent->GetOutput());
    statisticsFilterT2->Update();
    statisticsFilterFLAIR->SetInput( flairReader->GetOutput() );
    statisticsFilterFLAIR->SetLabelInput( relabelComponent->GetOutput());
    statisticsFilterFLAIR->Update();
    writer->SetInput(relabelComponent->GetOutput());
    writer->Update();
  }

  catch (itk::ExceptionObject &excep)
  {
    std::cerr << argv[0] << ": exception caught !" << std::endl;
    return EXIT_FAILURE;
  }

  unsigned long numLabels = statisticsFilterT1->GetNumberOfLabels();

  std::cout << "Number of labels: " << numLabels << std::endl;

  /*
  ** Threshold the atlas with the lesion volume.
  ** Get the unique values.
  ** Use the unique values to mask the atlas.
  ** Output the masked atlas.  
  */
  typedef itk::MaskImageFilter< ImageType, ImageType > MaskImageFilterType;
  MaskImageFilterType::Pointer maskFilter = MaskImageFilterType::New();

  maskFilter->SetInput1(atlasReader->GetOutput());
  maskFilter->SetInput2(lesionReader->GetOutput());
  maskFilter->SetOutsideValue(0);

  RelabelFilterType::Pointer relabelAtlasComponent = RelabelFilterType::New();
  itk::PluginFilterWatcher watchRelabelAtlas(relabelAtlasComponent,"Relabeling Atlas",CLPProcessInformation);

  relabelAtlasComponent->SetInput(maskFilter->GetOutput());

  LabelStatisticsFilterType::Pointer statisticsFilterAtlas = LabelStatisticsFilterType::New();
  statisticsFilterAtlas->SetInput(atlasReader->GetOutput());
  

  try
  {
    //statisticsFilterAtlas->SetLabelInput(maskFilter->GetOutput());
    statisticsFilterAtlas->SetLabelInput(relabelAtlasComponent->GetOutput());
    statisticsFilterAtlas->Update();
    
  }
  catch (itk::ExceptionObject &excep)
  {
    std::cerr << argv[0] << ": exception caught !" << std::endl;
    return EXIT_FAILURE;
  }

  unsigned long numAtlasLabels = statisticsFilterAtlas->GetNumberOfLabels();

  typedef itk::ChangeLabelImageFilter< ImageType, ImageType > ChangeLabelFilterType;
  ChangeLabelFilterType::Pointer changeLabelFilter = ChangeLabelFilterType::New();

  changeLabelFilter->SetInput(atlasReader->GetOutput());

  int flag = 0;

  for(unsigned int i=0;i<numAtlasLabels;i++)
  {
    changeLabelFilter->SetChange(statisticsFilterAtlas->GetMean(i),1);

    std::cout << "Region " << i << " Mean: " << statisticsFilterAtlas->GetMean(i) << std::endl;
    if(statisticsFilterAtlas->GetMean(i) == 1)
    {
      flag = 1;
    }
  }

  if(flag != 1)
  {
    changeLabelFilter->SetChange(1,0);
  }

  typedef itk::BinaryThresholdImageFilter< ImageType, ImageType > ThresholdImageFilterType;
  ThresholdImageFilterType::Pointer thresholdFilter = ThresholdImageFilterType::New();
  
  thresholdFilter->SetInput(changeLabelFilter->GetOutput());
  thresholdFilter->SetLowerThreshold(1);
  thresholdFilter->SetUpperThreshold(1);
  thresholdFilter->SetOutsideValue(0);
  thresholdFilter->SetInsideValue(1);

  MaskImageFilterType::Pointer maskAtlasFilter = MaskImageFilterType::New();

  maskAtlasFilter->SetInput1(atlasReader->GetOutput());
  //  maskAtlasFilter->SetInput(1,changeLabelFilter->GetOutput());
  maskAtlasFilter->SetInput2(thresholdFilter->GetOutput());
  maskAtlasFilter->SetOutsideValue(0);

  WriterType::Pointer atlasWriter = WriterType::New();
  atlasWriter->SetFileName(outputAtlasVolume.c_str());

  itk::PluginFilterWatcher watchAtlas(atlasWriter,"Generating ROI Atlas",CLPProcessInformation);

  try
  {
    atlasWriter->SetInput(maskAtlasFilter->GetOutput());
    atlasWriter->Update();
  }
  catch (itk::ExceptionObject &excep)
  {
    std::cerr << argv[0] << ": exception caught !" << std::endl;
    return EXIT_FAILURE;
  }

  /* Need to do this for each label in the ConnectedComponent .*/
  for(unsigned int i=1;i<numLabels;i++)
  {

    const PixelType imageMinT1 = static_cast<PixelType> ( statisticsFilterT1->GetMinimum(i) );
    const PixelType imageMaxT1 = static_cast<PixelType> ( statisticsFilterT1->GetMaximum(i) );
    const StatisticRealType imageMeanT1 = statisticsFilterT1->GetMean(i);
    const StatisticRealType imageSigmaT1 = statisticsFilterT1->GetSigma(i);
    const StatisticRealType imageVarianceT1 = statisticsFilterT1->GetVariance(i);
    const StatisticRealType imageCountT1 = statisticsFilterT1->GetCount(i);
    
    std::cout << "T1 Component " << i << " Minimum == " << imageMinT1 << std::endl;
    std::cout << "T1 Component " << i << " Maximum == " << imageMaxT1 << std::endl;
    std::cout << "T1 Component " << i << " Mean == " << imageMeanT1 << std::endl;
    std::cout << "T1 Component " << i << " Sigma == " << imageSigmaT1 << std::endl;
    std::cout << "T1 Component " << i << " Variance == " << imageVarianceT1 << std::endl;
    std::cout << "T1 Component " << i << " Count == " << imageCountT1 << std::endl;

    const PixelType imageMinT2 = static_cast<PixelType> ( statisticsFilterT2->GetMinimum(i) );
    const PixelType imageMaxT2 = static_cast<PixelType> ( statisticsFilterT2->GetMaximum(i) );
    const StatisticRealType imageMeanT2 = statisticsFilterT2->GetMean(i);
    const StatisticRealType imageSigmaT2 = statisticsFilterT2->GetSigma(i);
    const StatisticRealType imageVarianceT2 = statisticsFilterT2->GetVariance(i);
    const StatisticRealType imageCountT2 = statisticsFilterT2->GetCount(i);

    std::cout << "T2 Component " << i << " Minimum == " << imageMinT2 << std::endl;
    std::cout << "T2 Component " << i << " Maximum == " << imageMaxT2 << std::endl;
    std::cout << "T2 Component " << i << " Mean == " << imageMeanT2 << std::endl;
    std::cout << "T2 Component " << i << " Sigma == " << imageSigmaT2 << std::endl;
    std::cout << "T2 Component " << i << " Variance == " << imageVarianceT2 << std::endl;
    std::cout << "T2 Component " << i << " Count == " << imageCountT2 << std::endl;

    const PixelType imageMinFLAIR = static_cast<PixelType> ( statisticsFilterFLAIR->GetMinimum(i) );
    const PixelType imageMaxFLAIR = static_cast<PixelType> ( statisticsFilterFLAIR->GetMaximum(i) );
    const StatisticRealType imageMeanFLAIR = statisticsFilterFLAIR->GetMean(i);
    const StatisticRealType imageSigmaFLAIR = statisticsFilterFLAIR->GetSigma(i);
    const StatisticRealType imageVarianceFLAIR = statisticsFilterFLAIR->GetVariance(i);
    const StatisticRealType imageCountFLAIR = statisticsFilterFLAIR->GetCount(i);

    std::cout << "FLAIR Component " << i << " Minimum == " << imageMinFLAIR << std::endl;
    std::cout << "FLAIR Component " << i << " Maximum == " << imageMaxFLAIR << std::endl;
    std::cout << "FLAIR Component " << i << " Mean == " << imageMeanFLAIR << std::endl;
    std::cout << "FLAIR Component " << i << " Sigma == " << imageSigmaFLAIR << std::endl;
    std::cout << "FLAIR Component " << i << " Variance == " << imageVarianceFLAIR << std::endl;
    std::cout << "FLAIR Component " << i << " Count == " << imageCountFLAIR << std::endl;

  } 

  return EXIT_SUCCESS;
}
