
#include "itkBinaryBallStructuringElement.h"
#include "itkBinaryDilateImageFilter.h"
#include "itkDiscreteGaussianImageFilter.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkImageRegionSplitter.h"
#include "itkLinearInterpolateImageFunction.h"
#include "itkSignedDanielssonDistanceMapImageFilter.h"

// DEBUG
//#include "itkImageFileWriter.h"

#include "vnl/vnl_math.h"

#include "MersenneTwisterRNG.h"

#include "ContrastEnhancementFilter.h"

#include <cfloat>
#include <cmath>
#include <cstdlib>

#include <vector>


// Half-normal distribution
float _halfNormal(float var, float x)
{
  if (x < 0.0)
    return 0.0;
/*
  if (x < 0.0)
  {
    // Pad
    if (x >= -0.5)
      x = -x;
    else
      return 0.0;
  }
*/

  float theta_sq = (vnl_math::pi - 2.0) / (2.0*var);
  float theta = sqrtf(theta_sq);

  return 2.0 * theta / vnl_math::pi * expf(-(x*x) * theta_sq / vnl_math::pi);
}

ContrastEnhancementFilter
::ContrastEnhancementFilter()
{

  m_Modified = false;

  m_MaximumIterations = 100;

  m_TimeStep = 0.05;

  m_BlurVariance = 1.0;

  m_SeedCountFraction = 0.05;

  m_DilationRadius = 1;

  m_VolumeFraction = 0.2;

  m_GrowCoefficient = 2.0;
  m_DeathCoefficient = 1.0;

  m_Mode = RingEnhancement;

  m_NumberOfThreads = 0;
}

ContrastEnhancementFilter
::~ContrastEnhancementFilter()
{

}

DynArray<ContrastEnhancementFilter::ProbabilityImagePointer>
ContrastEnhancementFilter
::GetOutput()
{

  if (m_Modified)
    this->Update();

  return m_OutputProbabilities;

}

void
ContrastEnhancementFilter
::SetInputProbabilities(DynArray<ProbabilityImagePointer> plist)
{

  m_InputProbabilities = plist;
  m_Modified = true;

}

void
ContrastEnhancementFilter
::CheckInput()
{
  if (m_InputProbabilities.GetSize() < 5)
    itkExceptionMacro(<< "Input probabilites < 5");

  if (m_VesselProbability.IsNull())
    itkExceptionMacro(<< "No vessel probability");

  if (m_SeedCountFraction < 0)
    itkExceptionMacro(<< "Seed fraction must be between 0 and 1");
}

void
ContrastEnhancementFilter
::Update()
{

  this->CheckInput();

  this->Initialize();

  float thresholdVolume = m_VolumeFraction * m_InputVolume;

  std::cout << "  Generating ";
  if (m_Mode == RingEnhancement)
     std::cout << " ring ";
  else if (m_Mode == UniformEnhancement)
     std::cout << " uniform ";
  else
     std::cout << " no ";
  std::cout << " tumor enhancement" << std::endl;

  for (unsigned int i = 1; i <= m_MaximumIterations; i++)
  {
    //float enhVol = this->Step();
    float enhVol = this->ThreadedStep();

    //std::cout << "Enh volume = " << enhVol << std::endl;
    //if (enhVol > thresholdVolume)
    //  break;
  }

  this->Finish();

  m_Phi = 0;
  m_NewPhi = 0;
  m_SourceMask = 0;

  m_Modified = false;

}

void
ContrastEnhancementFilter
::ComputeSourceSinkMask()
{

  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();

  ProbabilityImageRegionType region = 
    m_InputProbabilities[0]->GetLargestPossibleRegion();

  ProbabilityImageIndexType ind;
  ProbabilityImageSizeType size = region.GetSize();
  ProbabilityImageSpacingType spacing = m_InputProbabilities[0]->GetSpacing();

  ByteImagePointer brainMask = ByteImageType::New();
  brainMask->CopyInformation(m_InputProbabilities[0]);
  brainMask->SetRegions(region);
  brainMask->Allocate();
  brainMask->FillBuffer(0);

  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        float p = 0;
        for (unsigned int i = 0; i < m_InputProbabilities.GetSize(); i++)
          p += m_InputProbabilities[i]->GetPixel(ind);
        if (p > 0)
          brainMask->SetPixel(ind, 1);
        else
          brainMask->SetPixel(ind, 0);
      }

  typedef itk::SignedDanielssonDistanceMapImageFilter<
    ByteImageType, ProbabilityImageType> DistanceMapFilterType;
  typedef itk::DiscreteGaussianImageFilter<
    ProbabilityImageType, ProbabilityImageType> BlurFilterType;

  ProbabilityImagePointer brainDistMap = 0;
  {
    DistanceMapFilterType::Pointer distanceMapFilter =
      DistanceMapFilterType::New();

    distanceMapFilter->InsideIsPositiveOff();
    distanceMapFilter->SetInput(brainMask);
    distanceMapFilter->SquaredDistanceOff();
    distanceMapFilter->UseImageSpacingOn();

    distanceMapFilter->Update();

    //brainDistMap = distanceMapFilter->GetDistanceMap();

    BlurFilterType::Pointer blurf = BlurFilterType::New();
    blurf->SetInput(distanceMapFilter->GetDistanceMap());
    blurf->SetVariance(m_BlurVariance);
    blurf->Update();
    brainDistMap = blurf->GetOutput();
  }

  ByteImagePointer tumorMask = ByteImageType::New();
  tumorMask->CopyInformation(m_InputProbabilities[0]);
  tumorMask->SetRegions(region);
  tumorMask->Allocate();
  tumorMask->FillBuffer(0);

  float tumorCenter[3];
  for (int d = 0; d < 3; d++)
    tumorCenter[d] = 0;

  std::vector<ByteImageIndexType> tumorIndices;

  float sumPTumor = 1e-10;

  unsigned int numTumorVox = 0;
  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
