#include <vcl_cstdio.h>

#include "itkBinaryErodeImageFilter.h"
#include "itkBinaryDilateImageFilter.h"
#include "itkBinaryBallStructuringElement.h" 

#include "itkGrayscaleErodeImageFilter.h"
#include "itkGrayscaleDilateImageFilter.h"
#include "itkBinaryBallStructuringElement.h" 

#include "itkMedianImageFilter.h"

#include "itkScalarImageToListAdaptor.h"
#include "itkListSampleToHistogramGenerator.h"
#include "itkDiscreteGaussianImageFilter.h"
#include "itkRescaleIntensityImageFilter.h"
#include "itkGradientAnisotropicDiffusionImageFilter.h"

#include <3d_image/3d_image_util.h>
#include <3d_image/3d_image_io.h>


void BinaryClosingImage (ImageType::Pointer& image, const int radius)
{
  vcl_printf ("BinaryClosingImage(): \n");

  ImageType::PixelType background = 0;
  ImageType::PixelType foreground = 1;

  typedef itk::BinaryBallStructuringElement<float,3> StructuringElementType;

  typedef itk::BinaryErodeImageFilter< ImageType, ImageType,
                            StructuringElementType >  ErodeFilterType;

  typedef itk::BinaryDilateImageFilter< ImageType, ImageType, 
                            StructuringElementType >  DilateFilterType;

  ErodeFilterType::Pointer  erode  = ErodeFilterType::New();
  DilateFilterType::Pointer dilate = DilateFilterType::New();

  StructuringElementType  structuringElement;
  structuringElement.SetRadius( radius );  
  structuringElement.CreateStructuringElement();

  erode->SetKernel(  structuringElement );
  dilate->SetKernel( structuringElement );

  erode->SetErodeValue( foreground );
  dilate->SetDilateValue( foreground );

  dilate -> SetInput( image );
  erode -> SetInput ( dilate -> GetOutput() );
  erode -> Update ( );
  image = erode -> GetOutput ( );
  vcl_printf ("\t done.\n\n");
}

void BinaryOpenningImage (ImageType::Pointer& image, const int radius)
{
  vcl_printf ("BinaryOpenningImage(): \n");

  ImageType::PixelType background = 0;
  ImageType::PixelType foreground = 1;

  typedef itk::BinaryBallStructuringElement<float,3> StructuringElementType;

  typedef itk::BinaryErodeImageFilter< ImageType, ImageType, 
                  StructuringElementType > ErodeFilterType;

  typedef itk::BinaryDilateImageFilter< ImageType, ImageType, 
                  StructuringElementType > DilateFilterType;

  ErodeFilterType::Pointer  eroder  = ErodeFilterType::New();
  DilateFilterType::Pointer dilater = DilateFilterType::New();

  StructuringElementType  structuringElement;
  structuringElement.SetRadius( radius );  
  structuringElement.CreateStructuringElement();

  eroder->SetKernel(  structuringElement );
  dilater->SetKernel( structuringElement );

  eroder->SetErodeValue( foreground );
  dilater->SetDilateValue( foreground );

  eroder -> SetInput( image );
  dilater -> SetInput ( eroder -> GetOutput() );
  dilater -> Update ( );
  image = dilater -> GetOutput ( );
  vcl_printf ("\t done.\n\n");
}


