/**
\file  GeodesicSegmentation.h

\brief The header file containing the Geodesic segmentation class, used to apply an adaptive geodesic transform

Library Dependecies: ITK 4.7+ <br>
Header Dependencies: cbicaUtilities.h, cbicaLogging.h

https://www.cbica.upenn.edu/sbia/software/ <br>
software@cbica.upenn.edu

Copyright (c) 2016 University of Pennsylvania. All rights reserved. <br>
See COPYING file or https://www.cbica.upenn.edu/sbia/software/license.html

*/
#pragma once

#include "iostream"
#include "vtkImageData.h"
#include "itkImage.h"
#include "itkConnectedThresholdImageFilter.h"
#include "itkImageRegionIterator.h"
#include "itkBSplineControlPointImageFilter.h"
#include "itkExpImageFilter.h"
#include "itkImageRegionIterator.h"
#include "itkOtsuThresholdImageFilter.h"
#include "itkShrinkImageFilter.h"
#include "itkMedianImageFunction.h"
#include "itkNeighborhoodIterator.h"
#include "itkMinimumMaximumImageCalculator.h"
#include "itkConnectedComponentImageFilter.h"
#include "itkBinaryThresholdImageFilter.h"
#include "itkThresholdImageFilter.h"

#include "fProgressDialog.h"

#include "PreprocessingPipelineClass.h"
#include "ApplicationBase.h"

#include <limits.h>

const int MAX_VAL = /*static_cast<int>(_I32_MAX)*/100000; // maximum possible Geodesic distance -- this value doesn't work for float type images

/**
\class GeodesicSegmentation

\brief Applies an adaptive Geodesic filter to image

Reference:

@inproceedings{gaonkar2014adaptive,
title={Adaptive geodesic transform for segmentation of vertebrae on CT images},
author={Gaonkar, Bilwaj and Shu, Liao and Hermosillo, Gerardo and Zhan, Yiqiang},
booktitle={SPIE Medical Imaging},
pages={903516--903516},
year={2014},
organization={International Society for Optics and Photonics}
}
*/
//template<class ImageTypeGeodesic>
class GeodesicSegmentation : public ApplicationBase
{

public:
  GeodesicSegmentation();
  ~GeodesicSegmentation();

  void cleanUp()
  {

    cbica::Logging(loggerFile, "cleanUp called"); // [TBD]
  }

  template<class ImageTypeGeodesic
#if (_MSC_VER >= 1800) || (__GNUC__ > 4)
    = ImageTypeShort3D
#endif
      >
  typename ImageTypeGeodesic::Pointer Run(typename ImageTypeGeodesic::Pointer Inp, VectorVectorDouble &tumorPoints);

  template<class ImageTypeGeodesic
#if (_MSC_VER >= 1800) || (__GNUC__ > 4)
    = ImageTypeShort3D
#endif
>
  typename ImageTypeGeodesic::Pointer Run(typename ImageTypeGeodesic::Pointer Inp, typename ImageTypeGeodesic::Pointer MaskImage, VectorVectorDouble &tumorPoints);

private:
  inline void SetLongRunning(bool longRunning);

  //typename ImageTypeGeodesic::Pointer Init, Geos, Gamma, tumorMask

};

template<class ImageTypeGeodesic>
typename ImageTypeGeodesic::Pointer GeodesicSegmentation::Run(typename ImageTypeGeodesic::Pointer Inp, VectorVectorDouble &points)
{
  typename ImageTypeGeodesic::Pointer mask = ImageTypeGeodesic::New();
  mask->CopyInformation(Inp);
  mask->SetRequestedRegion(Inp->GetLargestPossibleRegion());
  mask->SetBufferedRegion(Inp->GetBufferedRegion());
  mask->Allocate();
  mask->FillBuffer(2);

  return Run< ImageTypeGeodesic >(Inp, mask, points);
}