#if 1
        float p_tumor = m_InputProbabilities[4]->GetPixel(ind);
        for (int d = 0; d < 3; d++)
          tumorCenter[d] += p_tumor*ind[d];
        sumPTumor += p_tumor;

        float maxp = 0;
        unsigned int maxi = 0;
        for (unsigned int i = 0; i < m_InputProbabilities.GetSize(); i++)
          if (maxp < m_InputProbabilities[i]->GetPixel(ind))
          {
            maxp = m_InputProbabilities[i]->GetPixel(ind);
            maxi = i;
          }
        if ((maxp > 1e-10) && (maxi == 4))
        {
          tumorMask->SetPixel(ind, 1);
          numTumorVox++;

          tumorIndices.push_back(ind);
        }
#else
        //if (m_InputProbabilities[4]->GetPixel(ind) > 1e-2)
        if (m_InputProbabilities[4]->GetPixel(ind) > 0.1)
        {
          for (int d = 0; d < 3; d++)
            tumorCenter[d] += ind[d];
          tumorMask->SetPixel(ind, 1);
          numTumorVox++;

          tumorIndices.push_back(ind);
        }
#endif
      }

std::cout << "Number of tumor voxels = " << numTumorVox << std::endl;

  ByteImageIndexType initInd = 
    tumorIndices[rng->GenerateUniformIntegerUpToK(numTumorVox-1)];

  for (int d = 0; d < 3; d++)
    tumorCenter[d] /= sumPTumor;
    //tumorCenter[d] /= numTumorVox;

//std::cout << "# of tumor voxels = " << numTumorVox << std::endl;

  ProbabilityImagePointer tumorDistMap = 0;
  {
    DistanceMapFilterType::Pointer distanceMapFilter =
      DistanceMapFilterType::New();

    distanceMapFilter->InsideIsPositiveOff();
    distanceMapFilter->SquaredDistanceOff();
    distanceMapFilter->SetInput(tumorMask);
    distanceMapFilter->UseImageSpacingOn();

    distanceMapFilter->Update();

    tumorDistMap = distanceMapFilter->GetDistanceMap();

    BlurFilterType::Pointer blurf = BlurFilterType::New();
    blurf->SetInput(distanceMapFilter->GetDistanceMap());
    blurf->SetVariance(m_BlurVariance);
    blurf->Update();
    tumorDistMap = blurf->GetOutput();
  }

  float maxTumorDistance = 0;
  float aveTumorDistance = 0;
  unsigned int countTumor = 0;
  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        if (tumorMask->GetPixel(ind) == 0)
          continue;
        float d = -1.0 * tumorDistMap->GetPixel(ind);
        if (d < 0)
          continue;
        if (maxTumorDistance < d)
          maxTumorDistance = d;
        aveTumorDistance += d;
        countTumor++;
      }
  aveTumorDistance /= countTumor;

std::cout << "  Max tumor distance = " << maxTumorDistance << std::endl;
std::cout << "  Average tumor distance = " << aveTumorDistance << std::endl;

  //float sourceRad = aveTumorDistance * 0.5;
  //float sinkRad = aveTumorDistance * 0.75;
  float sourceRad = maxTumorDistance * 0.1;
  float sinkRad = maxTumorDistance * 0.5;

  m_SourceMask = ByteImageType::New();
  m_SourceMask->CopyInformation(m_InputProbabilities[0]);
  m_SourceMask->SetRegions(region);
  m_SourceMask->Allocate();
  m_SourceMask->FillBuffer(0);

  m_SinkMask = ByteImageType::New();
  m_SinkMask->CopyInformation(m_InputProbabilities[0]);
  m_SinkMask->SetRegions(region);
  m_SinkMask->Allocate();
  m_SinkMask->FillBuffer(0);

