#include "itkImageFileWriter.h"

#include "itkSmoothingRecursiveGaussianImageFilter.h"

#include "itkPluginUtilities.h"

#include "GeodesicSegmentationCLP.h"

// Use an anonymous namespace to keep class types and function names
// from colliding when module is used as shared object module.  Every
// thing should be in an anonymous namespace except for the module
// entry point, e.g. main()
//
namespace
{

template <typename TPixel>
int DoIt( int argc, char * argv[], TPixel )
{
  PARSE_ARGS;

  typedef TPixel InputPixelType;
  typedef TPixel OutputPixelType;

  const unsigned int Dimension = 3;

  typedef itk::Image<InputPixelType,  Dimension> InputImageType;
  typedef itk::Image<OutputPixelType, Dimension> OutputImageType;

  typedef itk::ImageFileReader<InputImageType>  ReaderType;

  typename ReaderType::Pointer image = ReaderType::New();
  typename ReaderType::Pointer mask = ReaderType::New();
  
  image->SetFileName( inputVolume1.c_str() );
    ImageType::Pointer m_Inp = ImageType::New();
    m_Inp = image->GetOutput();
  mask->SetFileName(inputVolume2.c_str());
    ImageType::Pointer m_Init = ImageType::New();
    m_Init=mask->GetOutput();
    
    
    ImageType::Pointer m_Geos = ImageType::New();;
    m_Geos->CopyInformation(m_Inp);
    m_Geos->SetRequestedRegion( m_Inp->GetRequestedRegion() );
    m_Geos->SetBufferedRegion( m_Inp->GetBufferedRegion() );
    m_Geos->Allocate();
    m_Geos->FillBuffer(0);

    //For future use to specify prior probabilities in segmentation as done in the SPIE paper
    ImageType::Pointer m_Gamma = ImageType::New();;
    m_Gamma->CopyInformation(m_Inp);
    m_Gamma->SetRequestedRegion( m_Inp->GetRequestedRegion() );
    m_Gamma->SetBufferedRegion( m_Inp->GetBufferedRegion() );
    m_Gamma->Allocate();
    m_Gamma->FillBuffer(0);
    
      IteratorType InpIt(m_Inp, m_Inp->GetRequestedRegion());
      IteratorType InitIt(m_Init, m_Init->GetRequestedRegion());
      IteratorType GeosIt(m_Geos, m_Geos->GetRequestedRegion());
      IteratorType GamIt(m_Gamma, m_Gamma->GetRequestedRegion());

      m_logger.Write("Setting initial parameters (Gamma, zero-value images)...");
      //Geodesic code goes here Initialize gamma to fixed value - can be changed in the future to incorporate priors
      GamIt.GoToBegin();
      while (!GamIt.IsAtEnd())
      {
        GamIt.Set(1.0);
        ++GamIt;
      }
      //Set initial infinities
      GeosIt.GoToBegin();
      InitIt.GoToBegin();
      while (!GeosIt.IsAtEnd())
      {
        if (InitIt.Get() != 0)
          GeosIt.Set(MAX_VAL);
        ++GeosIt;
        ++InitIt;
      }
      //Setting up the neighborhood iterator
      typename TImageType::SizeType radius;
      radius[0] = 1;
      radius[1] = 1;
      radius[2] = 1;
      itk::NeighborhoodIterator<TImageType> ResNIt(radius, m_Geos, m_Geos->GetRequestedRegion());
      itk::NeighborhoodIterator<TImageType> InpNIt(radius, m_Inp, m_Inp->GetRequestedRegion());

      //The main loops
      m_logger.Write("Main loops execution: Forward pass...");
      
      ResNIt.GoToBegin();
      InpNIt.GoToBegin();
      GamIt.GoToBegin();
      InpIt.GoToBegin();
      GeosIt.GoToBegin();
      InitIt.GoToBegin();
      while (!InpIt.IsAtEnd())
      {
        double C_f_arr[14];
        C_f_arr[13] = GeosIt.Get();
        C_f_arr[0] = ResNIt.GetPixel(4) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(4))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(4))));
        C_f_arr[1] = ResNIt.GetPixel(10) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(10))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(10))));
        C_f_arr[2] = ResNIt.GetPixel(12) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(12))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(12))));
        C_f_arr[3] = ResNIt.GetPixel(1) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(1))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(1))));
        C_f_arr[4] = ResNIt.GetPixel(3) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(3))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(3))));
        C_f_arr[5] = ResNIt.GetPixel(9) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(9))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(9))));
        C_f_arr[6] = ResNIt.GetPixel(0) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(0))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(0))));
        C_f_arr[7] = ResNIt.GetPixel(7) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(7))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(7))));
        C_f_arr[8] = ResNIt.GetPixel(6) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(6))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(6))));
        C_f_arr[9] = ResNIt.GetPixel(15) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(15))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(15))));
        C_f_arr[10] = ResNIt.GetPixel(24) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(24))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(24))));
        C_f_arr[11] = ResNIt.GetPixel(21) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(21))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(21))));
        C_f_arr[12] = ResNIt.GetPixel(18) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(18))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(18))));
        //std::cout<<InpNIt.GetPixel(13);
        double minval = 20000000;
        for (int i = 0; i<14; ++i)
        {
          if (C_f_arr[i]<minval)
            minval = C_f_arr[i];
        }
        //if(minval>0)
        //{std::cout<<minval<<std::endl;}
        //std::cout<<GeosIt.GetIndex()<<std::endl;
        GeosIt.Set(minval);
        ++InpIt;
        ++GeosIt;
        ++ResNIt;
        ++InpNIt;
        ++GamIt;
      }
      InpIt.GoToReverseBegin();
      GeosIt.GoToReverseBegin();
      GamIt.GoToReverseBegin();
      ResNIt.GoToEnd();
      --ResNIt;
      InpNIt.GoToEnd();
      --InpNIt;

      m_logger.Write("Main loops execution: Backward pass...");
      while (!InpIt.IsAtReverseEnd())
      {
        double C_b_arr[14];
        C_b_arr[13] = GeosIt.Get();
        C_b_arr[0] = ResNIt.GetPixel(22) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(22))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(22))));
        C_b_arr[1] = ResNIt.GetPixel(16) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(16))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(16))));
        C_b_arr[2] = ResNIt.GetPixel(14) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(14))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(14))));
        C_b_arr[3] = ResNIt.GetPixel(25) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(25))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(25))));
        C_b_arr[4] = ResNIt.GetPixel(23) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(23))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(23))));
        C_b_arr[5] = ResNIt.GetPixel(17) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(17))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(17))));
        C_b_arr[6] = ResNIt.GetPixel(26) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(26))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(26))));
        C_b_arr[7] = ResNIt.GetPixel(19) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(19))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(19))));
        C_b_arr[8] = ResNIt.GetPixel(20) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(20))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(20))));
        C_b_arr[9] = ResNIt.GetPixel(11) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(11))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(11))));
        C_b_arr[10] = ResNIt.GetPixel(2) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(2))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(2))));
        C_b_arr[11] = ResNIt.GetPixel(5) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(5))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(5))));
        C_b_arr[12] = ResNIt.GetPixel(8) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(8))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(8))));
        double minval = 20000000;
        for (int i = 0; i<14; ++i)
        {
          if (C_b_arr[i]<minval)
            minval = C_b_arr[i];
        }
        //std::cout<<minval;
        GeosIt.Set(minval);
        --InpIt;
        --GeosIt;
        --ResNIt;
        --InpNIt;
        --GamIt;
      }
      for (GeosIt.GoToBegin(); !GeosIt.IsAtEnd(); ++GeosIt)
      {
        if (GeosIt.Get() >= MAX_VAL)
          GeosIt.Set(0);
      }
    
  typedef itk::ImageFileWriter<OutputImageType> WriterType;
  typename WriterType::Pointer writer = WriterType::New();
  writer->SetFileName( outputVolume.c_str() );
  writer->SetInput( Geos );
  writer->SetUseCompression(1);
  writer->Update();

  return EXIT_SUCCESS;
}

} // end of anonymous namespace

