#include <vcl_cstdio.h>

#include "afcm_grid_regression.h"

//===================================================================
// Regression on grid block intensity

void grid_regression_linear (const vcl_vector<vcl_vector<float> >& centroid_v_grid,
                             const vcl_vector<ImageType::IndexType>& grid_center_index,
                             vnl_matrix<double>& B)
{
  vcl_printf ("    grid_regression_linear(): \n");
  int i;

  //Put centroid_v_grid[2] into y[] (only use the WM for now).
  assert (centroid_v_grid[0].size() == 3);
  //Put centroid index coordinates into x1[], x2[], x3[].
  assert (centroid_v_grid.size() == grid_center_index.size());

  //Determine the total number of qualified inputs.
  vcl_printf ("      WM centroid(s): ");
  int SZ = 0;
  for (i=0; i<centroid_v_grid.size(); i++) {
    vcl_printf ("%.2f ", centroid_v_grid[i][2]);
    if (centroid_v_grid[i][2] >= 0) {
      SZ++;
    }
  }
  vcl_printf ("\n      Total # WM centroid(s) used = %d.\n", SZ);
  
  vnl_matrix<double> y (SZ,1);
  vnl_matrix<double> x1 (SZ,1);
  vnl_matrix<double> x2 (SZ,1);
  vnl_matrix<double> x3 (SZ,1);

  int c = 0;
  for (int i=0; i<centroid_v_grid.size(); i++) {
    if (centroid_v_grid[i][2] < 0) 
      continue;

    assert (c < SZ);
    y(c, 0) = centroid_v_grid[i][2]; //only use the WM for now
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    x1(c, 0) = x_1;
    x2(c, 0) = x_2;
    x3(c, 0) = x_3;
    c++;
  }

  //Prepare the design matrix X
  vnl_matrix<double> X (SZ,4);
  X.set_column (0, 1.0);
  X.update (x1, 0, 1);
  X.update (x2, 0, 2);
  X.update (x3, 0, 3);
  ///vcl_cerr << X;  
  x1.clear();
  x2.clear();
  x3.clear();

  vnl_matrix<double> Xt = X.transpose();
  vnl_matrix<double> Xt_X = Xt * X; //(x'*x)
  X.clear();
  vnl_matrix<double> Xt_y = Xt * y; //(x'*y)
  Xt.clear();
  y.clear();
  //Solve for the linear normal equation: (x'*x) * b = (x'*y)
  vnl_matrix<double> Xt_X_inv = vnl_matrix_inverse<double>(Xt_X);
  Xt_X.clear();
  //b = inv(x'*x) * (x'*y);
  B = Xt_X_inv * Xt_y;
  
  vcl_printf ("B: \n");
  vcl_cerr << B;
}

void grid_regression_quadratic (const vcl_vector<vcl_vector<float> >& centroid_v_grid,
                                const vcl_vector<ImageType::IndexType>& grid_center_index,
                                vnl_matrix<double>& B)
{
  vcl_printf ("    grid_regression_quadratic(): \n");
  int i;

  //Put centroid_v_grid[2] into y[] (only use the WM for now).
  assert (centroid_v_grid[0].size() == 3);
  //Put centroid index coordinates into x1[], x2[], x3[].
  assert (centroid_v_grid.size() == grid_center_index.size());
  
  //Determine the total number of qualified inputs.
  vcl_printf ("      WM centroid(s): ");
  int SZ = 0;
  for (i=0; i<centroid_v_grid.size(); i++) {
    vcl_printf ("%.2f ", centroid_v_grid[i][2]);
    if (centroid_v_grid[i][2] >= 0) {
      SZ++;
    }
  }
  vcl_printf ("\n      Total # WM centroid(s) used = %d.\n", SZ);

  vnl_matrix<double> y (SZ,1);
  vnl_matrix<double> x1 (SZ,1);
  vnl_matrix<double> x2 (SZ,1);
  vnl_matrix<double> x3 (SZ,1);

  int c = 0;
  for (int i=0; i<centroid_v_grid.size(); i++) {
    if (centroid_v_grid[i][2] < 0) 
      continue;

    assert (c < SZ);
    y(c, 0) = centroid_v_grid[i][2]; //only use the WM for now
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    x1(c, 0) = x_1;
    x2(c, 0) = x_2;
    x3(c, 0) = x_3;
    c++;
  }

  //Prepare the design matrix X
  vnl_matrix<double> X (SZ,10);
  X.set_column (0, 1.0);
  X.update (x1, 0, 1);
  X.update (x2, 0, 2);
  X.update (x3, 0, 3);
  ///vcl_cerr << X;  
  x1.clear();
  x2.clear();
  x3.clear();

  vnl_matrix<double> x1x2 (SZ,1);
  vnl_matrix<double> x1x3 (SZ,1);
  vnl_matrix<double> x2x3 (SZ,1); 
  c = 0;
  for (int i=0; i<centroid_v_grid.size(); i++) {
    if (centroid_v_grid[i][2] < 0) 
      continue;

    assert (c < SZ);
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    x1x2 (c, 0) = x_1 * x_2;
    x1x3 (c, 0) = x_1 * x_3;
    x2x3 (c, 0) = x_2 * x_3;
    c++;
  }
  X.update (x1x2, 0, 4);
  X.update (x1x3, 0, 5);
  X.update (x2x3, 0, 6);
  x1x2.clear();
  x1x3.clear();
  x2x3.clear();
  
  vnl_matrix<double> x1x1 (SZ,1);
  vnl_matrix<double> x2x2 (SZ,1);
  vnl_matrix<double> x3x3 (SZ,1);
  c = 0;
  for (int i=0; i<centroid_v_grid.size(); i++) {
    if (centroid_v_grid[i][2] < 0) 
      continue;
    assert (i < SZ);
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    x1x1 (i, 0) = x_1 * x_1;
    x2x2 (i, 0) = x_2 * x_2;
    x3x3 (i, 0) = x_3 * x_3;
    c++;
  }
  X.update (x1x1, 0, 7);
  X.update (x2x2, 0, 8);
  X.update (x3x3, 0, 9);
  x1x1.clear();
  x2x2.clear();
  x3x3.clear();

  ///vcl_printf ("X: \n");
  ///vcl_cerr << X;

  vnl_matrix<double> Xt = X.transpose();
  vnl_matrix<double> Xt_X = Xt * X; //(x'*x)
  X.clear();
  vnl_matrix<double> Xt_y = Xt * y; //(x'*y)
  Xt.clear();
  y.clear();
  //Solve for the linear normal equation: (x'*x) * b = (x'*y)
  vnl_matrix<double> Xt_X_inv = vnl_matrix_inverse<double>(Xt_X);  
  Xt_X.clear();
  //b = inv(x'*x) * (x'*y);
  B = Xt_X_inv * Xt_y;
  
  vcl_printf ("B: \n");
  vcl_cerr << B;
}

