#include <vcl_cstdio.h>

#include <3d_image/3d_image_util.h>
#include <3d_image/3d_image_io.h>
#include "afcm.h"
#include "afcm_grid.h"
#include "afcm_regression.h"
#include "afcm_grid_regression.h"

void afcm_segmentation_grid (const ImageType::Pointer& img_y, 
                             const int n_class, const int n_bin,
                             const float low_th, const float high_th,
                             const float bg_thresh,
                             const int gain_fit_option, 
                             const float gain_th, const float gain_min,
                             const float conv_thresh,
                             const int n_grid,
                             ImageType::Pointer& gain_field_g,
                             vcl_vector<ImageType::Pointer>& mem_fun_u, 
                             vcl_vector<ImageType::Pointer>& mem_fun_un, 
                             vcl_vector<float>& centroid_v)
{
  int i;
  assert (gain_fit_option == 3 || gain_fit_option == 4);

  //Detect the bounding box of the non-background brain block B.
  int xmin, ymin, zmin;
  int xmax, ymax, zmax;
  bool r = detect_bnd_box (img_y, bg_thresh, xmin, ymin, zmin, xmax, ymax, zmax);
  assert (r);
  
  //Space division: into nxnxn: 3x3x3 or 4x4x4 blocks.
  vcl_vector<ImageType::Pointer> img_y_grid;
  vcl_vector<ImageType::IndexType> grid_center_index;
  compute_grid_imgs (img_y, xmin, ymin, zmin, xmax, ymax, zmax, n_grid, 
                     img_y_grid, grid_center_index);

  //Allocate space for the centroid_v_grid[] and centroid_vn_grid.
  const int total_grids = (int) img_y_grid.size();
  vcl_vector<vcl_vector<float> > centroid_v_grid (total_grids);
  vcl_vector<float> centroid_vn_grid (total_grids);

  //Allocate space for the gain_field_g_grid[].
  vcl_vector<ImageType::Pointer> gain_field_g_grid (total_grids);

  //Allocate space for the mem_fun_u_grid[][] and mem_fun_un_grid[][]
  vcl_vector<vcl_vector<ImageType::Pointer> > mem_fun_u_grid (total_grids);
  vcl_vector<vcl_vector<ImageType::Pointer> > mem_fun_un_grid (total_grids);

  for (i=0; i<total_grids; i++) {
    //Initialize gain_field_g_grid[].
    gain_field_g_grid[i] = ImageType::New();
    gain_field_g_grid[i]->SetRegions (img_y_grid[i]->GetLargestPossibleRegion());
    gain_field_g_grid[i]->SetSpacing (img_y_grid[i]->GetSpacing());
    gain_field_g_grid[i]->SetOrigin (img_y_grid[i]->GetOrigin());
    gain_field_g_grid[i]->Allocate();

    //Initialize the images of the membership functions u1[], u2[], u3[]
    //and a updated storage u1n[], u2n[], u3n[].
    mem_fun_u_grid[i].resize (n_class);
    mem_fun_un_grid[i].resize (n_class);
    
    for (int k = 0; k < n_class; k++) {
      mem_fun_u_grid[i][k] = ImageType::New();
      mem_fun_u_grid[i][k]->SetRegions (img_y_grid[i]->GetLargestPossibleRegion());
      mem_fun_u_grid[i][k]->SetSpacing (img_y_grid[i]->GetSpacing());
      mem_fun_u_grid[i][k]->SetOrigin (img_y_grid[i]->GetOrigin());
      mem_fun_u_grid[i][k]->Allocate();
      mem_fun_un_grid[i][k] = ImageType::New();
      mem_fun_un_grid[i][k]->SetRegions (img_y_grid[i]->GetLargestPossibleRegion());
      mem_fun_un_grid[i][k]->SetSpacing (img_y_grid[i]->GetSpacing());
      mem_fun_un_grid[i][k]->SetOrigin (img_y_grid[i]->GetOrigin());
      mem_fun_un_grid[i][k]->Allocate();
    }
  }

  //Grid Regression Iteration:
  bool conv = false;
  double SSD_old = 100000000; //DBL_MAX;
  double SSD;
  vcl_vector<double> SSD_history;
  int iter = 0;
  vnl_matrix<double> B;

  do {
    vcl_printf ("\n============================================================\n");
    vcl_printf ("\nGrid Regression Iteration %d:\n", iter);
    vcl_printf ("\n============================================================\n");

    //Run the AFCM segmentation in each block.
    //Assume each block contains sufficient GM and WM.
    //Obtain the intensity centroids CGM and CWM for each block.
    for (i=0; i<total_grids; i++) {
      //Reset the gain field and membership functions.
      gain_field_g_grid[i]->FillBuffer (1.0f);
      for (int k = 0; k < n_class; k++) {
        mem_fun_u_grid[i][k]->FillBuffer (0.0f);
        mem_fun_un_grid[i][k]->FillBuffer (0.0f);
      }

      afcm_segmentation (img_y_grid[i], n_class, n_bin, low_th, high_th, 
                         bg_thresh, 0, 
                         gain_th, gain_min,
                         conv_thresh, gain_field_g_grid[i],
                         mem_fun_u_grid[i], mem_fun_un_grid[i], centroid_v_grid[i]);

      vcl_printf ("\n==============================\n");
      vcl_printf ("Iter %d Grid %d : ", iter, i);
      vcl_printf ("  C0 %f, C1 %f, C2 %f.\n", 
                  centroid_v_grid[i][0], centroid_v_grid[i][1], centroid_v_grid[i][2]);
      vcl_printf ("==============================\n");
    }

    vcl_printf ("\n============================================================\n");
    vcl_printf ("  Start gain field fitting (regression) for iter %d.\n", iter);
    vcl_printf ("============================================================\n");
    //Linear or quadratic regression on the WM of the nxnxn grids, 
    //with value at the center of each block.
    //WM is the centroid_v_grid[i][2].
    assert (n_class == 3);
    assert (gain_fit_option == 3 || gain_fit_option == 4);

    if (gain_fit_option == 3) {
      //Linear regression
      grid_regression_linear (centroid_v_grid, grid_center_index, B);

      //Linear fitting for the new centroid_v_grid.
      centroid_linear_fit (grid_center_index, B, centroid_vn_grid);

      ///conv = test_1st_convergence (B);
    }
    else {
      //Quadratic regression
      grid_regression_quadratic (centroid_v_grid, grid_center_index, B);

      //Quadratic fitting for the new centroid_v_grid.
      centroid_quadratic_fit (grid_center_index, B, centroid_vn_grid);

      ///conv = test_2nd_convergence (B);
    }

    //Compute difference norm of the fitting.
    SSD = compute_diff_norm (centroid_v_grid, centroid_vn_grid);
    vcl_printf ("\n\n  Iter %d: SSD_old %4.0f, SSD %4.0f.\n", iter, SSD_old, SSD);

    //Test convergence: 
    SSD_history.push_back (SSD);
    if (SSD >= SSD_old) {
      vcl_printf ("\n SSD > SSD_old, converges, stop iteration.\n");
      conv = true;
    }
    else {
      conv = false;
      SSD_old = SSD;
    }

    if (conv == false) {
      for (i=0; i<total_grids; i++) {
        //Update each new gain_field_g_grid[i]
        if (gain_fit_option == 3)        
          compute_linear_fit_img (B, gain_field_g_grid[i]);
        else
          compute_quadratic_fit_img (B, gain_field_g_grid[i]);
      }

      //Compute a global pixel mean to generate a new gain_field_g[].
      compute_gain_from_grids (gain_field_g_grid, img_y, bg_thresh, gain_field_g);

      for (i=0; i<total_grids; i++) {
        //Update each new img_y_grid[i] for each grid: yn[] = y[] / g[].
        update_gain_to_image (gain_field_g_grid[i], img_y_grid[i]);
      }
    }

    iter++;
  }
  while (conv == false);

  vcl_printf ("\n==============================================================================\n");
  vcl_printf ("  Summary for grid division gain field correction:\n");
  vcl_printf ("    %s fitting: totally %d iteration(s).\n", 
              (gain_fit_option==3) ? "Linear" : "Quadratic",
              SSD_history.size());
  vcl_printf ("    SSD in iterations: ");
  for (unsigned int i=0; i<SSD_history.size(); i++)
    vcl_printf ("%4.0f ", SSD_history[i]);
  vcl_printf ("\n==============================================================================\n");
  
  //Compute the final img_y[] from the gain_field_g[].
  update_gain_to_image (gain_field_g, img_y);

  ///Debug
  ///save_img_f16 ("temp_gain_corrected.mhd", img_y);

  //mask the final gain_field with img_y[].
  mask_gain_field (img_y, bg_thresh, gain_field_g);

  //Initialize the image of gain field g[] to 1.
  ImageType::Pointer gain_field_tmp = ImageType::New();  
  gain_field_tmp->CopyInformation (img_y);
  gain_field_tmp->SetRegions (img_y->GetLargestPossibleRegion());
  gain_field_tmp->Allocate();
  gain_field_tmp->FillBuffer (1.0f);

  //Run the original AFCM again on the gain corrected img_y[].
  afcm_segmentation (img_y, n_class, n_bin, low_th, high_th, bg_thresh, 
                     0, gain_th, gain_min, conv_thresh, 
                     gain_field_tmp, mem_fun_u, mem_fun_un, centroid_v);
}
