
#ifndef _WeiLevoyTextureGenerator_txx
#define _WeiLevoyTextureGenerator_txx

#include "itkBSplineInterpolateImageFunction.h"
#include "itkDiscreteGaussianImageFilter.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkResampleImageFilter.h"

#include "itkBinaryErodeImageFilter.h"
#include "BinaryBlockStructuringElement.h"

//TODO: ITK implementation of kd is buggy, use wrapper for ANN?

#include "itkKdTree.h"
#include "itkKdTreeGenerator.h"
#include "itkListSample.h"
#include "itkVariableLengthVector.h"

#include "vnl/vnl_math.h"
#include "vnl/vnl_matrix.h"
#include "vnl/vnl_vector.h"

#include "WeiLevoyTextureGenerator.h"

#include "MersenneTwisterRNG.h"

#include "TreeStructuredVectorQuantizer.h"

#include <stdlib.h>

template <class TOutputImage>
WeiLevoyTextureGenerator<TOutputImage>
::WeiLevoyTextureGenerator()
{
  //Initial image is 64 units wide in each direction.
  for (unsigned int i = 0; i < TOutputImage::GetImageDimension(); i++)
  {
    m_Size[i] = 64;
    m_Spacing[i] = 1.0;
  }

  m_NumberOfLevels = 1;
  m_NeighborhoodRadius = 3;
  m_MaxTextureSamples = 5000;

}

template <class TOutputImage>
WeiLevoyTextureGenerator<TOutputImage>
::~WeiLevoyTextureGenerator()
{

}

template <class TOutputImage>
void
WeiLevoyTextureGenerator<TOutputImage>
::GenerateOutputInformation()
{

  TOutputImage *output;

  IndexType index;
  index.Fill(0);
  
  output = this->GetOutput(0);
  
  typename TOutputImage::RegionType region;
  region.SetSize( m_Size );
  region.SetIndex( index );
  output->SetRegions( region );
  
  output->SetSpacing(m_Spacing);

  // Assume zero output origin
  //output->SetOrigin(m_Origin);

}

template <class TOutputImage>
DynArray<typename WeiLevoyTextureGenerator<TOutputImage>::OffsetType>
WeiLevoyTextureGenerator<TOutputImage>
::_BuildNeighborhoodOffsets(long radius)
{
  unsigned int imageDimension = TOutputImage::GetImageDimension();

  DynArray<OffsetType> offsets;

  OffsetType offt;
  offt.Fill(-radius);

  bool remaining = true;
  while (remaining)
  {
    bool allforward = true;
    for (unsigned int i = 0; i < imageDimension; i++)
      if (offt[i] < 0)
      {
        allforward = false;
        break;
      }

    if (!allforward)
      offsets.Append(offt);

    // Stop when we hit center?
    bool allzero = true;
    for (unsigned int i = 0; i < imageDimension; i++)
      if (offt[i] != 0)
      {
        allzero = false;
        break;
      }

    if (allzero)
      break;

    //offsets.Append(offt);

    remaining = false;
    for (unsigned int i = 0; i < imageDimension; i++)
    {
      offt[i]++;
      if (offt[i] <= radius)
      {
        remaining = true;
        break;
      }    
      else
      {
        offt[i] = -radius;
      }
    } // for i
  } // while remaining

  // Add center location
  offt.Fill(0);
  offsets.Append(offt);

  return offsets;
}

template <class TOutputImage>
void
WeiLevoyTextureGenerator<TOutputImage>
::_WrapIndex(IndexType& ind, const SizeType& size)
{
  unsigned int imageDimension = TOutputImage::GetImageDimension();
  // Wrap around
  for (unsigned int j = 0; j < imageDimension; j++)
  {
    ind[j] = (ind[j] + size[j]) % size[j];
  }
}

