
#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#endif

#include "itkOrientedImage.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkImageRegionIteratorWithIndex.h"

#include "itkOrientImageFilter.h"
#include "vcl_vector.h"
#include "vcl_cstdio.h"

typedef itk::OrientedImage<float, 3> ImageType;

void update_gain_to_image (const ImageType::Pointer& gain_field, 
                           const ImageType::Pointer& image,
                           float globalmean)
{
  vcl_printf ("    update_gain_to_image():\n");

  //Both use gain_field's region.
  typedef itk::ImageRegionConstIteratorWithIndex < ImageType > ConstIndexIteratorType;
  typedef itk::ImageRegionIterator < ImageType > IteratorType;
  ConstIndexIteratorType itg (gain_field, gain_field->GetRequestedRegion());
  IteratorType it (image, gain_field->GetRequestedRegion());

  for (itg.GoToBegin(), it.GoToBegin(); !itg.IsAtEnd(); ++itg, ++it) {
    ImageType::PixelType gain = itg.Get();
    ImageType::PixelType pixel = it.Get();

    ///Debug
    ///ImageType::IndexType idxg = itg.GetIndex();
    ///if (idxg[0]==97 && idxg[1]==87 && idxg[2]==33) {
    ///  vcl_printf ("\n  pixel = %f, gain (%d, %d, %d) = %f.", 
    ///              pixel, idxg[0], idxg[1], idxg[2], gain);
    ///}

    pixel = globalmean * static_cast<float>(pixel) / static_cast<float>(gain);

    it.Set (pixel);
  }
}


//Use B to compute a new fitting
void compute_quadratic_fit_img (const vnl_matrix<double>& B,
                                const ImageType::Pointer& fit_image)
{
  vcl_printf ("    compute_quadratic_fit_img(): \n");

  //Traverse through the fit_image and compute a new quadratic value via B.
  //Image coordinates into x1[], x2[], x3[].
  typedef itk::ImageRegionIteratorWithIndex < ImageType > IndexedIteratorType;
  IndexedIteratorType iit (fit_image, fit_image->GetRequestedRegion());
  assert (iit.GetIndex().GetIndexDimension() == 3);
  assert (B.rows() == 10);

  for (iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    ImageType::IndexType idx = iit.GetIndex();
    int x_1 = idx[0];
    int x_2 = idx[1];
    int x_3 = idx[2];
    double pixel = B(0,0) + B(1,0)*x_1 + B(2,0)*x_2 + B(3,0)*x_3 +
                   B(4,0)*x_1*x_2 + B(5,0)*x_1*x_3 + B(6,0)*x_2*x_3 +
                   B(7,0)*x_1*x_1 + B(8,0)*x_2*x_2 + B(9,0)*x_3*x_3;
    iit.Set (pixel);
  }
}



