#include "watershed.h"

#define int32_T int 
#define true  1
#define false 0 
#define  MAX_int32_T     ((int32_T)(2147483647))    /* 2147483647  */
#define  MIN_int32_T     ((int32_T)(-2147483647-1)) /* -2147483648 */
#define  MAX_uint32_T    ((uint32_T)(0xFFFFFFFFU))  /* 4294967295  */
#define  MIN_uint32_T    ((uint32_T)(0))
 
void compute_watershed(float *I, int **coor, int N, int xdim, int ydim, int zdim, unsigned long *sort_index, long *L)
{
  int current_label = 0;
  queueADT pixel_queue;
  int32_T *dist;
  int32_T closest_dist;
  int32_T closest_label_value;
  bool closest_label_value_is_unique;
  int32_T fictitious = FICTITIOUS;
  int32_T wshed = WSHED;
  int k;
  int num_processed_pixels;
  int k1;
  int k2;
  int mask = MASK;
  int p;
  int q;
  int r;
  int current_distance;
  float current_level;
  
  
#ifdef DO_NAN_CHECK
  for (k = 0; k < N; k++)
    {
      if (isnanf(I[k]))
        {
	  printf("\n Images:watershed:expectedNonnan:%s",
		 "Input image may not contain NaNs.");
        }
    }
#endif /* DO_NAN_CHECK */
  
  /*
   * If the input array is empty, there's nothing to do here.
   */
  if (N == 0)
    {
      return;
    }
  
  /*
   * Initialize output array.
   */
  for (k = 0; k < N; k++)
    {
      L[k] = INIT;
    }
  
  /*
   * Initialize the pixel queue.
   */
  pixel_queue = QueueCreat();
  
  /*
   * Initialize the distance array, filling it with zeros via mxCalloc.
   */
  dist = new int32_T [N];       
  int *nh = new int [26];
  int nhsize = 0;
  num_processed_pixels = 0;
  while (num_processed_pixels < N)
    {
      /*
       * Find the next set of pixels that all have the same value.
       */
      k1 = num_processed_pixels;
      current_level = I[(int) sort_index[k1]];
      k2 = k1;
      do
        {
	  k2++;
        } 
      while ((k2 < N) && (I[(int) sort_index[k2]] == current_level));
      k2--;
      
      /*
       * Mask all image pixels whose value equals current_level.
       */
      for (k = k1; k <= k2; k++)
        {
	  p = (int) sort_index[k];
	  L[p] = mask;
	  
	  // int *nh = new int[26];
	  nhsize = 0;
	  for(int z=-1; z<=1; z++)
	    for(int y=-1; y<=1; y++)
	      for(int x=-1; x<=1; x++)
		{
		  int nx = coor[p][0]+x;
		  int ny = coor[p][1]+y;
		  int nz = coor[p][2]+z;
		  if( (nx >= 0) && (nx < xdim)
		      && (ny >= 0) && (ny < ydim)
		      && (nz >= 0) && (nz < zdim)
		      && ((x!=0)||(y!=0)||(z!=0)))
		    {
		      nh[nhsize] = nz*xdim*ydim+ny*xdim+nx;
		      nhsize += 1;
		    }
		}
	  for(int i=0; i<nhsize; i++)
	    {
	      q = nh[i];
	      if ((L[q] > 0) || (L[q] == wshed))
                {
		  /*
		   * Initialize queue with neighbors at current_level
		   * of current basins or watersheds.
		   */
		  dist[p] = 1;
		  QueueEnter(pixel_queue, p);
		  break;
                }
            }
	  num_processed_pixels++;
	  // delete [] nh;
        }
      
      current_distance = 1;
      QueueEnter(pixel_queue, fictitious);
      
      /*
       * Extend the basins.
       */
      while (true)
        {
	  // printf("\n just below true(extend)  queue lenth is %d",QueueLength(pixel_queue));
	  p = QueueDelete(pixel_queue);
	  if (p == fictitious)
            {
	      if (QueueLength(pixel_queue) == 0)
                {
		  break;
                }
	      else
                {
		  QueueEnter(pixel_queue, fictitious);
		  current_distance++;
		  // printf("\n just below true(extend  else)  queue lenth is %d",QueueLength(pixel_queue));
		  p = QueueDelete(pixel_queue);
                }
            }
	  
	  /*
	   * NOTE: the code from here down to "detect and process
	   * new minima" is a modified version of the algorithm originally 
	   * published in Vincent and Soille's paper.  That algorithm
	   * could make several changes to L[p] during a single
	   * sweep of its neighbors, which sometimes results in incorrect
	   * labeling.  This seems to be particularly a problem in
	   * higher dimensions with the correspondingly larger number
	   * of neighbors.  Here the algorithm is changed to make a
	   * sweep of the neighborhood, accumulating key information
	   * about it configuration, and then, after the neighborhood
	   * sweep is finished, make one and only one change to L[p].
	   */
	  
	  /*
	   * Find the labeled or watershed neighbors with the closest
	   * distance.  At the same time, put any masked neighbors
	   * whose distance is 0 onto the queue and reset their distance
	   * to 1.
	   */
	  closest_dist = MAX_int32_T;
	  closest_label_value = 0;
	  closest_label_value_is_unique = true;

	  //int *nh = new int[26];
	  nhsize = 0;
	  for(int z=-1; z<=1; z++)
	    for(int y=-1; y<=1; y++)
	      for(int x=-1; x<=1; x++)
		{
		  int nx = coor[p][0]+x;
		  int ny = coor[p][1]+y;
		  int nz = coor[p][2]+z;
		  if( (nx >= 0) && (nx < xdim)
		      && (ny >= 0) && (ny < ydim)
		      && (nz >= 0) && (nz < zdim)
		      && ((x!=0)||(y!=0)||(z!=0)))
		    {
		      nh[nhsize] = nz*xdim*ydim+ny*xdim+nx;
		      nhsize += 1;
		    }
		}
	  for(int i=0; i<nhsize; i++)
	    {
	      q = nh[i];
            
	      if ((L[q] > 0) || (L[q] == WSHED))
                {
		  if (dist[q] < closest_dist)
                    {
		      closest_dist = dist[q];
		      if (L[q] > 0)
                        {
			  closest_label_value = L[q];
                        }
                    }
		  else if (dist[q] == closest_dist)
                    {
		      if (L[q] > 0)
                        {
			  if ((closest_label_value > 0) &&
			      (L[q] != closest_label_value))
                            {
			      closest_label_value_is_unique = false;
                            }
			  closest_label_value = L[q];
                        }
                    }
                }
	      
	      else if ((L[q] == MASK) && (dist[q] == 0))
                {
		  /*
		   * q is a plateau pixel.
		   */
		  dist[q] = current_distance + 1;
		  QueueEnter(pixel_queue, q);
                }
            }
	  
	  /*
	   * Label p.
	   */
	  if ((closest_dist < current_distance) && (closest_label_value > 0))
            {
	      if (closest_label_value_is_unique && 
		  ((L[p] == MASK) || (L[p] == WSHED)))
                {
		  L[p] = closest_label_value;
                }
	      else if (! closest_label_value_is_unique ||
		       (L[p] != closest_label_value))
                {
		  L[p] = WSHED;
                }
            }
	  else if (L[p] == MASK)
            {
	      L[p] = WSHED;
            }

	  //delete [] nh;
        }
      
      /*
       * Detect and process new minima at current_level.
       */
      for (k = k1; k <= k2; k++)
        {
	  p = (int) sort_index[k];
	  dist[p] = 0;
	  if (L[p] == mask)
            {
	      /*
	       * p is inside a new minimum.
	       */
	      current_label++;  /* create a new label */
	      QueueEnter(pixel_queue, p);
	      L[p] = current_label;
	      while (QueueLength(pixel_queue) > 0)
                {
		  // printf("\njust below >0  queue lenth is %d",QueueLength(pixel_queue));
		  q = QueueDelete(pixel_queue);
		  
		  /*
		   * Inspect neighbors of q.
		   */
		  //int *nh = new int[26];
		  nhsize = 0;
		  for(int z=-1; z<=1; z++)
		    for(int y=-1; y<=1; y++)
		      for(int x=-1; x<=1; x++)
			{
			  int nx = coor[q][0]+x;
			  int ny = coor[q][1]+y;
			  int nz = coor[q][2]+z;
			  if( (nx >= 0) && (nx < xdim)
			      && (ny >= 0) && (ny < ydim)
			      && (nz >= 0) && (nz < zdim)
			      && ((x!=0)||(y!=0)||(z!=0)))
			    {
			      nh[nhsize] = nz*xdim*ydim+ny*xdim+nx;
			      nhsize += 1;
			    }
			}
		  for(int i=0; i<nhsize; i++)
		    {
		      r = nh[i];
		      if (L[r] == mask)
                        {
			  QueueEnter(pixel_queue, r);
			  L[r] = current_label;
                        }
                    }
		  // delete [] nh;
                }
            }
        }
    }
  
  //mxAssert(queue_length(pixel_queue) == 0, "");
  QueueDestroy(pixel_queue);
  delete [] dist;
  delete [] nh;
}


