#include <vcl_cstdio.h>
#include "afcm_regression.h"

#include "itkImageRegionIteratorWithIndex.h"

//===================================================================
// Regression on image intensity

void img_regression_linear (const ImageType::Pointer& image,
                            const float thresh,
                            vnl_matrix<double>& B)
{
  vcl_printf ("    img_regression_linear(): \n");
  int i;

  //Put image intensity into y[].
  //Put image pixel coordinates into x1[], x2[], x3[].
  typedef itk::ImageRegionIteratorWithIndex < ImageType > IndexedIteratorType;
  IndexedIteratorType iit (image, image->GetRequestedRegion());
  iit.GoToBegin();
  assert (iit.GetIndex().GetIndexDimension() == 3);
  
  //Determine the total number of pixels > thresh.
  ///ImageType::SizeType requestedSize = image->GetRequestedRegion().GetSize();  
  ///int SZ = requestedSize[0] * requestedSize[1] * requestedSize[2];
  int SZ = 0;
  for (i=0, iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    float pixel = iit.Get();
    if (pixel > thresh)
      SZ++;
  }
  vcl_printf ("      # pixels > thresh (%f) = %d\n", thresh, SZ);
  
  vnl_matrix<double> y (SZ,1);
  vnl_matrix<double> x1 (SZ,1);
  vnl_matrix<double> x2 (SZ,1);
  vnl_matrix<double> x3 (SZ,1);

  for (i=0, iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    float pixel = iit.Get();
    if (pixel > thresh) {
      assert (i < SZ);
      y(i, 0) = pixel;
      int x_1 = idx[0];
      int x_2 = idx[1];
      int x_3 = idx[2];
      x1(i, 0) = x_1;
      x2(i, 0) = x_2;
      x3(i, 0) = x_3;
      i++;
    }
  }

  //Prepare the design matrix X
  vnl_matrix<double> X (SZ,4);
  X.set_column (0, 1.0);
  X.update (x1, 0, 1);
  X.update (x2, 0, 2);
  X.update (x3, 0, 3);
  ///vcl_cerr << X;  
  x1.clear();
  x2.clear();
  x3.clear();

  vnl_matrix<double> Xt = X.transpose();
  vnl_matrix<double> Xt_X = Xt * X; //(x'*x)
  X.clear();
  vnl_matrix<double> Xt_y = Xt * y; //(x'*y)
  Xt.clear();
  y.clear();
  //Solve for the linear normal equation: (x'*x) * b = (x'*y)
  vnl_matrix<double> Xt_X_inv = vnl_matrix_inverse<double>(Xt_X);
  Xt_X.clear();
  //b = inv(x'*x) * (x'*y);
  B = Xt_X_inv * Xt_y;
  
  vcl_printf ("B: \n");
  vcl_cerr << B;
}

//Use B to compute a new fitting
void compute_linear_fit_img (const vnl_matrix<double>& B, 
                             const ImageType::Pointer& fit_image)
{
  vcl_printf ("    compute_linear_fit_img(): \n");

  //Traverse through the fit_image and compute a new quadratic value via B.
  //Image coordinates into x1[], x2[], x3[].
  typedef itk::ImageRegionIteratorWithIndex < ImageType > IndexedIteratorType;
  IndexedIteratorType iit (fit_image, fit_image->GetRequestedRegion());
  assert (iit.GetIndex().GetIndexDimension() == 3);
  assert (B.rows() == 4);

  for (iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    int x_1 = idx[0];
    int x_2 = idx[1];
    int x_3 = idx[2];
    double pixel = B(0,0) + B(1,0)*x_1 + B(2,0)*x_2 + B(3,0)*x_3;
    iit.Set (pixel);
  }
}