void grid_regression_quadratic (const vcl_vector<float> & centroid_v_grid,
                                const vcl_vector<ImageType::IndexType>& grid_center_index,
                                vnl_matrix<double>& B)
{
  vcl_printf ("    grid_regression_quadratic(): \n");
  int i;

  //Put centroid index coordinates into x1[], x2[], x3[].
  assert (centroid_v_grid.size() == grid_center_index.size());
  
  //Determine the total number of qualified inputs.
  vcl_printf ("      WM centroid(s): ");
  int SZ = 0;
  for (i=0; i<centroid_v_grid.size(); i++) {
    vcl_printf ("%.2f ", centroid_v_grid[i]);
    if (centroid_v_grid[i] > 0) {
      SZ++;
    }
  }
  vcl_printf ("\n      Total # WM centroid(s) used = %d.\n", SZ);

  vnl_matrix<double> y (SZ,1);
  vnl_matrix<double> x1 (SZ,1);
  vnl_matrix<double> x2 (SZ,1);
  vnl_matrix<double> x3 (SZ,1);

  int c = 0;
  for (int i=0; i<centroid_v_grid.size(); i++) {
    if (centroid_v_grid[i] <= 0) 
      continue;

    assert (c < SZ);
    y(c, 0) = centroid_v_grid[i]; //only use the WM for now
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    x1(c, 0) = x_1;
    x2(c, 0) = x_2;
    x3(c, 0) = x_3;
    c++;
  }

  //Prepare the design matrix X
  vnl_matrix<double> X (SZ,10);
  X.set_column (0, 1.0);
  X.update (x1, 0, 1);
  X.update (x2, 0, 2);
  X.update (x3, 0, 3);
  ///vcl_cerr << X;  
  x1.clear();
  x2.clear();
  x3.clear();

  vnl_matrix<double> x1x2 (SZ,1);
  vnl_matrix<double> x1x3 (SZ,1);
  vnl_matrix<double> x2x3 (SZ,1); 
  c = 0;
  for (int i=0; i<centroid_v_grid.size(); i++) {
    if (centroid_v_grid[i] <= 0) 
      continue;

    assert (c < SZ);
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    x1x2 (c, 0) = x_1 * x_2;
    x1x3 (c, 0) = x_1 * x_3;
    x2x3 (c, 0) = x_2 * x_3;
    c++;
  }
  X.update (x1x2, 0, 4);
  X.update (x1x3, 0, 5);
  X.update (x2x3, 0, 6);
  x1x2.clear();
  x1x3.clear();
  x2x3.clear();
  
  vnl_matrix<double> x1x1 (SZ,1);
  vnl_matrix<double> x2x2 (SZ,1);
  vnl_matrix<double> x3x3 (SZ,1);
  c = 0;
  for (int i=0; i<centroid_v_grid.size(); i++) {
    if (centroid_v_grid[i] <= 0) 
      continue;
    assert (i < SZ);
    int x_1 = grid_center_index[i][0];
    int x_2 = grid_center_index[i][1];
    int x_3 = grid_center_index[i][2];
    x1x1 (i, 0) = x_1 * x_1;
    x2x2 (i, 0) = x_2 * x_2;
    x3x3 (i, 0) = x_3 * x_3;
    c++;
  }
  X.update (x1x1, 0, 7);
  X.update (x2x2, 0, 8);
  X.update (x3x3, 0, 9);
  x1x1.clear();
  x2x2.clear();
  x3x3.clear();

  ///vcl_printf ("X: \n");
  ///vcl_cerr << X;

  vnl_matrix<double> Xt = X.transpose();
  vnl_matrix<double> Xt_X = Xt * X; //(x'*x)
  X.clear();
  vnl_matrix<double> Xt_y = Xt * y; //(x'*y)
  Xt.clear();
  y.clear();
  //Solve for the linear normal equation: (x'*x) * b = (x'*y)
  vnl_matrix<double> Xt_X_inv = vnl_matrix_inverse<double>(Xt_X);  
  Xt_X.clear();
  //b = inv(x'*x) * (x'*y);
  B = Xt_X_inv * Xt_y;
  
  vcl_printf ("B: \n");
  vcl_cerr << B;
}



bool detect_bnd_box (const ImageType::Pointer& image, 
                     const float bg_thresh, 
                     int& xmin, int& ymin, int& zmin, 
                     int& xmax, int& ymax, int& zmax)
{
  vcl_printf ("    detect_bnd_box(): bg_thresh %f.\n", bg_thresh);

  typedef itk::ImageRegionIteratorWithIndex < ImageType > IndexedIteratorType;
  IndexedIteratorType iit (image, image->GetRequestedRegion());
  assert (iit.GetIndex().GetIndexDimension() == 3);

  xmin = INT_MAX;
  ymin = INT_MAX;
  zmin = INT_MAX;
  xmax = -INT_MAX;
  ymax = -INT_MAX;
  zmax = -INT_MAX;

  for (iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    if (iit.Get() <= bg_thresh)
      continue;

    ImageType::IndexType idx = iit.GetIndex();
    int x = idx[0];
    int y = idx[1];
    int z = idx[2];
    
    if (x < xmin)
      xmin = x;
    if (y < ymin)
      ymin = y;
    if (z < zmin)
      zmin = z;
    if (x > xmax)
      xmax = x;
    if (y > ymax)
      ymax = y;
    if (z > zmax)
      zmax = z;
  }

  if (xmin == INT_MAX || ymin == INT_MAX || zmin == INT_MAX ||
      xmax == INT_MIN || ymax == INT_MIN || zmax == INT_MIN) {
    //error: no pixel with intensity > bf_thresh.
    vcl_printf ("Error: no pixel with intensity > bf_thresh!\n");
    return false;
  }

  vcl_printf ("      (%d, %d, %d) - (%d, %d, %d).\n\n",
              xmin, ymin, zmin, xmax, ymax, zmax);
  return true;
}