void watershed(float ***im, int xdim, int ydim, int zdim, short ***water, float gaussian)
{

  long imsize = xdim*ydim*zdim;
  float *I = vector(imsize);
  int **coor = imatrix(imsize,3);
  
  float ***gx = f3tensor(xdim, ydim, zdim);
  float ***gy = f3tensor(xdim, ydim, zdim);
  float ***gz = f3tensor(xdim, ydim, zdim);
  
  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	{
	  gx[i][j][k] = 0;
	  gy[i][j][k] = 0;
	  gz[i][j][k] = 0;
	}
  float ***sim = f3tensor(xdim,ydim,zdim);
  /*
  float maxvoxel = 0;
  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	{
	  if(maxvoxel < im[i][j][k])
	    maxvoxel = im[i][j][k];
	}
  //float gamma = 2.8;
  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	sim[i][j][k] = maxvoxel*powf(im[i][j][k]/maxvoxel,gamma);
  */

  //float ***fmask = f3tensor(xdim, ydim, zdim);
  //gaussian_smooth(im, xdim, ydim, zdim, 0.2, fmask);

  gaussian_smooth(im, xdim, ydim, zdim, gaussian, sim);

  /*
  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	if(fmask[i][j][k]<=0)
	  sim[i][j][k] = 0;
  
  free_f3tensor(fmask,xdim,ydim,zdim);
  */
  for(int k=1; k<zdim-1; k++)
    for(int j=1; j<ydim-1; j++)
      for(int i=1; i<xdim-1; i++)
	{
	  gx[i][j][k] = sim[i+1][j][k] - sim[i][j][k];
	  gy[i][j][k] = sim[i][j+1][k] - sim[i][j][k];
	  gz[i][j][k] = sim[i][j][k+1] - sim[i][j][k];
	}

  free_f3tensor(sim,xdim,ydim,zdim);

  int index = 0;
  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	{
	  I[index] = sqrt(gx[i][j][k]*gx[i][j][k] + gy[i][j][k]*gy[i][j][k] + gz[i][j][k]*gz[i][j][k]);
	  coor[index][0] = i;
	  coor[index][1] = j;
	  coor[index][2] = k;
	  index += 1;
	}

  free_f3tensor(gx,xdim,ydim,zdim);
  free_f3tensor(gy,xdim,ydim,zdim);
  free_f3tensor(gz,xdim,ydim,zdim);

  unsigned long *sort_index = lvector(imsize);
  indexx(imsize, I, sort_index);
  
  long *L = new long [imsize];
  
  /*
  FILE *fp;
  fp=fopen("grad.txt","w");
  for(int i=0; i<imsize; i++)
    fwrite(&I[i],sizeof(float),1,fp);
  fclose(fp);
  */
  
  compute_watershed(I, coor, imsize, xdim, ydim, zdim, sort_index, L);
  
  for(int i=0; i<imsize; i++)
    water[coor[i][0]][coor[i][1]][coor[i][2]] = (short)L[i];

  free_vector(I, imsize);
  free_lvector(sort_index, imsize);

  delete [] L;

  free_imatrix(coor,imsize,3);
}


