/*=========================================================================

 Program:   BRAINS (Brain Research: Analysis of Images, Networks, and Systems)
 Module:    $RCSfile: $
 Language:  TCL
 Date:      $Date: 2006/03/29 14:53:40 $
 Version:   $Revision: 1.9 $
 
   Copyright (c) Iowa Mental Health Clinical Research Center. All rights reserved.
   See BRAINSCopyright.txt or http://www.psychiatry.uiowa.edu/HTML/Copyright.html 
   for details.
 
      This software is distributed WITHOUT ANY WARRANTY; without even 
      the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
      PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#include <iostream>
#include <fstream>

#include <itkArray.h>
#include <itkImage.h>
#include <itkVectorImage.h>
#include <metaCommand.h>
#include <itkImageFileWriter.h>
#include <itkImageFileReader.h>
#include <itkExceptionObject.h>
#include <itkMetaDataObject.h>
#include <itkImageRegionIterator.h>
#include <itkImageRegionConstIterator.h>
#include <itkLabelStatisticsImageFilter.h>
#include <itkThresholdImageFilter.h>
#include <itkMaskImageFilter.h>
#include <itkNotImageFilter.h>
#include <itkScalarImageKmeansImageFilter.h>
#include <itkImageRegionIterator.h>


#include "KmeansClusterSamplesCLP.h"

int main (int argc, char **argv)
{
  int flairImage = 0;
    
  PARSE_ARGS;

/*  std::string inputT1Volume = "t1vol.img";
  std::string inputMaskVolume = "t1vol_mask.hdr";
  std::string inputFLAIRVolume = "flair1to_t1vol.hdr";
  std::string outputLabelVolume = "labeledImage_pd1.img";
 */
  int brainMask = 255;
  int nonBrainMask = 0;

  bool debug=true;
  if (debug) 
    {
    std::cout << "Input T1 Image: " <<  inputT1Volume << std::endl; 
    std::cout << "Input Mask Image: " <<  inputMaskVolume << std::endl; 
    std::cout << "Output K-Means Label Image: " <<  outputLabelVolume << std::endl;
	if (!inputFLAIRVolume.empty()) {
	  std::cout << "Input FLAIR Image: " << inputFLAIRVolume << std::endl;
	  flairImage = 1;
	}
  }
 
  
  bool violated=false;
  if (inputT1Volume.size() == 0) { violated = true; std::cout << "  --inputT1Volume Required! "  << std::endl; }
  if (inputMaskVolume.size() == 0) { violated = true; std::cout << "  --inputMaskVolume Required! "  << std::endl; }
  if (outputLabelVolume.size() == 0) { violated = true; std::cout << "  --outputLabelVolume Required! "  << std::endl; }
  if (violated) exit(1);
  
  
  typedef signed short       PixelType;
  const unsigned int          Dimension = 3;

  typedef itk::Image<PixelType, Dimension > ImageType;
  typedef itk::ImageFileReader< ImageType > ReaderType;
/*
      typedef itk::ThresholdImageFilter< ImageType > ThresholdFilterType;
	  typedef itk::LabelStatisticsImageFilter< ImageType, ImageType > LabelStatisticsFilterType;
    typedef LabelStatisticsFilterType::RealType StatisticRealType;
	  typedef itk::MaskImageFilter< ImageType, ImageType > MaskFilterType;
	    typedef itk::ScalarImageKmeansImageFilter< ImageType > KMeansFilterType;
  typedef KMeansFilterType::RealPixelType RealPixelType;
    typedef KMeansFilterType::OutputImageType LabelImageType;
	  typedef itk::LabelStatisticsImageFilter< LabelImageType, LabelImageType > LabelMapStatisticsFilterType;
*/
  ReaderType::Pointer T1Reader = ReaderType::New();
  T1Reader->SetFileName( inputT1Volume );

  ReaderType::Pointer maskReader = ReaderType::New();
  maskReader->SetFileName( inputMaskVolume );