/*
  while (true)
  {
    ind[0] = rng->GenerateUniformIntegerUpToK(size[0]-1);
    ind[1] = rng->GenerateUniformIntegerUpToK(size[1]-1);
    ind[2] = rng->GenerateUniformIntegerUpToK(size[2]-1);

    float d = -1.0 * tumorDistMap->GetPixel(ind);

    // Outside
    if (d < 0)
      continue;

    float p_tumor = m_InputProbabilities[4]->GetPixel(ind);

    if (p_tumor > 0.1)
    {
      initInd = ind;
      break;
    }
  }
*/

  // Metropolis-Hastings for tumor source and sink
  // Random walk with Gaussian proposal probabilities

  float searchStd = 80;
  float searchVar = searchStd*searchStd;

  ProbabilityImagePointer probTumor = m_InputProbabilities[4];

  ProbabilityImageIndexType currInd = initInd;
  //ProbabilityImageIndexType currInd;
  //for (int d = 0; d < 3; d++)
  //  currInd[d] = (long)tumorCenter[d];

  //unsigned int maxCountSource = 200;
  //unsigned int maxCountSink = 200;
  //unsigned int maxCountVess = 1000;
  unsigned int maxCountSource = numTumorVox / 20 + 10;
  unsigned int maxCountSink = numTumorVox / 20 + 10;
  unsigned int maxCountVess = 5 * maxCountSource;

/*
  if (m_Mode == NoEnhancement)
  {
    maxCountSource = rng->GenerateUniformIntegerUpToK(2);
  }
*/

  unsigned int countSource = 0;
  for (unsigned int iter = 0; iter < 10000000; iter++)
  {
    if (m_Mode == NoEnhancement)
      break;

    ProbabilityImageIndexType nextInd
      = tumorIndices[rng->GenerateUniformIntegerUpToK(numTumorVox-1)];
/*
    for (int d = 0; d < 3; d++)
    {
      //float shift = rng->GenerateNormal(0, searchVar);
      float shift = (2.0*rng->GenerateUniformRealClosedInterval() - 1.0) * searchStd;
      nextInd[d] = currInd[d] + (long)shift;
    }
*/

    for (int d = 0; d < 3; d++)
    {
      if (nextInd[d] < 0)
        nextInd[d] = 0;
      if (nextInd[d] >= (long)size[d])
        nextInd[d] = size[d]-1;
    }

    float d_curr = -1.0 * tumorDistMap->GetPixel(currInd);
    float d_next = -1.0 * tumorDistMap->GetPixel(nextInd);

    float p_curr = 0;
    float p_next = 0;

    if (m_Mode == RingEnhancement)
    {
      p_curr =
        _halfNormal(sourceRad*sourceRad, d_curr) * probTumor->GetPixel(currInd);
      p_next =
        _halfNormal(sourceRad*sourceRad, d_next) * probTumor->GetPixel(nextInd);
    }
    else if (m_Mode == UniformEnhancement)
    {
      p_curr = probTumor->GetPixel(currInd);
      p_next = probTumor->GetPixel(nextInd);
    }

    float ratio = p_next / (p_curr + 1e-10);
    if (ratio >= 1.0)
    {
      currInd = nextInd;
    }
    else
    {
      float q = rng->GenerateUniformRealOpenInterval();
      if (q <= ratio)
        currInd = nextInd;
    }

    // Burn-in period
    if (iter < 200)
      continue;

    if (m_SourceMask->GetPixel(currInd) == 0)
    {
      m_SourceMask->SetPixel(currInd, 1);
      countSource++;
    }

    if (countSource >= maxCountSource)
      break;
  }

std::cout << "Obtained " << countSource << " tumor sources" << std::endl;

  //for (int d = 0; d < 3; d++)
  //  currInd[d] = (long)tumorCenter[d];

  unsigned int countSink = 0;
  for (unsigned int iter = 0; iter < 10000000; iter++)
  {
    if (m_Mode == NoEnhancement || m_Mode == UniformEnhancement)
      break;

    if (countSink > (2*countSource))
      break;

    ProbabilityImageIndexType nextInd
      = tumorIndices[rng->GenerateUniformIntegerUpToK(numTumorVox-1)];
/*
    for (int d = 0; d < 3; d++)
    {
      //float shift = rng->GenerateNormal(0, searchVar);
      float shift = (2.0*rng->GenerateUniformRealClosedInterval() - 1.0) * searchStd;
      nextInd[d] = currInd[d] + (long)shift;
    }
*/

    for (int d = 0; d < 3; d++)
    {
      if (nextInd[d] < 0)
        nextInd[d] = 0;
      if (nextInd[d] >= (long)size[d])
        nextInd[d] = size[d]-1;
    }

    float d_curr = -1.0 * tumorDistMap->GetPixel(currInd);
    float d_next = -1.0 * tumorDistMap->GetPixel(nextInd);

    float p_curr = 0;
    float p_next = 0;

    if (m_Mode == RingEnhancement)
    {
      p_curr =
        _halfNormal(sinkRad*sinkRad, maxTumorDistance-d_curr) * probTumor->GetPixel(currInd);
      p_next =
      _halfNormal(sinkRad*sinkRad, maxTumorDistance-d_next) * probTumor->GetPixel(nextInd);
     }
     else if (m_Mode == UniformEnhancement)
     {
       break;
     }
     else if (m_Mode == NoEnhancement)
     {
       p_curr = probTumor->GetPixel(currInd);
       p_next = probTumor->GetPixel(nextInd);
     }

     float ratio = p_next / (p_curr + 1e-20);
     if (ratio >= 1.0)
     {
       currInd = nextInd;
     }
     else
     {
       float q = rng->GenerateUniformRealOpenInterval();
       if (q <= ratio)
         currInd = nextInd;
     }

     // Burn-in period
     if (iter < 200)
       continue;

     if (m_SinkMask->GetPixel(currInd) == 0)
     {
       m_SinkMask->SetPixel(currInd, 1);
       countSink++;
     }

     if (countSink >= maxCountSink)
       break;
  }

