#include <vcl_cstdio.h>
#include <vcl_vector.h>

#include <vnl/vnl_math.h>

#include <AFCM/afcm.h>
#include <AFCM/afcm_util.h>
#include <3d_image/3d_image_util.h>
#include <AFCM/afcm_regression.h>

//===================================================================

void afcm_segmentation (const ImageType::Pointer& img_y, 
                        const int n_class, const int n_bin,
                        const float low_th, const float high_th,
                        const float bg_thresh,
                        const int gain_fit_option, 
                        const float gain_th, const float gain_min,
                        const float conv_thresh,
                        ImageType::Pointer& gain_field_g,
                        vcl_vector<ImageType::Pointer>& mem_fun_u, 
                        vcl_vector<ImageType::Pointer>& mem_fun_un, 
                        vcl_vector<float>& centroid_v)
{
  //Initializtion:  
  //Find the initial guess of the centroid for different classes v1, v2, v3.
  compute_init_centroid (img_y, n_class, n_bin, low_th, high_th, centroid_v);

  //Iteration: five steps:
  bool conv;
  int iter = 0;
  do {
    vcl_printf ("\nIteration %d:\n", iter);
    //1) Compute new membership functions u1[], u2[], u3[].
    compute_new_mem_fun_u (centroid_v, gain_field_g, img_y, bg_thresh, mem_fun_u);

    ///debug:
    ///save_mem_fun_u ("output_im", mem_fun_u);

    //2) Compute the new centroids v1, v2, v3.
    compute_new_centroids (mem_fun_u, gain_field_g, img_y, centroid_v);

    //3) Compute a new gain field g[]:
    //   Initially, we assume g[]=1 is know and fixed in our case.
    //   Here we update it by a regression fit of the white matter (mem_fun_u[2])
    if (gain_fit_option == 1 || gain_fit_option == 2) {
      compute_new_gain_field (mem_fun_u, img_y, gain_field_g, 
                              gain_fit_option, gain_th);
      
      //debug: save gain field file for debugging.
      ///save_01_img8 ("gain_field_g.mhd", gain_field_g);
    }

    //4) Compute a new membership function u1n[], u2n[], u3n[] using step 1.
    compute_new_mem_fun_u (centroid_v, gain_field_g, img_y, bg_thresh, mem_fun_un);
    
    //5) Test convergence.
    //   if max(u1n[]-u1[], u2n[]-u2[], u3n[]-u3[]) < 0.01, converge and finish.
    conv = test_convergence (mem_fun_u, mem_fun_un, conv_thresh);
    iter++;
  }
  while (conv == false);
}

//===================================================================

void compute_init_centroid (const ImageType::Pointer& image, 
                            const int n_class, const int n_bin,
                            const float low_th, const float high_th,
                            vcl_vector<float>& centroid_v)
{
  vcl_printf ("\ncompute_init_centroid():\n");
  vcl_printf ("  n_class %d, n_bin %d, low_t %f, high_t %f.\n",
              n_class, n_bin, low_th, high_th);

  // Exclude pixels whose intensity is outside the range [low_th, high_th]
  // from computation. The range is divided into n_bin bins to computer the kernal 
  // estimator, and then used to compute the parameters
  vcl_vector<float> histVector;
  vcl_vector<float> binMin;
  vcl_vector<float> binMax;
  int nBinHistogram = 0;

  // let program decide how many bins are there for the histogram
  compute_histogram (image, histVector, binMax, binMin, nBinHistogram);
  assert (histVector.size() == nBinHistogram);
  assert (binMin.size() == nBinHistogram);
  assert (binMax.size() == nBinHistogram);  

  // the variable n_bin below is used to devide the range of intensity used for kernal
  // estimator calculation.  
  vcl_vector<float> kernalEstimator;
  kernalEstimator.resize (n_bin);
  assert (n_bin != 1);
  float deltaX = (high_th-low_th)/(n_bin-1);
  vcl_vector<float> xVector (n_bin);
  for (int k = 0; k < n_bin; k++) {
    xVector[k] = low_th + k*deltaX;
  }

  bool Done = false;
  float h0 = 0;
  float h1 = 50;
  float h = (h0 + h1)/2;
  while (!Done) {
    for (int k = 0; k < n_bin; k++) {
      kernalEstimator[k] = 0;      
      for (int n = 0; n < nBinHistogram; n ++ ) {
        float b = binMin[n];
        if ( b < low_th )
          continue;
        float d = binMax[n];
        if ( d > high_th )
          continue;
        d = (d  +  b) / 2.0;
        d = exp(-(xVector[k]-d)*(xVector[k]-d)/(2*h*h));
        kernalEstimator[k] = kernalEstimator[k] + d * histVector[n];
      }
    }
    int C = CountMode (kernalEstimator);
    if (C > n_class)
      h0 = h;
    else
      h1 = h;
    float hNew = (h0 + h1)/2;
    if (fabs(hNew-h) < 0.01)
      Done = true;
    h = hNew;
  }

  centroid_v.clear();

  int kernalLength = kernalEstimator.size();
  assert (kernalLength > 0);
  assert (xVector.size() == kernalLength);
  int ind = 0;
  for (int k = 1; k < kernalLength-1; k++) {
    if (kernalEstimator[k] < kernalEstimator[k-1])
      continue;
    if (kernalEstimator[k] < kernalEstimator[k+1])
      continue;
    centroid_v.push_back (xVector[k]);
    if (centroid_v.size() >= n_class)
      break;
  }

  vcl_printf ("  centroid_v: C0 %f,   C1 %f,   C2 %f.\n\n", 
              centroid_v[0], centroid_v[1], centroid_v[2]);
}