void compute_grid_imgs (const ImageType::Pointer& image, 
                        const int xmin, const int ymin, const int zmin, 
                        const int xmax, const int ymax, const int zmax, 
                        const int n_grid, 
                        vcl_vector<ImageType::Pointer>& image_grid,
                        vcl_vector<ImageType::IndexType>& grid_center_index)
{
  vcl_printf ("    compute_grid_imgs(): %d * %d * %d grids.\n", 
              n_grid, n_grid, n_grid);

  const int total_grids = n_grid * n_grid * n_grid;

  const int total_size_x = xmax-xmin+1;
  const int total_size_y = ymax-ymin+1;
  const int total_size_z = zmax-zmin+1;
  vcl_printf ("      size of region in interest: %d * %d * %d.\n", 
              total_size_x, total_size_y, total_size_z);

  const int grid_size_x = vcl_ceil (double(total_size_x) / n_grid);
  const int grid_size_y = vcl_ceil (double(total_size_y) / n_grid);
  const int grid_size_z = vcl_ceil (double(total_size_z) / n_grid);
  vcl_printf ("      regular grid size: %d * %d * %d.\n", 
              grid_size_x, grid_size_y, grid_size_z);

  const int grid_size_last_x = total_size_x - grid_size_x * (n_grid-1);
  const int grid_size_last_y = total_size_y - grid_size_y * (n_grid-1);
  const int grid_size_last_z = total_size_z - grid_size_z * (n_grid-1);
  vcl_printf ("      last slice, column, row grid size: %d * %d * %d.\n", 
              grid_size_last_x, grid_size_last_y, grid_size_last_z);
    
  ImageType::RegionType region;
  ImageType::IndexType index;
  ImageType::SizeType  size;

  image_grid.resize (total_grids);
  int i = 0;
  for (int x=0; x<n_grid; x++) {
    int start_x = grid_size_x * x + xmin;
    int grid_x = grid_size_x;
    if (x==n_grid-1) //last row
      grid_x = grid_size_last_x;    

    for (int y=0; y<n_grid; y++) {
      int start_y = grid_size_y * y + ymin;
      int grid_y = grid_size_y;
      if (y==n_grid-1) //last column
        grid_y = grid_size_last_y;

      for (int z=0; z<n_grid; z++) {
        int start_z = grid_size_z * z + zmin;
        int grid_z = grid_size_z;
        if (z==n_grid-1) //last slice
          grid_z = grid_size_last_z;

        //create image_grid[i] with size grid_x * grid_y * grid_z.
        image_grid[i] = ImageType::New();

        //setup grid index
        index[0] = start_x; // first index on X
        index[1] = start_y; // first index on Y
        index[2] = start_z; // first index on Z

        //setup grid size
        size[0] = grid_x;
        size[1] = grid_y;
        size[2] = grid_z;

        //setup grid center
        ImageType::IndexType grid_center;
        grid_center[0] = start_x + grid_x / 2;
        grid_center[1] = start_y + grid_y / 2;
        grid_center[2] = start_z + grid_z / 2;
        grid_center_index.push_back (grid_center);

        region.SetIndex (index);
        region.SetSize (size);
        image_grid[i]->SetRegions (region);

        image_grid[i]->SetOrigin (image->GetOrigin());        
        image_grid[i]->SetSpacing (image->GetSpacing());
        image_grid[i]->Allocate();

        //copy image pixel values into image_grid[i].
        typedef itk::ImageRegionConstIterator < ImageType > ConstIteratorType;
        typedef itk::ImageRegionIterator < ImageType > IteratorType;
        ConstIteratorType it (image, region);
        IteratorType itg (image_grid[i], image_grid[i]->GetRequestedRegion());

        float max_pixel = -FLT_MAX;
        for (it.GoToBegin(), itg.GoToBegin(); !it.IsAtEnd(); ++it, ++itg) {
          //debug:
          //ImageType::IndexType idx = it.GetIndex();
          //int itx = idx[0];
          //int ity = idx[1];
          //int itz = idx[2];
          //ImageType::IndexType idxg = itg.GetIndex();
          //int itgx = idxg[0];
          //int itgy = idxg[1];
          //int itgz = idxg[2];

          ImageType::PixelType pixel = it.Get();
          itg.Set (pixel);
          if (pixel > max_pixel)
            max_pixel = pixel;
        }

        vcl_printf ("      grid %d [%d * %d * %d] max_pixel %f, center (%d, %d, %d).\n", 
                    i, grid_x, grid_y, grid_z, max_pixel,
                    grid_center[0], grid_center[1], grid_center[2]);

        ///debug: write image_grid[i] to file for further examination.
        ///save_img8 ("grid.mhd", image_grid[i]);

        i++;
      }
    }
  }

}