std::cout << "Obtained " << countSink << " tumor sinks" << std::endl;

  // Determining vessel source regions.
  // Accept reject: vessels
  for (int d = 0; d < 3; d++)
    currInd[d] = rng->GenerateUniformIntegerUpToK(size[d]-1);

  unsigned int countVess = 0;
  for (unsigned int iter = 0; iter < 10000000; iter++)
  {
    float p_curr = m_VesselProbability->GetPixel(currInd);

/*
    ProbabilityImageIndexType nextInd;
    nextInd[0] = rng->GenerateUniformIntegerUpToK(size[0]-1);
    nextInd[1] = rng->GenerateUniformIntegerUpToK(size[1]-1);
    nextInd[2] = rng->GenerateUniformIntegerUpToK(size[2]-1);
*/

    ProbabilityImageIndexType nextInd;
    for (int d = 0; d < 3; d++)
    {
      //float shift = rng->GenerateNormal(0, searchVar);
      float shift = (2.0*rng->GenerateUniformRealClosedInterval() - 1.0) * searchStd;
      nextInd[d] = currInd[d] + (long)shift;
    }

    for (int d = 0; d < 3; d++)
    {
      if (nextInd[d] < 0)
        nextInd[d] = 0;
      if (nextInd[d] >= (long)size[d])
        nextInd[d] = size[d]-1;
    }

    if (m_SourceMask->GetPixel(nextInd) != 0)
      continue;

    // Outside brain
    //float d = brainDistMap->GetPixel(ind);
    //if (d > 0.5)
    //  continue;

    float p_next = m_VesselProbability->GetPixel(nextInd);

    float ratio = p_next / (p_curr + 1e-20);
    if (ratio >= 1.0)
    {
      currInd = nextInd;
    }
    else
    {
      float q = rng->GenerateUniformRealOpenInterval();
      if (q <= ratio)
        currInd = nextInd;
    }

    if (m_SourceMask->GetPixel(currInd) == 0)
    {
      m_SourceMask->SetPixel(currInd , 1);
      countVess++;
    }

    if (countVess >= maxCountVess)
      break;
  }

#if 1
  typedef itk::BinaryBallStructuringElement<unsigned char, 3> StructElementType;
  typedef
    itk::BinaryDilateImageFilter<ByteImageType, ByteImageType,
      StructElementType> DilateType;

  StructElementType structel;
  structel.SetRadius(m_DilationRadius);
  structel.CreateStructuringElement();

  DilateType::Pointer dilSource = DilateType::New();
  dilSource->SetDilateValue(1);
  dilSource->SetKernel(structel);
  dilSource->SetInput(m_SourceMask);
  dilSource->Update();
  m_SourceMask =  dilSource->GetOutput();

  DilateType::Pointer dilSink = DilateType::New();
  dilSink->SetDilateValue(1);
  dilSink->SetKernel(structel);
  dilSink->SetInput(m_SinkMask);
  dilSink->Update();
  m_SinkMask =  dilSink->GetOutput();

#endif

/*
  //DEBUG: avoid this so we get proper boundaries?
  // No source or sink in non-tumor and non-vessel voxels
  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        float maxp = 0;
        unsigned int maxk = 0;
        for (unsigned int k = 0; k < m_InputProbabilities.GetSize(); k++)
          if (maxp < m_InputProbabilities[k]->GetPixel(ind))
          {
            maxp = m_InputProbabilities[k]->GetPixel(ind);
            maxk = k;
          }

        if (maxk == 4)
          continue;

        float pvess = m_VesselProbability->GetPixel(ind);
        if (pvess > maxp)
          continue;

        m_SourceMask->SetPixel(ind, 0);
        m_SinkMask->SetPixel(ind, 0);
      }
*/

  // Deal with overlapping source/sink regions
  typedef itk::ImageRegionIteratorWithIndex<ByteImageType> IteratorType;
  IteratorType it(m_SourceMask, m_SourceMask->GetLargestPossibleRegion());
  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    ByteImageType::IndexType ind = it.GetIndex();

    bool issource = m_SourceMask->GetPixel(ind) != 0;
    bool issink = m_SinkMask->GetPixel(ind) != 0;

    if (issource && issink)
    {
      // Cancel out
      //m_SourceMask->SetPixel(ind, 0);
      //m_SinkMask->SetPixel(ind, 0);

/*
      float q = rng->GenerateUniformRealOpenInterval();
      if (q < 0.5)
        m_SourceMask->SetPixel(ind, 0);
      else
        m_SinkMask->SetPixel(ind, 0);
*/

      // Sink assignments take priority
      //m_SourceMask->SetPixel(ind, 0);

      // Source assignments take priority
      m_SinkMask->SetPixel(ind, 0);
    }

    // Can't have source in tail regions of normal
    float p_normal = 0;
    for (unsigned int i = 0; i < 3; i++)
      p_normal += m_InputProbabilities[i]->GetPixel(ind);
    if (p_normal > 1e-2)
    {
      m_SourceMask->SetPixel(ind, 0);
    }
  }