//  itk::ImageFileReader< LabelImageType >::Pointer labelImageReader = itk::ImageFileReader< LabelImageType >::New();
//  labelImageReader->SetFileName(outputLabelVolume);
//  labelImageReader->Update();

//  LabelImageType::Pointer kmeansLabelImage = labelImageReader->GetOutput();

  ReaderType::Pointer FLAIRReader = ReaderType::New();
  FLAIRReader->SetFileName( inputFLAIRVolume );

  const PixelType imageExclusion = -32000;
  const PixelType maskThresholdBelow = 1;     // someday with more generality?
  
  /* The Threshold Image Filter is used to produce the brain clipping mask. */
  typedef itk::ThresholdImageFilter< ImageType > ThresholdFilterType;
  ThresholdFilterType::Pointer brainMaskFilter = ThresholdFilterType::New();
  brainMaskFilter->SetInput( maskReader->GetOutput() );
  brainMaskFilter->ThresholdBelow( maskThresholdBelow );
  brainMaskFilter->Update();

  /* The Not Image Filter is used to produce the other clipping mask. */
  typedef itk::NotImageFilter< ImageType, ImageType > NotFilterType;
  NotFilterType::Pointer nonBrainMaskFilter = NotFilterType::New();
  nonBrainMaskFilter->SetInput( maskReader->GetOutput() );
  nonBrainMaskFilter->Update();

  /* The Statistics Image Filter lets us find the initial cluster means.
     Should this be limited to the excluded region of the clipped T1 image?  */
  typedef itk::LabelStatisticsImageFilter< ImageType, ImageType > LabelStatisticsFilterType;
  typedef LabelStatisticsFilterType::RealType StatisticRealType;
  LabelStatisticsFilterType::Pointer statisticsFilter = LabelStatisticsFilterType::New();
  statisticsFilter->SetInput( T1Reader->GetOutput() );
  statisticsFilter->SetLabelInput( maskReader->GetOutput() );
  statisticsFilter->Update();

  const PixelType imageMin = static_cast<PixelType> ( statisticsFilter->GetMinimum(brainMask) );
  const PixelType imageMax = static_cast<PixelType> ( statisticsFilter->GetMaximum(brainMask) );
  const StatisticRealType imageMean = statisticsFilter->GetMean(brainMask);
  const StatisticRealType imageSigma = statisticsFilter->GetSigma(brainMask);
  
  std::cout << "T1 Brain Minimum == " << imageMin << std::endl;
  std::cout << "T1 Brain Maximum == " << imageMax << std::endl;
  std::cout << "T1 Brain Mean == " << imageMean << std::endl;
  std::cout << "T1 Brain Sigma == " << imageSigma << std::endl;
  

  /* The Statistics Image Filter lets us find the initial cluster means.
     Should this be limited to the excluded region of the clipped T1 image?  */
/*  LabelStatisticsFilterType::Pointer nonBrainStatisticsFilter = LabelStatisticsFilterType::New();
  nonBrainStatisticsFilter->SetInput( T1Reader->GetOutput() );
  nonBrainStatisticsFilter->SetLabelInput( maskReader->GetOutput() );
  nonBrainStatisticsFilter->Update();

  const PixelType nonBrainImageMin = static_cast<PixelType> ( nonBrainStatisticsFilter->GetMinimum(nonBrainMask) );
  const PixelType nonBrainImageMax = static_cast<PixelType> ( nonBrainStatisticsFilter->GetMaximum(nonBrainMask) );
  const StatisticRealType nonBrainImageMean = nonBrainStatisticsFilter->GetMean(nonBrainMask);
  const StatisticRealType nonBrainImageSigma = nonBrainStatisticsFilter->GetSigma(nonBrainMask);
  
  //std::cout << "Background Minimum == " << nonBrainImageMin << std::endl;
  std::cout << "T1 Background Maximum == " << nonBrainImageMax << std::endl;
  std::cout << "T1 Background Minimum == " << nonBrainImageMin << std::endl;
  std::cout << "T1 Background Mean == " << nonBrainImageMean << std::endl;
  std::cout << "T1 Background Sigma == " << nonBrainImageSigma << std::endl;
  
  /* The Mask Image Filter applies the clipping mask by stepping 
     on the excluded region with the imageExclusion value. */
  typedef itk::MaskImageFilter< ImageType, ImageType > MaskFilterType;
  MaskFilterType::Pointer clippedBrainT1Filter = MaskFilterType::New();
  clippedBrainT1Filter->SetInput1( T1Reader->GetOutput() );
  clippedBrainT1Filter->SetInput2( brainMaskFilter->GetOutput() );
  clippedBrainT1Filter->SetOutsideValue( imageExclusion );
  clippedBrainT1Filter->Update();
  
  /* The Mask Image Filter applies the clipping mask by stepping 
     on the excluded region with the imageExclusion value. */