int main( int argc, char *argv[] )
{

  itk::ImageFileReader<ImageType>::Pointer imageReader = itk::ImageFileReader<ImageType>::New();
  imageReader->SetFileName( argv[1] );
  imageReader->Update();
  ImageType::Pointer image = imageReader->GetOutput();

  itk::ImageFileReader<ImageType>::Pointer wmReader = itk::ImageFileReader<ImageType>::New();
  wmReader->SetFileName( argv[2] );
  wmReader->Update();
  ImageType::Pointer wm = wmReader->GetOutput();

  
  // determine boundary box
  int xmin, ymin, zmin;
  int xmax, ymax, zmax;
  bool r = detect_bnd_box (wm, 128, xmin, ymin, zmin, xmax, ymax, zmax);

  //Space division: into nxnxn: 3x3x3 or 4x4x4 blocks.
  vcl_vector<ImageType::Pointer> img_y_grid;
  vcl_vector<ImageType::IndexType> grid_center_index;
  compute_grid_imgs (wm, xmin, ymin, zmin, xmax, ymax, zmax, 3, 
                     img_y_grid, grid_center_index);


  int nBlocks = img_y_grid.size();
  vcl_vector<float> centroid_v_grid;
  centroid_v_grid.resize( nBlocks );

  float globalmean = 0;
  float globalcount = 0;

  for (int k = 0; k < nBlocks; k++)
    {
    centroid_v_grid[k] = 0;
    ImageType::IndexType idx;
    idx.Fill( 0 );
    grid_center_index[k] = idx;

    itk::ImageRegionIteratorWithIndex<ImageType> it( img_y_grid[k], img_y_grid[k]->GetLargestPossibleRegion() );

    int nWM = 0;
    
    for (it.GoToBegin(); !it.IsAtEnd(); ++it)
      {
      if (it.Get() < 128)
        {
        continue;
        }
      nWM ++;
      globalcount ++;
      idx = it.GetIndex();
      centroid_v_grid[k] += image->GetPixel( idx );
      globalmean += image->GetPixel( idx );
      for (int m = 0; m < 3; m++)
        {
        grid_center_index[k][m] += idx[m];
        }
      }

    if (nWM > 0)
      {
      centroid_v_grid[k] /= static_cast<float>( nWM );
      for (int m = 0; m < 3; m++)
        {
        grid_center_index[k][m] = static_cast<long int>(static_cast<float>(grid_center_index[k][m])/static_cast<float>(nWM));
        }
      std::cout << centroid_v_grid[k] << std::endl;
      }
    }

  globalmean /= globalcount;
  std::cout << "global white matter mean: " << globalmean << std::endl;

  vnl_matrix<double> B;
  grid_regression_quadratic (centroid_v_grid, grid_center_index, B);

  ImageType::Pointer gain_field = ImageType::New();  
  gain_field->CopyInformation (image);
  gain_field->SetRegions (image->GetLargestPossibleRegion());
  gain_field->Allocate();
  gain_field->FillBuffer (1.0f);

  compute_quadratic_fit_img (B, gain_field);

  update_gain_to_image (gain_field, image, globalmean);

  itk::ImageFileWriter<ImageType>::Pointer w = itk::ImageFileWriter<ImageType>::New();
  w->SetFileName( argv[3] );
  w->SetInput( image );
  w->Update();

  for (int k = 0; k < nBlocks; k++)
    {
    centroid_v_grid[k] = 0;
    ImageType::IndexType idx;
    idx.Fill( 0 );
    grid_center_index[k] = idx;

    itk::ImageRegionIteratorWithIndex<ImageType> it( img_y_grid[k], img_y_grid[k]->GetLargestPossibleRegion() );

    int nWM = 0;
    
    for (it.GoToBegin(); !it.IsAtEnd(); ++it)
      {
      if (it.Get() < 128)
        {
        continue;
        }
      nWM ++;
      globalcount ++;
      idx = it.GetIndex();
      centroid_v_grid[k] += image->GetPixel( idx );
      globalmean += image->GetPixel( idx );
      for (int m = 0; m < 3; m++)
        {
        grid_center_index[k][m] += idx[m];
        }
      }

    if (nWM > 0)
      {
      centroid_v_grid[k] /= static_cast<float>( nWM );
      for (int m = 0; m < 3; m++)
        {
        grid_center_index[k][m] = static_cast<long int>(static_cast<float>(grid_center_index[k][m])/static_cast<float>(nWM));
        }
      std::cout << centroid_v_grid[k] << std::endl;
      }
    }

  return 0;
}