// DEBUG
/*
  typedef itk::ImageFileWriter<ByteImageType> WriterType;
  WriterType::Pointer writer = WriterType::New();
  writer->SetInput(m_SourceMask);
  writer->SetFileName("enh_source.mha");
  writer->Update();
  writer->SetInput(m_SinkMask);
  writer->SetFileName("enh_sink.mha");
  writer->Update();
*/

}


void
ContrastEnhancementFilter
::Initialize()
{

  ProbabilityImageRegionType region = 
    m_InputProbabilities[0]->GetLargestPossibleRegion();

  ProbabilityImageIndexType ind;
  ProbabilityImageSizeType size = region.GetSize();
  ProbabilityImageSpacingType spacing = m_InputProbabilities[0]->GetSpacing();

  m_InputVolume = 0;
  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        float p = 0;
        for (unsigned int i = 0; i < m_InputProbabilities.GetSize(); i++)
          p += m_InputProbabilities[i]->GetPixel(ind);
        m_InputVolume += p;
      }
  m_InputVolume *= spacing[0]*spacing[1]*spacing[2];

  this->ComputeSourceSinkMask();

  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();

  // Initial enhancement probabilities
  m_Phi = ProbabilityImageType::New();
  m_Phi->CopyInformation(m_InputProbabilities[0]);
  m_Phi->SetRegions(region);
  m_Phi->Allocate();
  m_Phi->FillBuffer(0);

  m_NewPhi = ProbabilityImageType::New();
  m_NewPhi->CopyInformation(m_InputProbabilities[0]);
  m_NewPhi->SetRegions(region);
  m_NewPhi->Allocate();
  m_NewPhi->FillBuffer(0);

  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        float maxp =  0;
        unsigned int maxk = 0;
        for (unsigned int k = 0; k < m_InputProbabilities.GetSize(); k++)
        {
          float p = m_InputProbabilities[k]->GetPixel(ind);
          if (p > maxp)
          {
            maxp = p;
            maxk = k;
          }
        }

// Initial value
// Tumor equally likely to be enhanced
// CSF at source region also equally likely to be enhanced
// p_enh(x) = 0.5*p_tumor(x) + 0.5*p_csf(x)*I{x \in Xsource}
// OR
// random noise???

//PP
        float pvess = m_VesselProbability->GetPixel(ind);
        float ptumor = m_InputProbabilities[4]->GetPixel(ind);

        // Random, but proportional to relevant prob so boundaries are OK ???
        //float phi0 =  (pvess+ptumor) * rng->GenerateUniformRealClosedInterval();
        //float phi0 =  rng->GenerateUniformRealClosedInterval();

/*
        if (pvess > maxp)
        {
          maxk = 5;
          maxp = pvess;
        }

        if (maxp < 1e-5)
          continue;

        float phi0 = 0;
        if (maxk == 4 || maxk == 5)
          phi0 = rng->GenerateUniformRealClosedInterval();
*/

/*
        float phi0 = 0;
        if (m_SourceMask->GetPixel(ind) != 0)
          phi0 = 0.5;
*/