//===================================================================

void centroid_linear_fit (const vcl_vector<ImageType::IndexType>& grid_center_index,  
                          const vnl_matrix<double>& B, 
                          vcl_vector<float>& centroid_vn_grid)
{  
  vcl_printf ("    centroid_linear_fit(): \n");
  assert (B.rows() == 4);
  assert (grid_center_index.size() == centroid_vn_grid.size());

  //Compute a new value for the grid_center_index[] from B.
  for (int i=0; i<grid_center_index.size(); i++) {
    assert (grid_center_index[i].GetIndexDimension() == 3);
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    double pixel = B(0,0) + B(1,0)*x_1 + B(2,0)*x_2 + B(3,0)*x_3;
    centroid_vn_grid[i] = pixel;
  }
}

void centroid_quadratic_fit (const vcl_vector<ImageType::IndexType>& grid_center_index,  
                             const vnl_matrix<double>& B, 
                             vcl_vector<float>& centroid_vn_grid)
{
  vcl_printf ("    centroid_quadratic_fit(): \n");
  assert (B.rows() == 10);
  assert (grid_center_index.size() == centroid_vn_grid.size());

  //Compute a new value for the grid_center_index[] from B.
  for (int i=0; i<grid_center_index.size(); i++) {
    assert (grid_center_index[i].GetIndexDimension() == 3);
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    double pixel = B(0,0) + B(1,0)*x_1 + B(2,0)*x_2 + B(3,0)*x_3 +
                   B(4,0)*x_1*x_2 + B(5,0)*x_1*x_3 + B(6,0)*x_2*x_3 +
                   B(7,0)*x_1*x_1 + B(8,0)*x_2*x_2 + B(9,0)*x_3*x_3;
    centroid_vn_grid[i] = pixel;
  }
}

//: should use the max/mean pixel value or use Euclidean norm 
//  before/after correction to test convergence!!
//
bool test_1st_convergence (const vnl_matrix<double>& B)
{
  vcl_printf ("test_1st_convergence():\n");
  //double pixel = B(0,0) + B(1,0)*x_1 + B(2,0)*x_2 + B(3,0)*x_3;
  const float B_1st_thresh = 0.1f;

  if (B(1,0) > B_1st_thresh) {
    vcl_printf ("  B(1,0) %f > thresh %.2f.\n", B(1,0), B_1st_thresh);
    return false;
  }
  if (B(2,0) > B_1st_thresh) {
    vcl_printf ("  B(2,0) %f > thresh %.2f.\n", B(2,0), B_1st_thresh);
    return false;
  }
  if (B(3,0) > B_1st_thresh) {
    vcl_printf ("  B(3,0) %f > thresh %.2f.\n", B(3,0), B_1st_thresh);
    return false;
  }

  return true;
}

bool test_2nd_convergence (const vnl_matrix<double>& B)
{
  vcl_printf ("test_2nd_convergence():\n");

  if (test_1st_convergence (B) == false)
    return false;

  //double pixel = B(0,0) + B(1,0)*x_1 + B(2,0)*x_2 + B(3,0)*x_3 +
  //               B(4,0)*x_1*x_2 + B(5,0)*x_1*x_3 + B(6,0)*x_2*x_3 +
  //               B(7,0)*x_1*x_1 + B(8,0)*x_2*x_2 + B(9,0)*x_3*x_3;
  const float B_2nd_thresh = 0.01f;

  if (B(4,0) > B_2nd_thresh)
    return false;
  if (B(5,0) > B_2nd_thresh)
    return false;
  if (B(6,0) > B_2nd_thresh)
    return false;
  if (B(7,0) > B_2nd_thresh)
    return false;
  if (B(8,0) > B_2nd_thresh)
    return false;
  if (B(9,0) > B_2nd_thresh)
    return false;

  return false;
}