void watershed_post(float ***im, short ***water, float ***result,int xdim, int ydim, int zdim)
{
  int num_region = 0;
  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	{
	  if(water[i][j][k]>num_region)
	    num_region = (int)water[i][j][k];
	}
  num_region += 1;

  float *mean = vector(num_region);
  float *size = vector(num_region);

  for(int i=0; i<num_region; i++)
    {
      mean[i] = 0;
      size[i] = 1.0e-5;
    }
  
  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	{
	  mean[(int)water[i][j][k]]+=im[i][j][k];
	  size[(int)water[i][j][k]]+=1;
	}

  float max_pixel =0; 
  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	if(im[i][j][k]>max_pixel)
	  max_pixel = im[i][j][k];

  //printf("\n the max pixel is %f",max_pixel);
  for(int i=0; i<num_region;i++)
    {
      if(size[i] == 0)
	{
	  mean[i] = 0;
	  // printf("\n the size[%d] = 0",i);
	}
      else
	mean[i]/=size[i];
    }
  mean[0] = 0;

   max_pixel =0; 
  for(int i=0; i<num_region;i++)
    if(max_pixel<mean[i])
      max_pixel = mean[i];

  //printf("\n the max mean is %f",max_pixel);

  for(int k=0; k<zdim; k++)
    for(int j=0; j<ydim; j++)
      for(int i=0; i<xdim; i++)
	result[i][j][k]=mean[(int)water[i][j][k]];

  free_vector(mean,num_region);
  free_vector(size,num_region);
}
	