// Initially only fill seed regions
/*
        float phi0 = 0;
        if (m_SourceMask->GetPixel(ind) != 0)
          phi0 = pvess + ptumor;
        phi0 *= 0.5;
*/

        //float phi0 = 1.0;
        //float phi0 = (pvess + ptumor) * rng->GenerateUniformRealClosedInterval();
        float phi0 = 1.0 + rng->GenerateUniformRealClosedInterval();

        //if (pvess > ptumor)
        //  phi0 = 0.5 + rng->GenerateUniformRealClosedInterval() * 0.5;

        if ( (pvess + ptumor) < 1e-10)
          phi0 = 0;

        //if (m_Mode == RingEnhancement)
        //  phi0 *= exp(-dist*dist);

        //if (m_Mode == UniformEnhancement)
          ////phi0 = (phi0 + 0.5) / 1.5;
          //phi0 = 0.5 + rng->GenerateUniformRealClosedInterval() * 0.5;

        if (m_Mode == NoEnhancement)
        {
          //phi0 = pvess * rng->GenerateUniformRealClosedInterval();
          if (pvess < 1e-2)
            phi0 = 0;
          else
            phi0 = 0.5 + rng->GenerateUniformRealClosedInterval() * 0.5;
        }

        if (m_SinkMask->GetPixel(ind) != 0)
          phi0 = -2.0;
        if (m_SourceMask->GetPixel(ind) != 0)
          phi0 = 2.0;

        m_Phi->SetPixel(ind, phi0);
      }

  typedef itk::DiscreteGaussianImageFilter<
    ProbabilityImageType, ProbabilityImageType> BlurFilterType;
  BlurFilterType::Pointer blurf = BlurFilterType::New();
  blurf->SetInput(m_Phi);
  blurf->SetVariance(1.0);
  blurf->Update();
  m_Phi = blurf->GetOutput();

  m_DiffusionCoeffImage = ProbabilityImageType::New();
  m_DiffusionCoeffImage->CopyInformation(m_InputProbabilities[0]);
  m_DiffusionCoeffImage->SetRegions(region);
  m_DiffusionCoeffImage->Allocate();

  m_DiffusionCoeffImage->FillBuffer(0);

// Contrast agent spreads only in fluid and tumor (and vessels)
// spreads faster in tumor since packed with vessels ??
  typedef itk::ImageRegionIteratorWithIndex<ProbabilityImageType> IteratorType;
  IteratorType it(m_DiffusionCoeffImage, m_DiffusionCoeffImage->GetLargestPossibleRegion());
  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    ProbabilityImageType::IndexType ind = it.GetIndex();
//PP
#if 1
    float p_vess = m_VesselProbability->GetPixel(ind);
    float p_tumor = m_InputProbabilities[4]->GetPixel(ind);

    // Mix coeff? Bad for border regions with low probability
    //m_DiffusionCoeffImage->SetPixel(ind, p_tumor*1.5 + p_vess*1.2 + 0.1);

    // Detect likely vessel / tumor regions, adjust accordingly
    if (p_tumor > 1e-20 || p_vess > 1e-20)
    {
      float diffC = 0.01;
      if (p_tumor > p_vess)
        diffC = 0.5;
      else
        diffC = 0.4;
      m_DiffusionCoeffImage->SetPixel(ind, diffC);
    }
#else
    float maxp = 0;
    unsigned int maxk = 0;
    for (unsigned int k = 0; k < m_InputProbabilities.GetSize(); k++)
      if (maxp < m_InputProbabilities[k]->GetPixel(ind))
      {
        maxp = m_InputProbabilities[k]->GetPixel(ind);
        maxk = k;
      }
    float diffC = 0.01;
    // Tumor:
    // Inside tumor has plenty of vessels should be more diffuse than
    // healthy vessels
    if (maxk == 4)
      diffC = 1.5;
    // Vessel:
    float pvess = m_VesselProbability->GetPixel(ind);
    if (pvess > maxp)
      diffC = 1.2;
    m_DiffusionCoeffImage->SetPixel(ind, diffC);
#endif
  }

}

float
ContrastEnhancementFilter
::Step()
{

  ProbabilityImageIndexType ind;

  ProbabilityImageOffsetType xofft = {{1, 0, 0}};
  ProbabilityImageOffsetType yofft = {{0, 1, 0}};
  ProbabilityImageOffsetType zofft = {{0, 0, 1}};

  m_NewPhi->FillBuffer(0);

  ProbabilityImageSizeType size = m_InputProbabilities[0]->GetLargestPossibleRegion().GetSize();
  ProbabilityImageSpacingType spacing = m_InputProbabilities[0]->GetSpacing();

  float dxsq = spacing[0]*spacing[0];
  float dysq = spacing[1]*spacing[1];
  float dzsq = spacing[2]*spacing[2];

  float volume = 0.0;
  for (ind[2] = 1; ind[2] < (long)(size[2]-1); ind[2]++)
    for (ind[1] = 1; ind[1] < (long)(size[1]-1); ind[1]++)
      for (ind[0] = 1; ind[0] < (long)(size[0]-1); ind[0]++)
      {
        float t0 = m_DiffusionCoeffImage->GetPixel(ind);

        float tx_f = (m_DiffusionCoeffImage->GetPixel(ind+xofft) + t0) / 2.0; 
        float tx_b = (m_DiffusionCoeffImage->GetPixel(ind-xofft) + t0) / 2.0; 

        float ty_f = (m_DiffusionCoeffImage->GetPixel(ind+yofft) + t0) / 2.0; 
        float ty_b = (m_DiffusionCoeffImage->GetPixel(ind-yofft) + t0) / 2.0; 

        float tz_f = (m_DiffusionCoeffImage->GetPixel(ind+zofft) + t0) / 2.0; 
        float tz_b = (m_DiffusionCoeffImage->GetPixel(ind-zofft) + t0) / 2.0; 

        float phi0 = m_Phi->GetPixel(ind);

        float diffusionT = 0;

        diffusionT += tx_f * (m_Phi->GetPixel(ind+xofft) - phi0) / dxsq;
        diffusionT -= tx_b * (phi0 - m_Phi->GetPixel(ind-xofft)) / dxsq;

        diffusionT += ty_f * (m_Phi->GetPixel(ind+yofft) - phi0) / dysq;
        diffusionT -= ty_b * (phi0 - m_Phi->GetPixel(ind-yofft)) / dysq;

        diffusionT += tz_f * (m_Phi->GetPixel(ind+zofft) - phi0) / dzsq;
        diffusionT -= tz_b * (phi0 - m_Phi->GetPixel(ind-zofft)) / dzsq;

        float growT = 0;
        //if (m_SourceMask->GetPixel(ind) != 0)
          growT = m_GrowCoefficient * phi0 * (1.0 - phi0);

        float deathT = 0;
        if (m_SinkMask->GetPixel(ind) != 0)
          deathT = m_DeathCoefficient * phi0;

        float dphi = m_TimeStep * (diffusionT + growT - deathT);

        //sumChange += fabs(dphi);

        float newphi = phi0 + dphi;
        if (newphi < 0.0)
          newphi = 0;
        if (newphi > 1.0)
          newphi = 1.0;

        volume += newphi;

        m_NewPhi->SetPixel(ind, newphi);
      }

  volume *= spacing[0]*spacing[1]*spacing[2];

  // Swap pointers
  ProbabilityImagePointer tmp = m_Phi;
  m_Phi = m_NewPhi;
  m_NewPhi = tmp;

  return volume;

}