void compute_histogram (const ImageType::Pointer& image, 
                        vcl_vector<float>& histVector,
                        vcl_vector<float>& binMax,
                        vcl_vector<float>& binMin,
                        int& nBin)
{
  vcl_printf ("compute_histogram(): \n");

  typedef itk::Statistics::ScalarImageToListAdaptor< ImageType >   AdaptorType;
  AdaptorType::Pointer adaptor = AdaptorType::New();
  adaptor->SetImage (image);
  typedef ImageType::PixelType  HistogramMeasurementType;
  typedef itk::Statistics::ListSampleToHistogramGenerator< 
                AdaptorType, HistogramMeasurementType> GeneratorType;
  GeneratorType::Pointer generator = GeneratorType::New();
  typedef GeneratorType::HistogramType  HistogramType;

  // let the program decide the number of bins 
  // using the maximum and minimum intensity values
  if (nBin == 0) {
    typedef itk::ImageRegionIterator< ImageType > IteratorType;
    IteratorType it (image, image->GetLargestPossibleRegion());
    ImageType::PixelType bMin = it.Get();
    ImageType::PixelType bMax = it.Get();

    for ( it.GoToBegin(); !it.IsAtEnd(); ++it) {
      ImageType::PixelType d = it.Get();
      if (bMin > d ) {
        bMin = d;
      }
      if (bMax < d) {
        bMax = d;
      }
    }
    nBin = static_cast<int> (bMax-bMin+1);
  }

  HistogramType::SizeType histogramSize;
  histogramSize.Fill (nBin);

  generator->SetListSample (adaptor);
  generator->SetNumberOfBins (histogramSize);
  generator->SetMarginalScale (10.0);
  generator->Update();

  HistogramType::ConstPointer histogram = generator->GetOutput();
  const unsigned int hs = histogram->Size();

  histVector.clear();
  binMax.clear();
  binMin.clear();

  ///debug: vcl_printf ("\n");
  for (int k = 0; k < hs; k++) {
    float hist_v = histogram->GetFrequency(k, 0);
    float bin_min = histogram->GetBinMin(0, k);
    float bin_max = histogram->GetBinMax(0, k);
    binMin.push_back (bin_min);
    binMax.push_back (bin_max);
    histVector.push_back (hist_v);
    ///vcl_printf ("h(%.1f,%.1f)=%.0f ", bin_min, bin_max, hist_v);
    ///if (k % 3 == 0)
      ///vcl_printf ("\n");
  }
  ///vcl_printf ("\t done.\n");
}

void HistogramEqualization (const ImageType::Pointer& image)
{
  vcl_printf ("HistogramEqualization(): \n");

  typedef itk::ImageRegionIterator< ImageType > IteratorType;

  IteratorType it(  image, image->GetLargestPossibleRegion()  );

  ImageType::PixelType bMin = it.Get();
  ImageType::PixelType bMax = it.Get();

  for ( it.GoToBegin(); !it.IsAtEnd(); ++it) {
    ImageType::PixelType d = it.Get();
    if (bMin > d ) {
      bMin = d;
    }
    if (bMax < d) {
      bMax = d;
    }
  }

  int nBin = static_cast<int> (bMax-bMin+1);
  vcl_vector<float> histVector;
  vcl_vector<float> binMax;
  vcl_vector<float> binMin;

  compute_histogram (image, histVector, binMax, binMin, nBin);

  vnl_vector<float> intensityMap ( nBin );
  intensityMap[0] = histVector[0];
  for ( int k = 1; k < nBin; k++ ) {
    intensityMap[k] = intensityMap[k-1] + histVector[k];
  }

  double totCount = intensityMap[nBin-1];
  for ( int k = 0; k < nBin; k++ ) {
    intensityMap[k] = 255 * intensityMap[k]/totCount;
  }

  bMax = (bMax-bMin)/(nBin-1);

  for ( it.GoToBegin(); !it.IsAtEnd(); ++it) {
    ImageType::PixelType d = it.Get();
    // now bMax is the width of the bins
    int idx = (d-bMin)/bMax; 
    it.Set(intensityMap[idx]);
  }

  vcl_printf ("\t done.\n\n");
}


void BinaryMedianFilter (ImageType::Pointer& image, const int radius)
{
  vcl_printf ("BinaryMedianFilter(): \n");
  typedef itk::MedianImageFilter<ImageType, ImageType >  FilterType;

  FilterType::Pointer filter = FilterType::New();

  ImageType::SizeType indexRadius;
  for (int k = 0; k < 3; k++)
    indexRadius[k] = radius;

  filter->SetRadius( indexRadius );

  filter -> SetInput( image );
  filter -> Update ( );
  image = filter -> GetOutput ( );
  vcl_printf ("\t done.\n\n");
}

void SmoothImage (ImageType::Pointer& image, const float sigma)
{
  vcl_printf ("SmoothImage(): sigma = %f\n", sigma);
  typedef itk::DiscreteGaussianImageFilter<
              ImageType, ImageType >  FilterType;
  FilterType::Pointer filter = FilterType::New();
  filter->SetInput (image);
  filter->SetVariance (sigma*sigma);
  filter->SetMaximumKernelWidth (sigma*5); //sigma*5
  filter->Update();
  image = filter->GetOutput();
  vcl_printf ("\t done.\n\n");
}