void make_gaussian_kernel(float sigma, float **kernel, int *windowsize)
{
   int i, center;
   float x, fx, sum=0.0;

   *windowsize =(int)( 1 + 2 * ceil(2.5 * sigma));
   center = (int)((*windowsize) / 2);

   // printf("  The kernel has %d elements.\n", *windowsize);
   
   if((*kernel = vector(*windowsize)) == NULL){
     printf("Error callocing the gaussian kernel array.\n");
      exit(1);
   }

   for(i=0;i<(*windowsize);i++){
      x = (float)(i - center);
      fx = pow(2.71828, -0.5*x*x/(sigma*sigma)) / (sigma * sqrt(6.2831853));
      (*kernel)[i] = fx;
      sum += fx;
   }

   for(i=0;i<(*windowsize);i++) (*kernel)[i] /= sum;
   
   
   // printf("The filter coefficients are:\n");
   //for(i=0;i<(*windowsize);i++)
   //  printf("(*kernel)[%d] = %f\n", i,(*kernel)[i]);   
}

void gaussian_smooth(float ***image, int xdim, int ydim, int zdim, float sigma, float ***smoothedim)
{
   int x, y, z, xx, yy, zz,     /* Counter variables. */
      windowsize,              /* Dimension of the gaussian kernel. */
      center;                  /* Half of the windowsize. */
   float ***tmpim, ***tmpim1,             /* Buffer for separable filter gaussian smoothing. */
         *kernel,              /* A one dimensional gaussian kernel. */
         dot,                  /* Dot product summing variable. */
         sum;                  /* Sum of the kernel weights variable. */

   /****************************************************************************
   * Create a 1-dimensional gaussian smoothing kernel.
   ****************************************************************************/
   //printf("   Computing the gaussian smoothing kernel.\n");
   make_gaussian_kernel(sigma, &kernel, &windowsize);
   center = windowsize / 2;
   //printf("\n center is %d\n",center);
   //printf("The filter coefficients are:\n");

   //for(int i=0;i<windowsize;i++)
   //  printf("kernel[%d] = %f\n", i, kernel[i]); 
   /****************************************************************************
   * Allocate a temporary buffer image and the smoothed image.
   ****************************************************************************/
   if((tmpim = f3tensor(xdim,ydim,zdim)) == NULL)
     {
       printf("Error allocating the buffer image.\n");
       exit(1);
     }
   if((tmpim1 = f3tensor(xdim,ydim,zdim)) == NULL)
     {
       printf("Error allocating the buffer image.\n");
       exit(1);
     }
  
   /****************************************************************************
   * Blur in the x - direction.
   ****************************************************************************/
   //printf("   Bluring the image in the X-direction.\n");
   
   int indx;
   for(z=0; z<zdim; z++)
     for(y=0; y<ydim; y++) 
       for(x=0; x<xdim; x++)
	 {
	   dot = 0.0;
	   sum = 0.0;
	   for(yy=(-center);yy<=center;yy++)
	     {
	       //  if(((y+yy) >= 0) && ((y+yy) < ydim))
	       indx = y+yy;
	       if(indx<0)
		 indx = - indx;
	       else if(indx>=ydim)
		 indx = 2*ydim - indx - 1;
	       
	       dot += image[x][indx][z] * kernel[center+yy];
	       sum += kernel[center+yy];
	       
	     }
	   tmpim1[x][y][z] = dot/sum;
	 }
   
   /****************************************************************************
    * Blur in the y - direction.
   ****************************************************************************/
   //printf("   Bluring the image in the Y-direction.\n");
 for(z=0; z<zdim; z++)
   for(y=0; y<ydim; y++)
     for(x=0; x<xdim; x++)
	 {
	   sum = 0.0;
	   dot = 0.0;
	   for(xx=(-center);xx<=center;xx++)
	     {
	       //if(((x+xx) >= 0) && ((x+xx) < xdim))
	       indx = x+xx;
	       if(indx<0)
		 indx = -indx;
	       else if(indx>=xdim)
		 indx = 2*xdim- indx-1;
	       
	       dot += tmpim1[indx][y][z] * kernel[center+xx];
	       sum += kernel[center+xx];
	       
	     }
	   tmpim[x][y][z] = dot/sum;
	 }

    /****************************************************************************
   * Blur in the z - direction.
   ****************************************************************************/
   //printf("   Bluring the image in the Z-direction.\n");
   for(z=0; z<zdim; z++)
     for(y=0; y<ydim; y++)
       for(x=0; x<xdim; x++)
	 {
	   sum = 0.0;
	   dot = 0.0;
	   for(zz=(-center);zz<=center;zz++)
	     {
	       // if(((z+zz) >= 0) && ((z+zz) < zdim))	
	       indx = z+zz;
	       if(indx<0)
		 indx = -indx;
	       else if(indx>=zdim)
		 indx = 2*zdim-indx-1;
	       
	       
		 dot += tmpim[x][y][indx] * kernel[center+zz];
		 sum += kernel[center+zz];
	       
	     }
	   smoothedim[x][y][z] = dot/sum;
	 }

   
   
   free_f3tensor(tmpim,xdim,ydim,zdim);
   free_f3tensor(tmpim1,xdim,ydim,zdim);
   free_vector(kernel, windowsize);
}