template <class TOutputImage>
void
WeiLevoyTextureGenerator<TOutputImage>
::_ReflectIndex(IndexType& ind, const SizeType& size)
{
  unsigned int imageDimension = TOutputImage::GetImageDimension();
  // Reflect
  for (unsigned int j = 0; j < imageDimension; j++)
  {
    if (ind[j] < 0)
    {
      ind[j] = (-ind[j]) % size[j];
    }
    if (ind[j] >= (long)size[j])
    {
      ind[j] = (size[j] - 1) - (ind[j] % size[j] + 1);
    }
  }
}

template <class TOutputImage>
void
WeiLevoyTextureGenerator<TOutputImage>
::InitializeOutput(
  long rad, OutputImageType* out, const OutputImageType* text)
{

  unsigned int imageDimension = TOutputImage::GetImageDimension();

  DynArray<OffsetType> offsets_init = 
    this->_BuildNeighborhoodOffsets(rad);

  SizeType textureSize_init =
    text->GetLargestPossibleRegion().GetSize();
  SizeType outSize_init =
    out->GetLargestPossibleRegion().GetSize();

  IndexType ind0;
  ind0.Fill(0);

  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();

  // Choose random location in texture, restrict to nonboundary blocks
  IndexType ind_r;
  for (unsigned int j = 0; j < imageDimension; j++)
  {
    ind_r[j] =
      rng->GenerateUniformIntegerUpToK(textureSize_init[j]-1-2*rad) + rad;
    if (ind_r[j] < rad)
      ind_r[j] = rad;
    if (ind_r[j] >= (textureSize_init[j]-rad))
      ind_r[j] = textureSize_init[j] - 1 - rad;
  }

  // Assign causal neighborhood to the first half-patch in output
  for (unsigned int i = 0; i < offsets_init.GetSize(); i++)
  {
    IndexType ind_out = ind0 + offsets_init[i];
    IndexType ind_text = ind_r + offsets_init[i];

    // Take care of BC
    //this->_WrapIndex(ind_out, outSize_init);
    //this->_WrapIndex(ind_text, textureSize_init);
    this->_ReflectIndex(ind_out, outSize_init);
    this->_ReflectIndex(ind_text, textureSize_init);

    out->SetPixel(ind_out, text->GetPixel(ind_text));
  }

  // Update center pixel
  out->SetPixel(ind0, text->GetPixel(ind_r));

}