void GradientAnisotropicSmooth (ImageType::Pointer& image)
{  
  int iter = 20; //50; //5
  //In general, the time step should be at or below 1/2^N, 
  //where N is the dimensionality of the image.
  float time_step = float(0.25/8); //0.25/8;
  int conductance = 3; //2
  vcl_printf ("GradientAnisotropicSmooth(): iter %d, time_step %f, conductance %d.\n",
              iter, time_step, conductance);
  typedef itk::GradientAnisotropicDiffusionImageFilter< 
                ImageType, ImageType >  FilterType;
  FilterType::Pointer filter = FilterType::New();
  filter->SetInput (image);
  filter->SetNumberOfIterations (iter);   
  filter->SetTimeStep (time_step);
  filter->SetConductanceParameter (conductance);   
  filter->Update();
  image = filter->GetOutput();
  vcl_printf ("\t done.\n\n");
}

void compute_mask_img (const ImageType::Pointer& image, 
                       const Image8Type::Pointer& img_mask, 
                       ImageType::Pointer& result)
{
  typedef itk::ImageRegionIterator< ImageType > IteratorType;
  typedef itk::ImageRegionConstIterator< Image8Type > ConstIteratorType8;

  IteratorType it (image, image->GetRequestedRegion());
  ConstIteratorType8 itm (img_mask, img_mask->GetRequestedRegion());

  for (it.GoToBegin(), itm.GoToBegin(); !it.IsAtEnd(); ++it, ++itm) {    
    unsigned char mask_j = itm.Get();
    if (mask_j == 0) {
      it.Set (0);
    }
  }
}

bool detect_bnd_box (const ImageType::Pointer& image, 
                     const float bg_thresh, 
                     int& xmin, int& ymin, int& zmin, 
                     int& xmax, int& ymax, int& zmax)
{
  vcl_printf ("    detect_bnd_box(): bg_thresh %f.\n", bg_thresh);

  typedef itk::ImageRegionIteratorWithIndex < ImageType > IndexedIteratorType;
  IndexedIteratorType iit (image, image->GetRequestedRegion());
  assert (iit.GetIndex().GetIndexDimension() == 3);

  xmin = INT_MAX;
  ymin = INT_MAX;
  zmin = INT_MAX;
  xmax = -INT_MAX;
  ymax = -INT_MAX;
  zmax = -INT_MAX;

  for (iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
    if (iit.Get() <= bg_thresh)
      continue;

    ImageType::IndexType idx = iit.GetIndex();
    int x = idx[0];
    int y = idx[1];
    int z = idx[2];
    
    if (x < xmin)
      xmin = x;
    if (y < ymin)
      ymin = y;
    if (z < zmin)
      zmin = z;
    if (x > xmax)
      xmax = x;
    if (y > ymax)
      ymax = y;
    if (z > zmax)
      zmax = z;
  }

  if (xmin == INT_MAX || ymin == INT_MAX || zmin == INT_MAX ||
      xmax == INT_MIN || ymax == INT_MIN || zmax == INT_MIN) {
    //error: no pixel with intensity > bf_thresh.
    vcl_printf ("Error: no pixel with intensity > bf_thresh!\n");
    return false;
  }

  vcl_printf ("      (%d, %d, %d) - (%d, %d, %d).\n\n",
              xmin, ymin, zmin, xmax, ymax, zmax);
  return true;
}