void img_regression_quadratic (const ImageType::Pointer& image,
                               const float thresh,
                               vnl_matrix<double>& B)
{
  vcl_printf ("    img_regression_quadratic(): \n");
  int i;

  //Put image intensity into y[].
  //Put image pixel coordinates into x1[], x2[], x3[].
  typedef itk::ImageRegionIteratorWithIndex < ImageType > IndexedIteratorType;
  IndexedIteratorType iit (image, image->GetRequestedRegion());
  assert (iit.GetIndex().GetIndexDimension() == 3);
  
  //Determine the total number of pixels > thresh.
  ///ImageType::SizeType requestedSize = image->GetRequestedRegion().GetSize();  
  ///int SZ = requestedSize[0] * requestedSize[1] * requestedSize[2];
  int SZ = 0;  
  for (i=0, iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    float pixel = iit.Get();
    if (pixel > thresh)
      SZ++;
  }
  vcl_printf ("      # pixels > thresh (%f) = %d\n", thresh, SZ);

  vnl_matrix<double> y (SZ,1);
  vnl_matrix<double> x1 (SZ,1);
  vnl_matrix<double> x2 (SZ,1);
  vnl_matrix<double> x3 (SZ,1);
  for (i=0, iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    double pixel = iit.Get();
    if (pixel > thresh) {
      assert (i < SZ);
      y(i, 0) = pixel;
      int x_1 = idx[0];
      int x_2 = idx[1];
      int x_3 = idx[2];
      x1(i, 0) = (double) x_1;
      x2(i, 0) = (double) x_2;
      x3(i, 0) = (double) x_3;
      i++;
    }
  }

  //Prepare the design matrix X
  vnl_matrix<double> X (SZ,10);
  X.set_column (0, 1.0);
  X.update (x1, 0, 1);
  X.update (x2, 0, 2);
  X.update (x3, 0, 3);
  x1.clear();
  x2.clear();
  x3.clear();

  vnl_matrix<double> x1x2 (SZ,1);
  vnl_matrix<double> x1x3 (SZ,1);
  vnl_matrix<double> x2x3 (SZ,1);  
  for (i=0, iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    double pixel = iit.Get();
    if (pixel > thresh) {
      assert (i < SZ);
      int x_1 = idx[0];
      int x_2 = idx[1];
      int x_3 = idx[2];
      x1x2 (i, 0) = x_1 * x_2;
      x1x3 (i, 0) = x_1 * x_3;
      x2x3 (i, 0) = x_2 * x_3;
      i++;
    }
  }
  X.update (x1x2, 0, 4);
  X.update (x1x3, 0, 5);
  X.update (x2x3, 0, 6);
  x1x2.clear();
  x1x3.clear();
  x2x3.clear();
  
  vnl_matrix<double> x1x1 (SZ,1);
  vnl_matrix<double> x2x2 (SZ,1);
  vnl_matrix<double> x3x3 (SZ,1);
  for (i=0, iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    double pixel = iit.Get();
    if (pixel > thresh) {
      assert (i < SZ);
      int x_1 = idx[0];
      int x_2 = idx[1];
      int x_3 = idx[2];
      x1x1 (i, 0) = x_1 * x_1;
      x2x2 (i, 0) = x_2 * x_2;
      x3x3 (i, 0) = x_3 * x_3;
      i++;
    }
  }
  X.update (x1x1, 0, 7);
  X.update (x2x2, 0, 8);
  X.update (x3x3, 0, 9);
  x1x1.clear();
  x2x2.clear();
  x3x3.clear();

  ///vcl_printf ("X: \n");
  ///vcl_cerr << X;

  vnl_matrix<double> Xt = X.transpose();
  vnl_matrix<double> Xt_X = Xt * X; //(x'*x)
  X.clear();
  vnl_matrix<double> Xt_y = Xt * y; //(x'*y)
  Xt.clear();
  y.clear();
  //Solve for the linear normal equation: (x'*x) * b = (x'*y)
  vnl_matrix<double> Xt_X_inv = vnl_matrix_inverse<double>(Xt_X);  
  Xt_X.clear();
  //b = inv(x'*x) * (x'*y);
  B = Xt_X_inv * Xt_y;
  
  vcl_printf ("B: \n");
  vcl_cerr << B;
}

//Use B to compute a new fitting
void compute_quadratic_fit_img (const vnl_matrix<double>& B,
                                const ImageType::Pointer& fit_image)
{
  vcl_printf ("    compute_quadratic_fit_img(): \n");

  //Traverse through the fit_image and compute a new quadratic value via B.
  //Image coordinates into x1[], x2[], x3[].
  typedef itk::ImageRegionIteratorWithIndex < ImageType > IndexedIteratorType;
  IndexedIteratorType iit (fit_image, fit_image->GetRequestedRegion());
  assert (iit.GetIndex().GetIndexDimension() == 3);
  assert (B.rows() == 10);

  for (iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    int x_1 = idx[0];
    int x_2 = idx[1];
    int x_3 = idx[2];
    double pixel = B(0,0) + B(1,0)*x_1 + B(2,0)*x_2 + B(3,0)*x_3 +
                   B(4,0)*x_1*x_2 + B(5,0)*x_1*x_3 + B(6,0)*x_2*x_3 +
                   B(7,0)*x_1*x_1 + B(8,0)*x_2*x_2 + B(9,0)*x_3*x_3;
    iit.Set (pixel);
  }
}
