
#include "DTImageScalarSource.h"

#include "vnl/vnl_math.h"

DTImageScalarSource::
DTImageScalarSource()
{
  m_DTImage = 0;

  m_MDImage = FloatImageType::New();
  m_FAImage = FloatImageType::New();
}

DTImageScalarSource::
~DTImageScalarSource()
{

}

void
DTImageScalarSource::
SetInput(DTImageType* dti)
{
  m_DTImage = dti;

  this->Update();
}

DTImageScalarSource::FloatImageType*
DTImageScalarSource::
GetMDImage()
{
  if (m_DTImage.IsNull())
    return 0;

  return m_MDImage.GetPointer();
}

DTImageScalarSource::FloatImageType*
DTImageScalarSource::
GetFAImage()
{
  if (m_DTImage.IsNull())
    return 0;

  return m_FAImage.GetPointer();
}

void
DTImageScalarSource::
Update()
{
  if (m_DTImage.IsNull())
    return;

  itkDebugMacro(<< "Update MD and FA");

  DTImageRegionType region = m_DTImage->GetLargestPossibleRegion();

  // Allocate space for scalar images
  m_MDImage->CopyInformation(m_DTImage);
  m_MDImage->SetRegions(region);
  m_MDImage->Allocate();

  m_FAImage->CopyInformation(m_DTImage);
  m_FAImage->SetRegions(region);
  m_FAImage->Allocate();

  // Assign values pixel-by-pixel

  DTImageIndexType ind;

  DTImageSizeType size = region.GetSize();
 
  //double scaleFA = sqrt(3.0 / 2.0);
  double scaleFA = sqrt(0.5);

  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {

        DiffusionTensor::MatrixEigenType eig(
          m_DTImage->GetPixel(ind).GetMatrix());

        double lambda1 = eig.get_eigenvalue(2);
        double lambda2 = eig.get_eigenvalue(1);
        double lambda3 = eig.get_eigenvalue(0);

        if (vnl_math_abs(lambda1) < vnl_math_abs(lambda2))
        {
          double s = lambda1;
          lambda1 = lambda2;
          lambda2 = s;
        }
        if (vnl_math_abs(lambda1) < vnl_math_abs(lambda3))
        {
          double s = lambda1;
          lambda1 = lambda3;
          lambda3 = s;
        }

        double md = (lambda1 + lambda2 + lambda3) / 3.0;

/*
        double dev1 = lambda1 - md;
        double dev2 = lambda2 - md;
        double dev3 = lambda3 - md;
        double fa =
          scaleFA * sqrt(dev1*dev1 + dev2*dev2 + dev3*dev3) /
            (sqrt(lambda1*lambda1 + lambda2*lambda2 + lambda3*lambda3)+1e-20);
*/

        double dev1 = lambda1 - lambda2;
        double dev2 = lambda1 - lambda3;
        double dev3 = lambda2 - lambda3;
        double fa =
          scaleFA * sqrt(dev1*dev1 + dev2*dev2 + dev3*dev3) /
            (sqrt(lambda1*lambda1 + lambda2*lambda2 + lambda3*lambda3)+1e-20);

        if (fa < 0.0)
          fa = 0.0;
        if (fa > 1.0)
          fa = 1.0;

        m_MDImage->SetPixel(ind, (float)md);
        m_FAImage->SetPixel(ind, (float)fa);
      }

  return;

}