void compute_grid_imgs (const ImageType::Pointer& image, 
                        const int xmin, const int ymin, const int zmin, 
                        const int xmax, const int ymax, const int zmax, 
                        const int n_grid, 
                        vcl_vector<ImageType::Pointer>& image_grid,
                        vcl_vector<ImageType::IndexType>& grid_center_index)
{
  vcl_printf ("    compute_grid_imgs(): %d * %d * %d grids.\n", 
              n_grid, n_grid, n_grid);

  const int total_grids = n_grid * n_grid * n_grid;

  const int total_size_x = xmax-xmin+1;
  const int total_size_y = ymax-ymin+1;
  const int total_size_z = zmax-zmin+1;
  vcl_printf ("      size of region in interest: %d * %d * %d.\n", 
              total_size_x, total_size_y, total_size_z);

  const int grid_size_x = vcl_ceil (double(total_size_x) / n_grid);
  const int grid_size_y = vcl_ceil (double(total_size_y) / n_grid);
  const int grid_size_z = vcl_ceil (double(total_size_z) / n_grid);
  vcl_printf ("      regular grid size: %d * %d * %d.\n", 
              grid_size_x, grid_size_y, grid_size_z);

  const int grid_size_last_x = total_size_x - grid_size_x * (n_grid-1);
  const int grid_size_last_y = total_size_y - grid_size_y * (n_grid-1);
  const int grid_size_last_z = total_size_z - grid_size_z * (n_grid-1);
  vcl_printf ("      last slice, column, row grid size: %d * %d * %d.\n", 
              grid_size_last_x, grid_size_last_y, grid_size_last_z);
    
  ImageType::RegionType region;
  ImageType::IndexType index;
  ImageType::SizeType  size;

  image_grid.resize (total_grids);
  int i = 0;
  for (int x=0; x<n_grid; x++) {
    int start_x = grid_size_x * x + xmin;
    int grid_x = grid_size_x;
    if (x==n_grid-1) //last row
      grid_x = grid_size_last_x;    

    for (int y=0; y<n_grid; y++) {
      int start_y = grid_size_y * y + ymin;
      int grid_y = grid_size_y;
      if (y==n_grid-1) //last column
        grid_y = grid_size_last_y;

      for (int z=0; z<n_grid; z++) {
        int start_z = grid_size_z * z + zmin;
        int grid_z = grid_size_z;
        if (z==n_grid-1) //last slice
          grid_z = grid_size_last_z;

        //create image_grid[i] with size grid_x * grid_y * grid_z.
        image_grid[i] = ImageType::New();

        //setup grid index
        index[0] = start_x; // first index on X
        index[1] = start_y; // first index on Y
        index[2] = start_z; // first index on Z

        //setup grid size
        size[0] = grid_x;
        size[1] = grid_y;
        size[2] = grid_z;

        //setup grid center
        ImageType::IndexType grid_center;
        grid_center[0] = start_x + grid_x / 2;
        grid_center[1] = start_y + grid_y / 2;
        grid_center[2] = start_z + grid_z / 2;
        grid_center_index.push_back (grid_center);

        region.SetIndex (index);
        region.SetSize (size);
        image_grid[i]->SetRegions (region);

        image_grid[i]->SetOrigin (image->GetOrigin());        
        image_grid[i]->SetSpacing (image->GetSpacing());
        image_grid[i]->Allocate();

        //copy image pixel values into image_grid[i].
        typedef itk::ImageRegionConstIterator < ImageType > ConstIteratorType;
        typedef itk::ImageRegionIterator < ImageType > IteratorType;
        ConstIteratorType it (image, region);
        IteratorType itg (image_grid[i], image_grid[i]->GetRequestedRegion());

        float max_pixel = -FLT_MAX;
        for (it.GoToBegin(), itg.GoToBegin(); !it.IsAtEnd(); ++it, ++itg) {
          //debug:
          //ImageType::IndexType idx = it.GetIndex();
          //int itx = idx[0];
          //int ity = idx[1];
          //int itz = idx[2];
          //ImageType::IndexType idxg = itg.GetIndex();
          //int itgx = idxg[0];
          //int itgy = idxg[1];
          //int itgz = idxg[2];

          ImageType::PixelType pixel = it.Get();
          itg.Set (pixel);
          if (pixel > max_pixel)
            max_pixel = pixel;
        }

        vcl_printf ("      grid %d [%d * %d * %d] max_pixel %f, center (%d, %d, %d).\n", 
                    i, grid_x, grid_y, grid_z, max_pixel,
                    grid_center[0], grid_center[1], grid_center[2]);

        ///debug: write image_grid[i] to file for further examination.
        ///save_img8 ("grid.mhd", image_grid[i]);

        i++;
      }
    }
  }

}