template <class TOutputImage>
void
WeiLevoyTextureGenerator<TOutputImage>
::GenerateData()
{

  unsigned int imageDimension = TOutputImage::GetImageDimension();

  if (m_MaskImage.IsNull())
    itkExceptionMacro(<< "No mask specified");
  if (m_TextureImage.IsNull())
    itkExceptionMacro(<< "No texture input specified");

  typename TOutputImage::SpacingType spacing =
    m_TextureImage->GetSpacing();
  for (unsigned int k = 1; k < imageDimension; k++)
    if (fabs(spacing[k]-spacing[0]) > 0.1)
      itkExceptionMacro(<< "Input texture should be almost isotropic");

  typedef itk::DiscreteGaussianImageFilter<TOutputImage, TOutputImage>
    GaussianFilterType;
  typedef itk::ResampleImageFilter<TOutputImage, TOutputImage>
    ResampleFilterType;

  typedef itk::BSplineInterpolateImageFunction<TOutputImage, double>
    BSplineInterpolatorType;

  typedef itk::VariableLengthVector<OutputImagePixelType> VVector;
  typedef itk::Statistics::ListSample<VVector> SampleType;
  typedef itk::Statistics::KdTree<SampleType> KdTreeType;
  typedef itk::Statistics::KdTreeGenerator<SampleType> TreeGeneratorType;

  typedef itk::ImageRegionIteratorWithIndex<TOutputImage> OutputIterator;

  typedef itk::ImageRegionIteratorWithIndex<TOutputImage> IteratorType;

  typename TOutputImage::Pointer outPtr = this->GetOutput(0);

  // Erode the texture mask image by specified radius
  typedef BinaryBlockStructuringElement<
    unsigned char, TOutputImage::ImageDimension>
    StructElementType;
  typedef
    itk::BinaryErodeImageFilter<ByteImageType, ByteImageType,
      StructElementType> ErodeType;

  StructElementType structel;

  typename StructElementType::RadiusType radius;
  radius.Fill(m_NeighborhoodRadius);
  structel.SetRadius(radius);

  typename ErodeType::Pointer erode = ErodeType::New();
  erode->SetErodeValue(1);
  erode->SetKernel(structel);
  erode->SetInput(m_MaskImage);

  erode->Update();

  ByteImagePointer erodedMask = erode->GetOutput();

  // Allocate output buffer
  outPtr->Allocate();

  // Find intensity range of texture image
std::cout << "Compute texture intensity range..." << std::endl;
  IteratorType texIt(
    m_TextureImage, m_TextureImage->GetLargestPossibleRegion());
  float textureMin = vnl_huge_val(1.0f);
  float textureMax = -vnl_huge_val(1.0f);
  for (texIt.GoToBegin(); !texIt.IsAtEnd(); ++texIt)
  {
    IndexType ind = texIt.GetIndex();

    if (erodedMask->GetPixel(ind) == 0)
      continue;

    float v = texIt.Get();

    if (v < textureMin)
      textureMin = v;
    if (v > textureMax)
      textureMax = v;
  }

  float textureRange = textureMax - textureMin;

std::cout << "Texture range = " << textureMin << " to " << textureMax << std::endl;

#if 0
  float texturePad = 0.005 * textureRange;
  textureMin += texturePad;
  textureMax -= texturePad;
  textureRange -= 2.0 * texturePad;

std::cout << "Padded texture range = " << textureMin << " to " << textureMax << std::endl;
#endif


  // Initialization: fill output image with noise in the same range as texture
std::cout << "Init image with noise..." << std::endl;

  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();

  // Re-initialize RNG to get consistent behavior at each run
  rng->Initialize(748379823758375815L);

  IteratorType outIt(outPtr, outPtr->GetLargestPossibleRegion());

  float outMin = vnl_huge_val(1.0f);
  float outMax = -vnl_huge_val(1.0f);

  outIt.GoToBegin();
  while (!outIt.IsAtEnd())
  {
    float u = rng->GenerateUniformRealClosedInterval();
    //float u = (float)rand() / (float)RAND_MAX;
    float pix = u*textureRange + textureMin;
    if (pix < outMin)
      outMin = pix;
    if (pix > outMax)
      outMax = pix;
    outIt.Set(static_cast<OutputImagePixelType>(pix));
    ++outIt;
  }

std::cout << "Initial random range = " << outMin << " to " << outMax << std::endl;

  // Initialize coarsest level
  this->InitializeOutput(
    m_NeighborhoodRadius+1,
    outPtr,
    m_TextureImage);

  // Build output image pyramid
std::cout << "Build output pyramid..." << std::endl;
  DynArray<OutputImagePointer> outputPyramid;
  outputPyramid.Allocate(m_NumberOfLevels);

  outputPyramid.Append(outPtr);
  for (long k = 1; k <= (m_NumberOfLevels-1); k++)
  {
    typename GaussianFilterType::Pointer blur = GaussianFilterType::New();
    blur->SetInput(outPtr);
    blur->SetUseImageSpacingOn();
    //blur->SetVariance(vnl_math_sqr(2*k * 0.5));
    blur->SetVariance(k*k); // stddev = half of shrink factor
    //blur->SetVariance(pow(2.0, 2.0*k)); // stddev = powers of 2
    blur->Update();

    outputPyramid.Append(blur->GetOutput());
  }

  // Store the image information for the pyramid
std::cout << "Computing output image pyramid info..." << std::endl;
  DynArray<SizeType> imageSizes;
  DynArray<SpacingType> imageSpacings;

  imageSizes.Append(m_Size);
  imageSpacings.Append(m_Spacing);
  for (long k = 1; k <= (m_NumberOfLevels-1); k++)
  {
    SizeType size_k;
    for (unsigned int i = 0; i < imageDimension; i++)
    {
      size_k[i] = (unsigned long)floor((float)m_Size[i] / (float)(2*k));
      if (size_k[i] < 1)
        size_k[i] = 1;
    }

    SpacingType spacing_k;
    for (unsigned int i = 0; i < imageDimension; i++)
    {
      spacing_k[i] = m_Spacing[i] * (float)(2*k);
    }

    imageSizes.Append(size_k);
    imageSpacings.Append(spacing_k);
  }

  // Build texture pyramid
std::cout << "Build texture pyramid..." << std::endl;
  DynArray<OutputImagePointer> texturePyramid;
  texturePyramid.Allocate(m_NumberOfLevels);

  texturePyramid.Append(m_TextureImage);
  for (long k = 1; k <= (m_NumberOfLevels-1); k++)
  {
    typename GaussianFilterType::Pointer blur = GaussianFilterType::New();
    blur->SetInput(m_TextureImage);
    blur->SetUseImageSpacingOn();
    //blur->SetVariance(vnl_math_sqr(2*k * 0.5));
    blur->SetVariance(k*k); // stddev = half of shrink factor
    //blur->SetVariance(pow(2.0, 2.0*k)); // stddev = powers of 2
    blur->Update();

    // No shrinking for textures, just use masked blurred images
    texturePyramid.Append(blur->GetOutput());
  }

  //
  // Generate textures by doing multiscale search
  //

  // Figure out search radius for each level, minimum radius is always 1
  DynArray<unsigned int> radiusList;
  for (long level = 0; level < m_NumberOfLevels; level++)
  {
    long r = m_NeighborhoodRadius - level;
    if (r < 1)
      r = 1;

    // Find the smallest radius for this level
    SizeType textureSize =
      texturePyramid[level]->GetLargestPossibleRegion().GetSize();
    long minrad = (textureSize[0]-1) / 2;
    if (minrad < 1)
      minrad = 1;
    for (unsigned int dim = 1; dim < imageDimension; dim++)
    {
      long r = (textureSize[dim]-1) / 2;
      if (r < 1)
        r = 1;
      if (r < minrad)
        minrad = r;
    }

    if (r > minrad)
      r = minrad;

std::cout << "L = " << level << ", rad = " << r << std::endl;

    radiusList.Append(r);
  }


  //
  // Multiresolution search
  //

  SizeType textSize = texturePyramid[0]->GetLargestPossibleRegion().GetSize();
std::cout << "Texture input size = " << textSize << std::endl;

  for (long level = m_NumberOfLevels-1; level >= 0; level--)
  {
std::cout << "Texture search at level = " << level << std::endl;

/*
    // Fill in initial neighborhood
// TODO: NOTE: consistency across levels? just do it at coarsest?
std::cout << "Filling in initial neighborhood..." << std::endl;
    this->InitializeOutput(radiusList[level], outputPyramid[level],
      texturePyramid[level]);
*/

std::cout << "Setting up iterators..." << std::endl;

    IteratorType outputIt(
      outputPyramid[level],
      outputPyramid[level]->GetLargestPossibleRegion());

    SizeType outSize_thisL =
      outputPyramid[level]->GetLargestPossibleRegion().GetSize();
std::cout << "Size of output pyramid = " << outSize_thisL << std::endl;

    SizeType outSize_nextL = outSize_thisL;
    if (level < (m_NumberOfLevels-1))
    {
      outSize_nextL = outputPyramid[level+1]->GetLargestPossibleRegion().GetSize();
    }

    IteratorType textureIt(
      texturePyramid[level],
      texturePyramid[level]->GetLargestPossibleRegion());

    // Build list of offsets from radius used in current level
    DynArray<OffsetType> offsets_thisL =
      this->_BuildNeighborhoodOffsets(radiusList[level]);

    // Build list of offsets from radius used in previous level
    DynArray<OffsetType> offsets_nextL;

    if (level < (m_NumberOfLevels-1))
    {
      offsets_nextL = this->_BuildNeighborhoodOffsets(1);
    } // if level

std::cout << "Offsets for this level = " << std::endl;
for (unsigned int j = 0; j < offsets_thisL.GetSize(); j++)
  std::cout << offsets_thisL[j] << std::endl;

    unsigned int centerIndex = offsets_thisL.GetSize() - 1;
    //unsigned int centerIndex = offsets_thisL.GetSize() / 2;

std::cout << "centerIndex = " << centerIndex << std::endl;

std::cout << "center = " << offsets_thisL[centerIndex] << std::endl;

    unsigned int numNeighFeatures =
      centerIndex + offsets_nextL.GetSize();
      //offsets_thisL.GetSize() + offsets_nextL.GetSize() - 1;

std::cout << "Building texture feature tree..." << std::endl;
    typename SampleType::Pointer textureSample = SampleType::New();
    textureSample->SetMeasurementVectorSize(numNeighFeatures);

    DynArray<OutputImagePixelType> centerFeatures;

    for (textureIt.GoToBegin(); !textureIt.IsAtEnd(); ++textureIt)
    {
      IndexType textInd = textureIt.GetIndex();

      if (erodedMask->GetPixel(textInd) == 0)
        continue;

//std::cout << "Adding location " << textInd << " to database" << std::endl;

      // Build neighborhood vector of texture image
      VVector N_t(numNeighFeatures);
      N_t.Fill(0);

      for (unsigned int i = 0; i < centerIndex; i++)
      {
        IndexType ind_t = textInd + offsets_thisL[i];
        //this->_WrapIndex(ind_t, textSize);
        this->_ReflectIndex(ind_t, textSize);
        N_t[i] = texturePyramid[level]->GetPixel(ind_t);
      }

      // Add the pixels in lower resolution to the neighborhood
      if (level < (m_NumberOfLevels-1))
      {
        // Find center location for next level
        IndexType ind_nextL = textureIt.GetIndex();
        for (unsigned int i = 0; i < offsets_nextL.GetSize(); i++)
        {
          IndexType ind_tmp = ind_nextL + offsets_nextL[i];
          // Take care of BC
          //this->_WrapIndex(ind_tmp, textSize);
          this->_ReflectIndex(ind_tmp, textSize);
          N_t[centerIndex+i] = texturePyramid[level+1]->GetPixel(ind_tmp);
        }
      } // if level

#if 0
// Avoid duplicates?
      bool unique = true;
      float thresN = 1e-4 * numNeighFeatures;
      for (unsigned int k = 0; k < textureSample->Size(); k++)
      {
        VVector d = textureSample->GetMeasurementVector(k) - N_t;
        if (d.GetSquaredNorm() < thresN)
        {
          unique = false;
          break;
        }
      }

      if (!unique)
        continue;
#endif

      textureSample->PushBack(N_t);

      centerFeatures.Append(texturePyramid[level]->GetPixel(textInd));

    } // for textureIt

    if (textureSample->Size() > m_MaxTextureSamples)
    {
      unsigned int* selectInd =
        rng->GenerateIntegerSequence(
          m_MaxTextureSamples, textureSample->Size()-1);

      typename SampleType::Pointer selectedTextureSample = SampleType::New();
      selectedTextureSample->SetMeasurementVectorSize(numNeighFeatures);

      for (unsigned int k = 0; k < m_MaxTextureSamples; k++)
        selectedTextureSample->PushBack(
          textureSample->GetMeasurementVector(selectInd[k]));

      DynArray<OutputImagePixelType> selectedCenterFeatures;
      selectedCenterFeatures.Allocate(m_MaxTextureSamples);
      for (unsigned int k = 0; k < m_MaxTextureSamples; k++)
        selectedCenterFeatures.Append(centerFeatures[selectInd[k]]);

      delete [] selectInd;

      textureSample = selectedTextureSample;
      centerFeatures = selectedCenterFeatures;
    }

std::cout << "  Constructing Kd tree..." << std::endl;
std::cout << "    " << textureSample->Size() << " samples in DB" << std::endl;
    typename TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();

    treeGenerator->SetSample(textureSample);
    treeGenerator->SetBucketSize(100);
    treeGenerator->Update();

    typename KdTreeType::Pointer kdtree = treeGenerator->GetOutput();

std::cout << "Processing pyramids..." << std::endl;

    outputIt.GoToBegin();
    while (!outputIt.IsAtEnd())
    {
      float bestDiff = vnl_huge_val(1.0f);

      // Build neighborhood of output image, borders processed toroidally
      // Include only up to center pixel (causal neighborhood)
      VVector N_s(numNeighFeatures);
      N_s.Fill(0);

      IndexType outInd = outputIt.GetIndex();

      for (unsigned int i = 0; i < centerIndex; i++)
      {
        IndexType ind_o = outInd + offsets_thisL[i];
        //this->_WrapIndex(ind_o, outSize_thisL);
        this->_ReflectIndex(ind_o, outSize_thisL);
        N_s[i] = outputPyramid[level]->GetPixel(ind_o);
      }

      // Add the pixels in lower resolution to the neighborhood
      if (level < (m_NumberOfLevels-1))
      {
        // Find center location for next level
        IndexType ind_nextL = outputIt.GetIndex();
        for (unsigned int i = 0; i < offsets_nextL.GetSize(); i++)
        {
          IndexType ind_tmp = ind_nextL + offsets_nextL[i];
          // BC
          //this->_WrapIndex(ind_tmp, outSize_nextL);
          this->_ReflectIndex(ind_tmp, outSize_nextL);
          N_s[centerIndex+i] = outputPyramid[level+1]->GetPixel(ind_tmp);
        }
      }

      unsigned int numMatches = 1;

      typename KdTreeType::InstanceIdentifierVectorType neighbors;
      kdtree->Search(N_s, numMatches, neighbors);

      float bestPixelValue = centerFeatures[neighbors[0]];

      outputIt.Set(bestPixelValue);

      ++outputIt;
    } // while outputIt

#if 0
    // Upsample the synthesized image before processing next level
    if (level > 0)
    {
std::cout << "Upsampling for next level..." << std::endl;
      typename BSplineInterpolatorType::Pointer bsplineInterp =
        BSplineInterpolatorType::New();
      bsplineInterp->SetSplineOrder(5);

      typename ResampleFilterType::Pointer resampler =
        ResampleFilterType::New();
      resampler->SetInterpolator(bsplineInterp);
      resampler->SetInput(outputPyramid[level]);
//TODO: size???
      resampler->SetOutputSpacing(imageSpacings[level-1]);
      resampler->SetSize(imageSizes[level-1]);
      resampler->Update();

      outputPyramid[level-1] = resampler->GetOutput();
    }
#endif
/*
    if (level > 0)
    {
      outputPyramid[level-1] = outputPyramid[level];
    }
*/

  } // for level

std::cout << "Copy result to output..." << std::endl;
  // Copy final result to output buffer
  OutputIterator tmpIt(outputPyramid[0], outPtr->GetLargestPossibleRegion());

  tmpIt.GoToBegin();
  outIt.GoToBegin();
  while (!outIt.IsAtEnd())
  {
    OutputImagePixelType v = tmpIt.Get();
    outIt.Set(v);

    ++outIt;
    ++tmpIt;
  }

std::cout << "Done GenerateData" << std::endl;

}

#endif