/*  MaskFilterType::Pointer clippedNonBrainT1Filter = MaskFilterType::New();
  clippedNonBrainT1Filter->SetInput1( T1Reader->GetOutput() );
  clippedNonBrainT1Filter->SetInput2( nonBrainMaskFilter->GetOutput() );
  clippedNonBrainT1Filter->SetOutsideValue( imageExclusion );
  clippedNonBrainT1Filter->Update();
  

  /* The Scalar Image Kmeans Image Filter will find a code image in 3 classes
     for the interior of the mask, plus a code for the exterior of the mask. */
  typedef itk::ScalarImageKmeansImageFilter< ImageType > KMeansFilterType;
  typedef KMeansFilterType::RealPixelType RealPixelType;
 KMeansFilterType::Pointer kmeansFilter = KMeansFilterType::New();
  kmeansFilter->SetInput( clippedBrainT1Filter->GetOutput() );

  unsigned int numberOfInitialClasses = 4;
  const unsigned int useNonContiguousLabels = 1;

  RealPixelType backgroundInitialMean = imageExclusion;
  RealPixelType bloodInitialMean = imageMax;    // ARTERIAL blood.
  const RealPixelType csfInitialMean = imageMean - 2*imageSigma;
  const RealPixelType whiteInitialMean = imageMean + imageSigma;
  const RealPixelType grayInitialMean = imageMean - imageSigma/5;

  kmeansFilter->AddClassWithInitialMean( backgroundInitialMean );
  kmeansFilter->AddClassWithInitialMean( csfInitialMean );
  kmeansFilter->AddClassWithInitialMean( grayInitialMean );
  kmeansFilter->AddClassWithInitialMean( whiteInitialMean );
  //kmeansFilter->AddClassWithInitialMean( bloodInitialMean );

  kmeansFilter->SetUseNonContiguousLabels( useNonContiguousLabels );

  try
    {
    kmeansFilter->Update();
    }
  catch( itk::ExceptionObject & excp )
    {
    std::cerr << "Problem encountered while running K-means segmentation ";
    std::cerr << excp << std::endl;
    return EXIT_FAILURE;
    }

  KMeansFilterType::ParametersType estimatedMeans = 
                                            kmeansFilter->GetFinalMeans();

  unsigned int numberOfClasses = estimatedMeans.Size();

  for ( unsigned int i = 0 ; i < numberOfClasses ; ++i )
    {
    std::cout << "Brain cluster[" << i << "] ";
    std::cout << "    estimated mean : " << estimatedMeans[i] << std::endl;
    }


  /* The Scalar Image Kmeans Image Filter will find a code image in 3 classes
     for the interior of the mask, plus a code for the exterior of the mask. */