template<class ImageTypeGeodesic>
typename ImageTypeGeodesic::Pointer GeodesicSegmentation::Run(typename ImageTypeGeodesic::Pointer Inp, typename ImageTypeGeodesic::Pointer MaskImage, VectorVectorDouble &tumorPoints)
{
  //--------------allocate a few images ------------------------
  messageUpdate("Geodesic Segmentation");
  progressUpdate(0);

  typename ImageTypeGeodesic::Pointer Init = ImageTypeGeodesic::New();
  Init->CopyInformation(Inp);
  Init->SetRequestedRegion(Inp->GetLargestPossibleRegion());
  Init->SetBufferedRegion(Inp->GetBufferedRegion());
  Init->Allocate();
  Init->FillBuffer(0);

  typename ImageTypeGeodesic::Pointer Geos = ImageTypeGeodesic::New();
  Geos->CopyInformation(Inp);
  Geos->SetRequestedRegion(Inp->GetLargestPossibleRegion());
  Geos->SetBufferedRegion(Inp->GetBufferedRegion());
  Geos->Allocate();
  Geos->FillBuffer(0);

  typename ImageTypeGeodesic::Pointer Gamma = ImageTypeGeodesic::New();
  Gamma->CopyInformation(Inp);
  Gamma->SetRequestedRegion(Inp->GetLargestPossibleRegion());
  Gamma->SetBufferedRegion(Inp->GetBufferedRegion());
  Gamma->Allocate();
  Gamma->FillBuffer(1);

  typename ImageTypeGeodesic::Pointer tumorMask = ImageTypeGeodesic::New();
  tumorMask->CopyInformation(Inp);
  tumorMask->SetRequestedRegion(Inp->GetLargestPossibleRegion());
  tumorMask->SetBufferedRegion(Inp->GetBufferedRegion());
  tumorMask->Allocate();
  tumorMask->FillBuffer(0);
  progressUpdate(10);
  qApp->processEvents();


  //---------------calculation of initial mask--------------------------
  typedef itk::ImageRegionIteratorWithIndex <ImageTypeGeodesic> IteratorType;
  for (unsigned int i = 0; i < tumorPoints.size(); i++)
  {
    // get index from the input points
    typename ImageTypeGeodesic::IndexType index;
    index[0] = tumorPoints[i][0];
    index[1] = tumorPoints[i][1];
    index[2] = tumorPoints[i][2];

    // initialize the mask and geodesic images 
    Init->SetPixel(index, static_cast<typename ImageTypeGeodesic::PixelType>(255));
    Geos->SetPixel(index, static_cast<typename ImageTypeGeodesic::PixelType>(MAX_VAL));
  }
  progressUpdate(20);
  qApp->processEvents();

  //---------------------------actual geodesic segmentation--------------------------
  typedef itk::ImageRegionIteratorWithIndex <ImageTypeGeodesic> IteratorType;
  IteratorType GeosIt(Geos, Geos->GetLargestPossibleRegion());
  IteratorType GamIt(Gamma, Gamma->GetLargestPossibleRegion());
  IteratorType MaskIt(MaskImage, MaskImage->GetLargestPossibleRegion());

  progressUpdate(22);

  //Setting up the neighborhood iterator
  typename ImageTypeGeodesic::SizeType radius;
  radius[0] = 1;
  radius[1] = 1;
  radius[2] = 1;
  itk::NeighborhoodIterator<ImageTypeGeodesic> ResNIt(radius, Geos, Geos->GetLargestPossibleRegion());
  itk::NeighborhoodIterator<ImageTypeGeodesic> InpNIt(radius, Inp, Inp->GetLargestPossibleRegion());

  //The main loops
  cbica::Logging(loggerFile, "Main loops execution : Forward pass");

  MaskIt.GoToBegin();

  progressUpdate(25);
  
  while (!MaskIt.IsAtEnd())
  {
    if (MaskIt.Get() > 1)
    {
      GamIt.SetIndex(MaskIt.GetIndex());
      GeosIt.SetIndex(MaskIt.GetIndex());

      double C_f_arr[14];

      // forward pass
      C_f_arr[13] = GeosIt.Get();
      C_f_arr[0] = ResNIt.GetPixel(4) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(4))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(4))));
      C_f_arr[1] = ResNIt.GetPixel(10) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(10))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(10))));
      C_f_arr[2] = ResNIt.GetPixel(12) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(12))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(12))));
      C_f_arr[3] = ResNIt.GetPixel(1) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(1))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(1))));
      C_f_arr[4] = ResNIt.GetPixel(3) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(3))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(3))));
      C_f_arr[5] = ResNIt.GetPixel(9) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(9))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(9))));
      C_f_arr[6] = ResNIt.GetPixel(0) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(0))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(0))));
      C_f_arr[7] = ResNIt.GetPixel(7) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(7))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(7))));
      C_f_arr[8] = ResNIt.GetPixel(6) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(6))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(6))));
      C_f_arr[9] = ResNIt.GetPixel(15) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(15))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(15))));
      C_f_arr[10] = ResNIt.GetPixel(24) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(24))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(24))));
      C_f_arr[11] = ResNIt.GetPixel(21) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(21))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(21))));
      C_f_arr[12] = ResNIt.GetPixel(18) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(18))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(18))));

      double minval = MAX_VAL * 200;
      for (int i = 0; i < 14; ++i)
      {
        if (C_f_arr[i] < minval)
          minval = C_f_arr[i];
      }

      GeosIt.Set(minval);
    }
    ++MaskIt;
    ++ResNIt;
    ++InpNIt;
  }
  progressUpdate(55);

  MaskIt.GoToReverseBegin();
  ResNIt.GoToEnd();
  InpNIt.GoToEnd();
  --ResNIt;
  --InpNIt;
  cbica::Logging(loggerFile, "Main loops execution : Backward pass");

  while (!MaskIt.IsAtReverseEnd())
  {
    if (MaskIt.Get() > 1)
    {
      GamIt.SetIndex(MaskIt.GetIndex());
      GeosIt.SetIndex(MaskIt.GetIndex());

      double C_b_arr[14];
      C_b_arr[13] = GeosIt.Get();
      C_b_arr[0] = ResNIt.GetPixel(22) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(22))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(22))));
      C_b_arr[1] = ResNIt.GetPixel(16) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(16))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(16))));
      C_b_arr[2] = ResNIt.GetPixel(14) + sqrt(1.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(14))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(14))));
      C_b_arr[3] = ResNIt.GetPixel(25) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(25))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(25))));
      C_b_arr[4] = ResNIt.GetPixel(23) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(23))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(23))));
      C_b_arr[5] = ResNIt.GetPixel(17) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(17))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(17))));
      C_b_arr[6] = ResNIt.GetPixel(26) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(26))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(26))));
      C_b_arr[7] = ResNIt.GetPixel(19) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(19))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(19))));
      C_b_arr[8] = ResNIt.GetPixel(20) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(20))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(20))));
      C_b_arr[9] = ResNIt.GetPixel(11) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(11))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(11))));
      C_b_arr[10] = ResNIt.GetPixel(2) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(2))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(2))));
      C_b_arr[11] = ResNIt.GetPixel(5) + sqrt(2.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(5))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(5))));
      C_b_arr[12] = ResNIt.GetPixel(8) + sqrt(3.0 + GamIt.Get()*((InpNIt.GetPixel(13) - InpNIt.GetPixel(8))*(InpNIt.GetPixel(13) - InpNIt.GetPixel(8))));
      double minval = MAX_VAL * 200;
      for (int i = 0; i < 14; ++i)
      {
        if (C_b_arr[i] < minval)
          minval = C_b_arr[i];
      }

      if (minval >= MAX_VAL)
      {
        GeosIt.Set(0);
      }
      else
      {
        GeosIt.Set(minval);
      }
    }
    --MaskIt;
    --ResNIt;
    --InpNIt;
  }

  progressUpdate(90);
  qApp->processEvents();
  //------------------------------geodesic thresholding-----------------------------
  //typedef itk::MinimumMaximumImageCalculator <ImageTypeGeodesic> ImageCalculatorFilterType;
  //typename ImageCalculatorFilterType::Pointer imageCalculatorFilter = ImageCalculatorFilterType::New();
  //imageCalculatorFilter->SetImage(Geos);
  //imageCalculatorFilter->Compute();
  //typename ImageTypeGeodesic::PixelType minValue = imageCalculatorFilter->GetMinimum();
  //typename ImageTypeGeodesic::PixelType maxValue = imageCalculatorFilter->GetMaximum();
  //typename ImageTypeGeodesic::PixelType average = (minValue + maxValue) / 2;
  //typename ImageTypeGeodesic::PixelType average_sqrt = std::sqrt(std::abs(average));
  //typename ImageTypeGeodesic::PixelType lowerLimit = minValue - average_sqrt / 2;
  //typename ImageTypeGeodesic::PixelType upperLimit = minValue + average_sqrt;

  //typedef itk::BinaryThresholdImageFilter< ImageTypeGeodesic, ImageTypeGeodesic > BinaryThresholderType;
  //typename BinaryThresholderType::Pointer binaryThresholdFilter = BinaryThresholderType::New();
  //binaryThresholdFilter->SetLowerThreshold(lowerLimit);
  //binaryThresholdFilter->SetUpperThreshold(upperLimit);
  //binaryThresholdFilter->SetInput(Geos);
  //binaryThresholdFilter->SetOutsideValue(static_cast<typename ImageTypeGeodesic::PixelType>(0));
  //binaryThresholdFilter->SetInsideValue(static_cast<typename ImageTypeGeodesic::PixelType>(255));
  //binaryThresholdFilter->Update();

  //typedef itk::ConnectedComponentImageFilter <ImageTypeGeodesic, ImageTypeGeodesic > ConnectedComponentImageFilterType;
  //typename ConnectedComponentImageFilterType::Pointer connected = ConnectedComponentImageFilterType::New();
  //connected->SetInput(thresholdFilter->GetOutput());
  //connected->Update();

  progressUpdate(100);
  
  cleanUp();
  return Geos;
}