float
ContrastEnhancementFilter
::ThreadedStep()
{
  itk::MultiThreader::Pointer threader = itk::MultiThreader::New();
    
  int numThreads = m_NumberOfThreads;
  if (numThreads == 0)
    numThreads = threader->GetGlobalDefaultNumberOfThreads() / 4 * 3;
  if (numThreads < 2)
    numThreads = 2;

  ProbabilityImageRegionType region =
    m_InputProbabilities[0]->GetLargestPossibleRegion();

  typedef itk::ImageRegionSplitter<3> SplitterType;
  SplitterType::Pointer splitter = SplitterType::New();
  unsigned int numSplits = splitter->GetNumberOfSplits(region, numThreads);

  numThreads = numSplits;

  m_SplitRegions.Clear();
  for (unsigned int k = 0; k < numSplits; k++)
    m_SplitRegions.Append(splitter->GetSplit(k, numSplits, region));

  m_NewPhi->FillBuffer(0);
    
  threader->SetNumberOfThreads(numThreads);

  threader->SetSingleMethod(
    &ContrastEnhancementFilter::_stepThread, (void*)this);
  threader->SingleMethodExecute();

  ProbabilityImageSizeType size = region.GetSize();
  ProbabilityImageSpacingType spacing = m_InputProbabilities[0]->GetSpacing();

  float volume = 0.0;

  ProbabilityImageIndexType ind;

  for (ind[2] = 1; ind[2] < (long)(size[2]-1); ind[2]++)
    for (ind[1] = 1; ind[1] < (long)(size[1]-1); ind[1]++)
      for (ind[0] = 1; ind[0] < (long)(size[0]-1); ind[0]++)
      {
        volume += m_NewPhi->GetPixel(ind);
      }

  volume *= spacing[0]*spacing[1]*spacing[2];

  // Swap pointers
  ProbabilityImagePointer tmp = m_Phi;
  m_Phi = m_NewPhi;
  m_NewPhi = tmp;

  return volume;

}