/*  KMeansFilterType::Pointer kmeansNonBrainFilter = KMeansFilterType::New();
  kmeansNonBrainFilter->SetInput( clippedNonBrainT1Filter->GetOutput() );

  numberOfInitialClasses = 4;

  backgroundInitialMean = imageExclusion;
  const RealPixelType airInitialMean = imageMin;
  const RealPixelType fatInitialMean = imageMax;
  const RealPixelType muscleInitialMean = imageMean;

  kmeansNonBrainFilter->AddClassWithInitialMean( backgroundInitialMean );
  kmeansNonBrainFilter->AddClassWithInitialMean( airInitialMean );
  kmeansNonBrainFilter->AddClassWithInitialMean( muscleInitialMean );
  kmeansNonBrainFilter->AddClassWithInitialMean( fatInitialMean );
  kmeansNonBrainFilter->SetUseNonContiguousLabels( useNonContiguousLabels );

  

  try
    {
    kmeansNonBrainFilter->Update();
    }
  catch( itk::ExceptionObject & excp )
    {
    std::cerr << "Problem encountered while Background K-Means segmentation ";
    std::cerr << excp << std::endl;
    return EXIT_FAILURE;
    }

  estimatedMeans = kmeansNonBrainFilter->GetFinalMeans();

  numberOfClasses = estimatedMeans.Size();

  for ( unsigned int i = 0 ; i < numberOfClasses ; ++i )
    {
    std::cout << "Background cluster[" << i << "] ";
    std::cout << "    estimated mean : " << estimatedMeans[i] << std::endl;
    }

  /* Now remap the labels - background first followed by brain */
  typedef KMeansFilterType::OutputImageType LabelImageType;
  LabelImageType::Pointer kmeansLabelImage = LabelImageType::New();
  kmeansLabelImage->SetRegions( T1Reader->GetOutput()->GetLargestPossibleRegion() );
  kmeansLabelImage->SetSpacing( T1Reader->GetOutput()->GetSpacing() );
  kmeansLabelImage->SetDirection( T1Reader->GetOutput()->GetDirection() );
  kmeansLabelImage->SetOrigin( T1Reader->GetOutput()->GetOrigin() );
  kmeansLabelImage->Allocate( );
  kmeansLabelImage->FillBuffer( 0 );
  
  typedef itk::LabelStatisticsImageFilter< LabelImageType, LabelImageType > LabelMapStatisticsFilterType;

/*LabelMapStatisticsFilterType::Pointer statisticsNonBrainFilter = LabelMapStatisticsFilterType::New();
  statisticsNonBrainFilter->SetInput( kmeansNonBrainFilter->GetOutput() );
  statisticsNonBrainFilter->SetLabelInput( kmeansNonBrainFilter->GetOutput() );
  statisticsNonBrainFilter->Update();


  /* Background Tissues are Lower Label values */
  unsigned char currentLabel = 0;