void gaussian_smooth_mask(float ***image, int xdim, int ydim, int zdim, float sigma, float ***smoothedim, unsigned char ***mask)
{
   int x, y, z, xx, yy, zz,     /* Counter variables. */
     windowsize,              /* Dimension of the gaussian kernel. */
     center;                  /* Half of the windowsize. */
   float ***tmpim, ***tmpim1,             /* Buffer for separable filter gaussian smoothing. */
     *kernel,              /* A one dimensional gaussian kernel. */
     dot,                  /* Dot product summing variable. */
     sum;                  /* Sum of the kernel weights variable. */
   
   /****************************************************************************
    * Create a 1-dimensional gaussian smoothing kernel.
    ****************************************************************************/
   //printf("   Computing the gaussian smoothing kernel.\n");
   make_gaussian_kernel(sigma, &kernel, &windowsize);
   center = windowsize / 2;
   // printf("\n center is %d\n",center);
   //printf("The filter coefficients are:\n");
   
   //for(int i=0;i<windowsize;i++)
   //  printf("kernel[%d] = %f\n", i, kernel[i]); 
   /****************************************************************************
   * Allocate a temporary buffer image and the smoothed image.
   ****************************************************************************/
   if((tmpim = f3tensor(xdim,ydim,zdim)) == NULL)
     {
       printf("Error allocating the buffer image.\n");
       exit(1);
     }
   if((tmpim1 = f3tensor(xdim,ydim,zdim)) == NULL)
     {
       printf("Error allocating the buffer image.\n");
       exit(1);
     }
   
   float ***fmask = f3tensor(xdim, ydim, zdim);
   for(z=0; z<zdim; z++)
     for(y=0; y<ydim; y++) 
       for(x=0; x<xdim; x++)
	 if(mask[x][y][z]>0)
	   fmask[x][y][z] = 1.0;
	 else
	   fmask[x][y][z] = 0.0;

   
   /****************************************************************************
    * Blur in the x - direction.
    ****************************************************************************/
   //printf("   Bluring the image in the X-direction.\n");
   
   int indx;
   
   for(z=0; z<zdim; z++)
     for(y=0; y<ydim; y++) 
       for(x=0; x<xdim; x++)
	 {
	   dot = 0.0;
	   sum = 1.0e-6;
	   for(yy=(-center);yy<=center;yy++)
	     {
	       //  if(((y+yy) >= 0) && ((y+yy) < ydim))
	       indx = y+yy;
	       if(indx<0)
		 indx = - indx;
	       else if(indx>=ydim)
		 indx = 2*ydim - indx - 1;
	       
	       dot += image[x][indx][z] * kernel[center+yy] *fmask[x][indx][z];
	       sum += kernel[center+yy]*fmask[x][indx][z];
	       
	     }
	   tmpim1[x][y][z] = dot/sum;
	 }
   
   /****************************************************************************
    * Blur in the y - direction.
   ****************************************************************************/
   //printf("   Bluring the image in the Y-direction.\n");
 for(z=0; z<zdim; z++)
   for(y=0; y<ydim; y++)
     for(x=0; x<xdim; x++)
	 {
	   sum = 1.0e-6;
	   dot = 0.0;
	   for(xx=(-center);xx<=center;xx++)
	     {
	       //if(((x+xx) >= 0) && ((x+xx) < xdim))
	       indx = x+xx;
	       if(indx<0)
		 indx = -indx;
	       else if(indx>=xdim)
		 indx = 2*xdim- indx-1;
	       
	       dot += tmpim1[indx][y][z] * kernel[center+xx] *fmask[indx][y][z];
	       sum += kernel[center+xx] * fmask[indx][y][z];	       
	     }
	   tmpim[x][y][z] = dot/sum;
	 }

    /****************************************************************************
   * Blur in the z - direction.
   ****************************************************************************/
   //printf("   Bluring the image in the Z-direction.\n");
   for(z=0; z<zdim; z++)
     for(y=0; y<ydim; y++)
       for(x=0; x<xdim; x++)
	 {
	   sum = 1.0e-6;
	   dot = 0.0;
	   for(zz=(-center);zz<=center;zz++)
	     {
	       // if(((z+zz) >= 0) && ((z+zz) < zdim))	
	       indx = z+zz;
	       if(indx<0)
		 indx = -indx;
	       else if(indx>=zdim)
		 indx = 2*zdim-indx-1;
	       
	       
		 dot += tmpim[x][y][indx] * kernel[center+zz] *fmask[x][y][indx];
		 sum += kernel[center+zz]*fmask[x][y][indx];	       
	     }
	   smoothedim[x][y][z] = dot/sum;
	 }

      
   for(z=0; z<zdim; z++)
     for(y=0; y<ydim; y++)
       for(x=0; x<xdim; x++)
	 if(smoothedim[x][y][z]<=0)
	   smoothedim[x][y][z] = mask[x][y][z];

   free_f3tensor(fmask,xdim,ydim,zdim);
   free_f3tensor(tmpim,xdim,ydim,zdim);
   free_f3tensor(tmpim1,xdim,ydim,zdim);
   free_vector(kernel, windowsize);
}