ITK_THREAD_RETURN_TYPE
ContrastEnhancementFilter
::_stepThread(void* arg)
{
  typedef itk::MultiThreader::ThreadInfoStruct  ThreadInfoType;
  ThreadInfoType * infoStruct = static_cast<ThreadInfoType*>( arg );
  ContrastEnhancementFilter* obj = static_cast<ContrastEnhancementFilter*>(
    infoStruct->UserData);

  unsigned int index = infoStruct->ThreadID;

  ProbabilityImageRegionType region = obj->m_SplitRegions[index];

  ProbabilityImageSizeType size = obj->m_Phi->GetLargestPossibleRegion().GetSize();
  ProbabilityImageSpacingType spacing = obj->m_Phi->GetSpacing();

  ProbabilityImageOffsetType xofft = {{1, 0, 0}};
  ProbabilityImageOffsetType yofft = {{0, 1, 0}};
  ProbabilityImageOffsetType zofft = {{0, 0, 1}};

  float dxsq = spacing[0]*spacing[0];
  float dysq = spacing[1]*spacing[1];
  float dzsq = spacing[2]*spacing[2];

  typedef itk::ImageRegionIteratorWithIndex<ProbabilityImageType> IteratorType;

  IteratorType it(obj->m_NewPhi, region);
  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    ProbabilityImageIndexType ind = it.GetIndex();

    bool isEdge = false;
    for (unsigned int d = 0; d < 3; d++)
      if (ind[d] <= 0 || ind[d] >= (long)size[d]-1)
      {
        isEdge = true;
        break;
      }

    if (isEdge)
      continue;

    float t0 = obj->m_DiffusionCoeffImage->GetPixel(ind);

    float tx_f = (obj->m_DiffusionCoeffImage->GetPixel(ind+xofft) + t0) / 2.0; 
    float tx_b = (obj->m_DiffusionCoeffImage->GetPixel(ind-xofft) + t0) / 2.0; 

    float ty_f = (obj->m_DiffusionCoeffImage->GetPixel(ind+yofft) + t0) / 2.0; 
    float ty_b = (obj->m_DiffusionCoeffImage->GetPixel(ind-yofft) + t0) / 2.0; 

    float tz_f = (obj->m_DiffusionCoeffImage->GetPixel(ind+zofft) + t0) / 2.0; 
    float tz_b = (obj->m_DiffusionCoeffImage->GetPixel(ind-zofft) + t0) / 2.0; 

    float diffusionT = 0;

    float phi0 = obj->m_Phi->GetPixel(ind);

    diffusionT += tx_f * (obj->m_Phi->GetPixel(ind+xofft) - phi0) / dxsq;
    diffusionT -= tx_b * (phi0 - obj->m_Phi->GetPixel(ind-xofft)) / dxsq;

    diffusionT += ty_f * (obj->m_Phi->GetPixel(ind+yofft) - phi0) / dysq;
    diffusionT -= ty_b * (phi0 - obj->m_Phi->GetPixel(ind-yofft)) / dysq;

    diffusionT += tz_f * (obj->m_Phi->GetPixel(ind+zofft) - phi0) / dzsq;
    diffusionT -= tz_b * (phi0 - obj->m_Phi->GetPixel(ind-zofft)) / dzsq;

    float growT = 0;
    if (obj->m_SourceMask->GetPixel(ind) != 0)
      growT = obj->m_GrowCoefficient * phi0 * (1.0 - phi0);

    float deathT = 0;
    if (obj->m_SinkMask->GetPixel(ind) != 0)
      deathT = obj->m_DeathCoefficient * phi0;

    float dphi = obj->m_TimeStep * (diffusionT + growT - deathT);

    //sumChange += fabs(dphi);

    float newphi = phi0 + dphi;
    if (newphi < 0.0)
      newphi = 0;
    if (newphi > 1.0)
      newphi = 1.0;

    obj->m_Mutex.Lock();
    obj->m_NewPhi->SetPixel(ind, newphi);
    obj->m_Mutex.Unlock();
  }

  return ITK_THREAD_RETURN_VALUE;

}


void
ContrastEnhancementFilter
::Finish()
{

  ProbabilityImageIndexType ind;
  ProbabilityImageSizeType size =
    m_InputProbabilities[0]->GetLargestPossibleRegion().GetSize();

  // Allocate output probability images
  m_OutputProbabilities.Clear();
  m_OutputProbabilities.Allocate(m_InputProbabilities.GetSize()+1);
  for (unsigned int i = 0; i < (m_InputProbabilities.GetSize()+1); i++)
  {
    ProbabilityImagePointer pImg = ProbabilityImageType::New();
    pImg->CopyInformation(m_InputProbabilities[0]);
    pImg->SetRegions(m_InputProbabilities[0]->GetLargestPossibleRegion());
    pImg->Allocate();
    pImg->FillBuffer(0);

    m_OutputProbabilities.Append(pImg);
  }

  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        float p_enh = m_Phi->GetPixel(ind);

        if (p_enh < 0.0)
          p_enh = 0.0;
        if (p_enh > 1.0)
          p_enh = 1.0;

        float p_vess = m_VesselProbability->GetPixel(ind);
        float p_tumor = m_InputProbabilities[4]->GetPixel(ind);

        float p_tv = p_vess + p_tumor;

        // wm, gm, csf, edema unchanged
        m_OutputProbabilities[0]->SetPixel(ind,
          m_InputProbabilities[0]->GetPixel(ind));
        m_OutputProbabilities[1]->SetPixel(ind,
          m_InputProbabilities[1]->GetPixel(ind));
        m_OutputProbabilities[2]->SetPixel(ind,
          m_InputProbabilities[2]->GetPixel(ind));
        m_OutputProbabilities[3]->SetPixel(ind,
          m_InputProbabilities[3]->GetPixel(ind));

        float p_tumor_enh = p_enh * p_tumor;
        float p_vess_enh = p_enh * p_vess;

        // Non-enhancing tumor
        m_OutputProbabilities[4]->SetPixel(ind, p_tumor-p_tumor_enh);

        // Enhancement
        m_OutputProbabilities[5]->SetPixel(ind, p_vess_enh + p_tumor_enh);
      }

}