// Compute new membership functions u1[], u2[], u3[].
void compute_new_mem_fun_u (const vcl_vector<float>& centroid_v,
                            const ImageType::Pointer& gain_field_g, 
                            const ImageType::Pointer& img_y,
                            const float bg_thresh,
                            vcl_vector<ImageType::Pointer>& mem_fun_u)
{
  vcl_printf ("  compute_new_mem_fun_u(): \n");
  const int n_class = mem_fun_u.size();
  
  for (int k = 0; k < n_class; k++) {
    //iterate through each pixel j:
    typedef itk::ImageRegionConstIterator< ImageType > ConstIteratorType;
    typedef itk::ImageRegionIterator< ImageType > IteratorType;
    ConstIteratorType ity (img_y, img_y->GetRequestedRegion());
    ConstIteratorType itg (gain_field_g, gain_field_g->GetRequestedRegion());
    IteratorType itu (mem_fun_u[k], mem_fun_u[k]->GetRequestedRegion());

    for (ity.GoToBegin(), itg.GoToBegin(), itu.GoToBegin(); 
         !ity.IsAtEnd(); 
         ++ity, ++itg, ++itu) {
      //Skip background pixels.
      float img_y_j = ity.Get();
      if (img_y_j < bg_thresh)
        continue;

      float gain_field_g_j = itg.Get();

      ///double numerator = img_y[j] - centroid_v[k] * gain_field_g[j];
      double numerator = img_y_j - centroid_v[k] * gain_field_g_j;
      
      if (numerator != 0)
        numerator = 1 / (numerator * numerator);
      else if (gain_field_g_j == 1) {
        //The divide-by-zero happens when img_y[j] == centroid_v[k].
        //In this case, the membership function should be 1 for this class and 
        //0 for all other classes (for normalization).
        itu.Set (1);
        continue; //Done for the current pixel.
      }
      else {
        //Keep numerator as 0 for this unlikely-to-happen case.
        ///assert (0);
      }

      double denominator = 0;
      for (int l = 0; l < n_class; l++) {
        ///double denominator_l = img_y[j] - centroid_v[l] * gain_field_g[j];
        double denominator_l = img_y_j - centroid_v[l] * gain_field_g_j;
        
        if (denominator_l != 0) 
          denominator_l = 1 / (denominator_l * denominator_l);
        else {
          //This is the case when the same pixel of other class than k has mem_fun == 1.
          //Set the membership function to 0.          
          itu.Set (0);
          continue;
        }

        denominator += denominator_l;
      }
      ///mem_fun_u[k][j] = numerator / denominator;
      assert (denominator != 0);
      itu.Set (numerator / denominator);
    }
  }
}