void compute_gain_from_grids (const vcl_vector<ImageType::Pointer>& gain_field_g_grid, 
                              const ImageType::Pointer& img_y, const float bg_thresh,
                              ImageType::Pointer& gain_field_g)
{
  //Given a set of fitting images, compute a new gain field image.
  vcl_printf ("    compute_gain_from_grids(): \n");

  int i;
  typedef itk::ImageRegionIterator < ImageType > IteratorType;
  typedef itk::ImageRegionIteratorWithIndex < ImageType > IndexIteratorType;  

  //Compute the global mean pixel value.
  //Use only non-background pixels!!
  IteratorType yit (img_y, img_y->GetRequestedRegion());
  double sum = 0;
  int count = 0;
  for (i=0; i<gain_field_g_grid.size(); i++) {
    IndexIteratorType iit (gain_field_g_grid[i], gain_field_g_grid[i]->GetRequestedRegion());
    for (iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
      //Skip if this pixel is in background.
      ImageType::IndexType idx = iit.GetIndex();
      yit.SetIndex (idx);
      if (yit.Get() <= bg_thresh)
        continue;

      double pixel = iit.Get();
      sum += pixel;
      count++;
    }
  }    
  double mean = sum / count;
  vcl_printf ("      %d non-background pixels, intensity mean = %f\n", count, mean);

  
  IteratorType git (gain_field_g, gain_field_g->GetRequestedRegion());
  float max = -FLT_MAX;
  float min = FLT_MAX;

  for (i=0; i<gain_field_g_grid.size(); i++) {
    IndexIteratorType iit (gain_field_g_grid[i], gain_field_g_grid[i]->GetRequestedRegion());
    for (iit.GoToBegin(); !iit.IsAtEnd(); ++iit) {
      double pixel = iit.Get();
      //gain_field = pixel / mean.
      pixel = pixel / mean;
      iit.Set (pixel);

      //update gain_field_g[]
      ImageType::IndexType idx = iit.GetIndex();
      git.SetIndex (idx);
      double value = git.Get();      
      //Update global gain field 
      //gain_field_g(x,y,z) = gain_field_g(x,y,z) * gain_field_g_grid(x,y,z)
      value *= pixel; 
      git.Set (value);
      if (value > max)
        max = value;
      if (value < min)
        min = value;
    }
  }

  vcl_printf ("      min / max of g[]: %f / %f.\n", min, max);
}

void update_gain_to_image (const ImageType::Pointer& gain_field, 
                           const ImageType::Pointer& image)
{
  vcl_printf ("    update_gain_to_image():\n");

  //Both use gain_field's region.
  typedef itk::ImageRegionConstIteratorWithIndex < ImageType > ConstIndexIteratorType;
  typedef itk::ImageRegionIterator < ImageType > IteratorType;
  ConstIndexIteratorType itg (gain_field, gain_field->GetRequestedRegion());
  IteratorType it (image, gain_field->GetRequestedRegion());

  for (itg.GoToBegin(), it.GoToBegin(); !itg.IsAtEnd(); ++itg, ++it) {
    ImageType::PixelType gain = itg.Get();
    ImageType::PixelType pixel = it.Get();

    ///Debug
    ///ImageType::IndexType idxg = itg.GetIndex();
    ///if (idxg[0]==97 && idxg[1]==87 && idxg[2]==33) {
    ///  vcl_printf ("\n  pixel = %f, gain (%d, %d, %d) = %f.", 
    ///              pixel, idxg[0], idxg[1], idxg[2], gain);
    ///}

    pixel = pixel / gain;

    it.Set (pixel);
  }
}

double compute_diff_norm (const vcl_vector<vcl_vector<float> >& centroid_v_grid, 
                          const vcl_vector<float>& centroid_vn_grid)
{
  vcl_printf ("    compute_diff_norm(): \n");

  //print centroid_v_grid[] for the WM.
  vcl_printf ("      centroid_v_grid[%d]:  ", centroid_v_grid.size());
  for (int i=0; i<centroid_v_grid.size(); i++) {
    assert (centroid_v_grid[i].size() == 3);
    vcl_printf ("%4.0f ", centroid_v_grid[i][2]);
  }

  //print centroid_vn_grid
  vcl_printf ("\n      centroid_vn_grid[%d]: ", centroid_vn_grid.size());
  for (int i=0; i<centroid_vn_grid.size(); i++) {
    vcl_printf ("%4.0f ", centroid_vn_grid[i]);
  }

  //Compute SSD.
  double SSD = 0;
  assert (centroid_v_grid.size() == centroid_vn_grid.size());
  for (int i=0; i<centroid_v_grid.size(); i++) {
    int diff = centroid_v_grid[i][2] - centroid_vn_grid[i];
    SSD += (diff * diff);
  }
  return SSD;
}

//mask the final gain_field with image and bg_thresh.
void mask_gain_field (const ImageType::Pointer& image, 
                      const float bg_thresh,
                      const ImageType::Pointer& gain_field_g)
{
  typedef itk::ImageRegionConstIterator < ImageType > ConstIteratorType;
  typedef itk::ImageRegionIterator < ImageType > IteratorType;
  //both use gain_field's region.
  ConstIteratorType it (image, image->GetRequestedRegion());
  IteratorType itg (gain_field_g, image->GetRequestedRegion());
  
  for (it.GoToBegin(), itg.GoToBegin(); !it.IsAtEnd(); ++it, ++itg) {
    ImageType::PixelType pixel = it.Get();    
    if (pixel <= bg_thresh)
      itg.Set (0);
  }
}