int main( int argc, char * argv[] )
{
  PARSE_ARGS;

  itk::ImageIOBase::IOPixelType     pixelType;
  itk::ImageIOBase::IOComponentType componentType;

  try
    {
    itk::GetImageType(inputVolume, pixelType, componentType);

    // This filter handles all types on input, but only produces
    // signed types
    switch( componentType )
      {
      case itk::ImageIOBase::UCHAR:
        return DoIt( argc, argv, static_cast<unsigned char>(0) );
        break;
      case itk::ImageIOBase::CHAR:
        return DoIt( argc, argv, static_cast<signed char>(0) );
        break;
      case itk::ImageIOBase::USHORT:
        return DoIt( argc, argv, static_cast<unsigned short>(0) );
        break;
      case itk::ImageIOBase::SHORT:
        return DoIt( argc, argv, static_cast<short>(0) );
        break;
      case itk::ImageIOBase::UINT:
        return DoIt( argc, argv, static_cast<unsigned int>(0) );
        break;
      case itk::ImageIOBase::INT:
        return DoIt( argc, argv, static_cast<int>(0) );
        break;
      case itk::ImageIOBase::ULONG:
        return DoIt( argc, argv, static_cast<unsigned long>(0) );
        break;
      case itk::ImageIOBase::LONG:
        return DoIt( argc, argv, static_cast<long>(0) );
        break;
      case itk::ImageIOBase::FLOAT:
        return DoIt( argc, argv, static_cast<float>(0) );
        break;
      case itk::ImageIOBase::DOUBLE:
        return DoIt( argc, argv, static_cast<double>(0) );
        break;
      case itk::ImageIOBase::UNKNOWNCOMPONENTTYPE:
      default:
        std::cerr << "Unknown input image pixel component type: ";
        std::cerr << itk::ImageIOBase::GetComponentTypeAsString( componentType );
        std::cerr << std::endl;
        return EXIT_FAILURE;
        break;
      }
    }

  catch( itk::ExceptionObject & excep )
    {
    std::cerr << argv[0] << ": exception caught !" << std::endl;
    std::cerr << excep << std::endl;
    return EXIT_FAILURE;
    }
  return EXIT_SUCCESS;
}