// Compute the new centroids v1, v2, v3.
void compute_new_centroids (const vcl_vector<ImageType::Pointer>& mem_fun_u, 
                            const ImageType::Pointer& gain_field_g, 
                            const ImageType::Pointer& img_y, 
                            vcl_vector<float>& centroid_v)
{
  vcl_printf ("  compute_new_centroids(): ");
  const int n_class = mem_fun_u.size();

  for (int k = 0; k < n_class; k++) {
    //iterate through each pixel j:
    typedef itk::ImageRegionConstIterator< ImageType > ConstIteratorType;
    ConstIteratorType ity (img_y, img_y->GetRequestedRegion());
    ConstIteratorType itg (gain_field_g, gain_field_g->GetRequestedRegion());
    ConstIteratorType itu (mem_fun_u[k], mem_fun_u[k]->GetRequestedRegion());

    double numerator = 0;
    double denominator = 0;
    for (ity.GoToBegin(), itg.GoToBegin(), itu.GoToBegin(); !ity.IsAtEnd(); ++ity, ++itg, ++itu) {      
      float mem_fun_u_kj = itu.Get();
      assert (vnl_math_isnan (mem_fun_u_kj) == false);
      float gain_field_g_j = itg.Get();
      assert (vnl_math_isnan (gain_field_g_j) == false);
      float img_y_j = ity.Get();
      assert (vnl_math_isnan (img_y_j) == false);

      ///double numerator = mem_fun_u[k][j] * mem_fun_u[k][j] * gain_field_g[j] * img_y[j];
      numerator += mem_fun_u_kj * mem_fun_u_kj * gain_field_g_j * img_y_j;
      assert (vnl_math_isnan (numerator) == false);
      ///double denominator = mem_fun_u[k][j] * mem_fun_u[k][j] * gain_field_g[j] * gain_field_g[j];
      denominator += mem_fun_u_kj * mem_fun_u_kj * gain_field_g_j * gain_field_g_j;
      assert (vnl_math_isnan (denominator) == false);
    }

    if (denominator == 0) {
      if (numerator == 0)
        centroid_v[k] = 0;
      else {
        vcl_printf ("  Error: divide by 0!\n");
        centroid_v[k] = FLT_MAX;
      }
    }
    else {
      centroid_v[k] = numerator / denominator;
    }
  }
  vcl_printf ("C0 %f,   C1 %f,   C2 %f.\n", 
              centroid_v[0], centroid_v[1], centroid_v[2]);
}

// Compute a new gain field g[]:
//   Initially, we assume g[]=1 is know and fixed in our case.
//   Here we update it by a regression fit of the white matter (mem_fun_u[2])
void compute_new_gain_field (const vcl_vector<ImageType::Pointer>& mem_fun_u, 
                             const ImageType::Pointer& img_y, 
                             ImageType::Pointer& gain_field_g,
                             const int gain_fit_option,
                             const float gain_th)
{
  assert (gain_fit_option == 1 || gain_fit_option == 2);
  vcl_printf ("  compute_new_gain_field():\n");
  vcl_printf ("    %s fitting, gain_th %f.\n",
              (gain_fit_option==1) ? "linear" : "quadratic", 
              gain_th);

  //Quadratic regression fiting to get the parameter B
  vnl_matrix<double> B;

  if (gain_fit_option == 1) {
    img_regression_linear (mem_fun_u[2], gain_th, B);
    //Use B to compute a new gain_field_g[]
    compute_linear_fit_img (B, gain_field_g);
  }
  else if (gain_fit_option == 2) {
    img_regression_quadratic (mem_fun_u[2], gain_th, B);
    //Use B to compute a new gain_field_g[]
    compute_quadratic_fit_img (B, gain_field_g);
  }
}

// Test convergence.
bool test_convergence (const vcl_vector<ImageType::Pointer>& mem_fun_u, 
                       const vcl_vector<ImageType::Pointer>& mem_fun_un, 
                       const float conv_thresh)
{
  vcl_printf ("  test_convergence(): ");

  const int n_class = mem_fun_u.size();
  float max_value = 0;

  for (int k = 0; k < n_class; k++) {
    //iterate through each pixel j:
    typedef itk::ImageRegionConstIterator< ImageType > ConstIteratorType;
    ConstIteratorType it (mem_fun_u[k], mem_fun_u[k]->GetRequestedRegion());
    ConstIteratorType itn (mem_fun_un[k], mem_fun_un[k]->GetRequestedRegion());

    for (it.GoToBegin(), itn.GoToBegin(); !it.IsAtEnd(); ++it, ++itn) {
      float mem_fun_u_kj = it.Get();
      assert (vnl_math_isnan (mem_fun_u_kj) == false);
      float mem_fun_un_kj = itn.Get();
      assert (vnl_math_isnan (mem_fun_un_kj) == false);

      ///float diff = member_fun_u[k][j] - member_fun_un[k][j];
      float diff = mem_fun_u_kj - mem_fun_un_kj;
      diff = vcl_fabs (diff);
      if (diff > max_value)
        max_value = diff;
    }    
  }

  vcl_printf ("max_value %f (conv_th %f).\n", max_value, conv_thresh);

  if (max_value < conv_thresh)
    return true;
  else
    return false;
}