/*  for (unsigned int i=1; i<256; i++)
  {
    if ( statisticsNonBrainFilter->HasLabel( static_cast<unsigned char> ( i ) ) )
    {
      currentLabel++;
      LabelImageType::RegionType labelRegion = statisticsNonBrainFilter->GetRegion( static_cast<unsigned char> ( i ) );
      itk::	ImageRegionIterator<LabelImageType> it( kmeansNonBrainFilter->GetOutput(), labelRegion );
  
      it.GoToBegin();
      while( !it.IsAtEnd() )
      {
        if ( it.Get() == static_cast<unsigned char> ( i ) ) 
        {
          // Set Output Image
          kmeansLabelImage->SetPixel(it.GetIndex(), currentLabel);
        }
        ++it;
      } 
    }
  }
  
  /* Brain Tissues are Higher Label values */
  LabelMapStatisticsFilterType::Pointer statisticsBrainFilter = LabelMapStatisticsFilterType::New();
  statisticsBrainFilter->SetInput( kmeansFilter->GetOutput() );
  statisticsBrainFilter->SetLabelInput( kmeansFilter->GetOutput() );
  statisticsBrainFilter->Update();
  
  for (unsigned int i=1; i<256; i++)
  {
    if ( statisticsBrainFilter->HasLabel( static_cast<unsigned char> ( i ) ) )
    {
      currentLabel++;
      LabelImageType::RegionType labelRegion = statisticsBrainFilter->GetRegion( static_cast<unsigned char> ( i ) );
      itk::	ImageRegionIterator<LabelImageType> it( kmeansFilter->GetOutput(), labelRegion );
  
      it.GoToBegin();
      while( !it.IsAtEnd() )
      {
        if ( it.Get() == static_cast<unsigned char> ( i ) ) 
        {
          // Set Output Image
          kmeansLabelImage->SetPixel(it.GetIndex(), currentLabel);
        }
        ++it;
      } 
    }
  }




  /* Lesion Tissues are Highest Label value */
  //unsigned char currentLabel = 6;
  if (flairImage) {
  /* The Statistics Image Filter lets us find the initial cluster means on the FLAIR image.  */
    LabelStatisticsFilterType::Pointer flairStatisticsFilter = LabelStatisticsFilterType::New();
    flairStatisticsFilter->SetInput( FLAIRReader->GetOutput() );
    flairStatisticsFilter->SetLabelInput( maskReader->GetOutput() );
    flairStatisticsFilter->Update();

    const PixelType flairImageMin = static_cast<PixelType> ( flairStatisticsFilter->GetMinimum(brainMask) );
    const PixelType flairImageMax = static_cast<PixelType> ( flairStatisticsFilter->GetMaximum(brainMask) );
    const StatisticRealType flairImageMean = flairStatisticsFilter->GetMean(brainMask);
    const StatisticRealType flairImageSigma = flairStatisticsFilter->GetSigma(brainMask);
  
    std::cout << "FLAIR Brain Minimum == " << flairImageMin << std::endl;
    std::cout << "FLAIR Brain Maximum == " << flairImageMax << std::endl;
    std::cout << "FLAIR Brain Mean == " << flairImageMean << std::endl;
    std::cout << "FLAIR Brain Sigma == " << flairImageSigma << std::endl;

	/* The Mask Image Filter applies the clipping mask by stepping 
       on the excluded region with the imageExclusion value. */
    MaskFilterType::Pointer clippedBrainFLAIRFilter = MaskFilterType::New();
    clippedBrainFLAIRFilter->SetInput1( FLAIRReader->GetOutput() );
    clippedBrainFLAIRFilter->SetInput2( brainMaskFilter->GetOutput() );
    clippedBrainFLAIRFilter->SetOutsideValue( 0 );
    clippedBrainFLAIRFilter->Update();

	/* Threshold FLAIR Image for lesion class */
    PixelType lesionThresholdBelow = 1.5*flairImageSigma + flairImageMean;
	PixelType lesionThresholdAbove = 2*flairImageMean;
    std::cout << "Lesion > " << lesionThresholdBelow << std::endl;

    ThresholdFilterType::Pointer lesionThresholdFilter = ThresholdFilterType::New();
    lesionThresholdFilter->SetInput(clippedBrainFLAIRFilter->GetOutput());
    lesionThresholdFilter->ThresholdOutside(lesionThresholdBelow, lesionThresholdAbove);
    lesionThresholdFilter->Update();

    currentLabel++;

    ImageType::RegionType labelRegion = lesionThresholdFilter->GetOutput()->GetLargestPossibleRegion();
	itk::ImageRegionIterator<ImageType> it( lesionThresholdFilter->GetOutput(), labelRegion );
  
    it.GoToBegin();
    while( !it.IsAtEnd() )
    {
      if ( (it.Get())&&(it.GetIndex()[1]<110 ) ) 
	  {      
	    std::cout << "it.Get()! = 0 " << std::endl;
		std::cout << it.GetIndex() << std::endl;
	    // Set Output Image
        kmeansLabelImage->SetPixel(it.GetIndex(), currentLabel);
		//std::cout << static_cast<unsigned int> (kmeansLabelImage->GetPixel(it.GetIndex())) <<std::endl;
      }
      ++it;
	}
  }

  /* Write out the resulting Label Image */
  typedef itk::ImageFileWriter<LabelImageType> WriterType;
  WriterType::Pointer labelWriter = WriterType::New();
  labelWriter->SetInput( kmeansLabelImage );
  labelWriter->SetFileName( outputLabelVolume );
  labelWriter->Update( );

  return EXIT_SUCCESS;
    
    
}



