
#include "itkImageFileReader.h"
#include "itkMinimumMaximumImageFilter.h"

#include "vtkDataArray.h"
#include "vtkPointData.h"
#include "vtkSmartPointer.h"
#include "vtkStructuredPoints.h"
#include "vtkStructuredPointsReader.h"

#include "DTImageReader.h"

#include <cmath>
#include <iostream>

DTImageReader
::DTImageReader()
{
  m_ThresholdS0Fraction = 0.05;

  m_Output = 0;

  m_BFactor = 1000.0;

  // Basserr as default, image orientation should be RPI
  m_Gradients.Allocate(7);

  DiffusionTensor::MatrixType g(3, 1);

  // g0
  g(0, 0) = 0;
  g(1, 0) = 0;
  g(2, 0) = 0;
  m_Gradients.Append(g);

  // g1
  g(0, 0) = 1;
  g(1, 0) = 0;
  g(2, 0) = 1;
  m_Gradients.Append(g);

  // g2
  g(0, 0) = -1;
  g(1, 0) = 0;
  g(2, 0) = 1;
  m_Gradients.Append(g);

  // g3
  g(0, 0) = 0;
  g(1, 0) = 1;
  g(2, 0) = 1;
  m_Gradients.Append(g);

  // g4
  g(0, 0) = 0;
  g(1, 0) = 1;
  g(2, 0) = -1;
  m_Gradients.Append(g);

  // g5
  g(0, 0) = 1;
  g(1, 0) = 1;
  g(2, 0) = 0;
  m_Gradients.Append(g);

  // g6
  g(0, 0) = -1;
  g(1, 0) = 1;
  g(2, 0) = 0;
  m_Gradients.Append(g);

  this->NormalizeGradients();
}

DTImageReader
::~DTImageReader()
{

}

void
DTImageReader
::SetGradients(const DynArray<DiffusionTensor::MatrixType>& glist)
{

  if (glist.GetSize() < 7)
  {
    itkExceptionMacro(<< "Need at least 7 gradients");
  }

  unsigned int zeroCount = 0;
  for (unsigned int i = 0; i < glist.GetSize(); i++)
  {
    if (glist[i].rows() != 3 || glist[i].columns() != 1)
    {
      itkExceptionMacro(<< "Invalid gradient matrix size");
    }

    if (glist[i].frobenius_norm() == 0)
      zeroCount++;
  }

  if (zeroCount == 0)
  {
    itkExceptionMacro(<< "Need at least one image for zero gradient case");
  }

  m_Gradients = glist;

  this->NormalizeGradients();
}

void
DTImageReader
::NormalizeGradients()
{

  for (unsigned int i = 0; i < m_Gradients.GetSize(); i++)
  {
    double norm = sqrt(m_Gradients[i].frobenius_norm());
    if (norm == 0.0)
      continue;
    m_Gradients[i] /= norm;
  }

}

void
DTImageReader
::ReadRawImages(const StringList& filenames)
{
  itkDebugMacro(<< "ReadRawImages");

  unsigned int numImages = filenames.GetSize();

  if (numImages < 7)
    itkExceptionMacro(<< "Need at least 7 images");

  unsigned int numGradients = m_Gradients.GetSize();

  if (numImages != numGradients)
    itkExceptionMacro(<< "Number of input images and gradients don't match");

  // Load images
  DynArray<FloatImagePointer> inputImages;

  typedef itk::ImageFileReader<FloatImageType> ReaderType;
  for (unsigned int i = 0; i < numImages; i++)
  {
    itkDebugMacro(<< "Reading " << filenames[i] << "...");

    ReaderType::Pointer reader = ReaderType::New();
    reader->SetFileName(filenames[i].c_str());
    reader->Update();

    inputImages.Append(reader->GetOutput());
  }

  // Check image size consistency
  FloatImageSizeType size =
    inputImages[0]->GetLargestPossibleRegion().GetSize();
  for (unsigned int i = 1; i < numImages; i++)
  {
    FloatImageSizeType size_i =
      inputImages[i]->GetLargestPossibleRegion().GetSize();
    if (size != size_i)
      itkExceptionMacro(<< "Image size mismatch");
  }

  FloatImageIndexType ind;

  // Average the zero gradient images for S0
  FloatImagePointer avgS0Img = FloatImageType::New();
  avgS0Img->CopyInformation(inputImages[0]);
  avgS0Img->SetRegions(inputImages[0]->GetLargestPossibleRegion().GetSize());
  avgS0Img->Allocate();

  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
        avgS0Img->SetPixel(ind, 0);

  unsigned int numZeroGradients = 0;
  for (unsigned int j = 0; j < numImages; j++)
  {
    if (m_Gradients[j].frobenius_norm() >= 1e-20)
      continue;

    FloatImagePointer img_j = inputImages[j];
    for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
      for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
        for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
          avgS0Img->SetPixel(ind,
            avgS0Img->GetPixel(ind) + img_j->GetPixel(ind));

    numZeroGradients++;
  }

  if (numZeroGradients == 0)
    itkExceptionMacro(<< "No baseline (S0) image");

  if (numZeroGradients > 1)
  {
    for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
      for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
        for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
          avgS0Img->SetPixel(ind,
            avgS0Img->GetPixel(ind) / numZeroGradients);
  }

  // Compute threshold for S0 (lowest x% of image intensities)
  typedef itk::MinimumMaximumImageFilter<FloatImageType> MinMaxType;
  MinMaxType::Pointer minmax = MinMaxType::New();
  minmax->SetInput(avgS0Img);
  minmax->Update();

  double minS0 = minmax->GetMinimum();
  double maxS0 = minmax->GetMaximum();
  double thresholdS0 = minS0 + m_ThresholdS0Fraction*(maxS0-minS0);

  itkDebugMacro(<< "Lower threshold for S0 is " << thresholdS0);

  // Create list of images and gradient directions for non-zero gradient cases
  DynArray<FloatImagePointer> SkImages;
  SkImages.Allocate(numGradients - numZeroGradients);
  DynArray<DiffusionTensor::MatrixType> SkGradients;
  SkGradients.Allocate(numGradients - numZeroGradients);

  for (unsigned int j = 0; j < numImages; j++)
  {
    if (m_Gradients[j].frobenius_norm() >= 1e-20)
    {
// TODO: align the image to average S0
      SkImages.Append(inputImages[j]);
      SkGradients.Append(m_Gradients[j]);
    }
  }

  unsigned int numSkImages = SkImages.GetSize();

  if (numSkImages < 6)
    itkExceptionMacro(<< "Underdetermined linear system, need more non-zero gradients");

  // Allocate space for the output image
  m_Output = DTImageType::New();
  m_Output->CopyInformation(inputImages[0]);
  m_Output->SetRegions(inputImages[0]->GetLargestPossibleRegion());
  m_Output->Allocate();

  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        double pixS0 = avgS0Img->GetPixel(ind);

        if (pixS0 <= thresholdS0)
          continue;

        double logS0 = log(pixS0);

        DiffusionTensor::MatrixType lhs(numSkImages, 6);
        DiffusionTensor::MatrixType rhs(numSkImages, 1);

        for (unsigned int k = 0; k < numSkImages; k++)
        {
          double pixSk = SkImages[k]->GetPixel(ind);

          if (pixSk < 1e-20)
            pixSk = 1e-20;

          double logSk = log(pixSk);

          DiffusionTensor::MatrixType gmatrix = SkGradients[k];
    
          double gx, gy, gz;

          gx = gmatrix(0, 0);
          gy = gmatrix(1, 0);
          gz = gmatrix(2, 0);

          // Dxx
          lhs(k, 0) = gx*gx;
          // Dxy
          lhs(k, 1) = 2.0 * (gx*gy);
          // Dxz
          lhs(k, 2) = 2.0 * (gx*gz);
          // Dyy
          lhs(k, 3) = gy*gy;
          // Dyz
          lhs(k, 4) = 2.0 * (gy*gz);
          // Dzz
          lhs(k, 5) = gz*gz;

          rhs(k, 0) = (logS0-logSk) / m_BFactor;
        }

// Least squares fit
#if 1
        DiffusionTensor::MatrixType sol;
        if (lhs.rows() > lhs.columns())
        {
          // Do LSQ
          DiffusionTensor::MatrixQRType qr(lhs);
          sol = qr.solve(rhs);
        }
        else
        {
          //sol = DiffusionTensor::MatrixInverseType(lhs)*rhs;
// gcc workaround
          DiffusionTensor::MatrixType invL =
            DiffusionTensor::MatrixInverseType(lhs);
          sol = invL*rhs;
        }
#else
        DiffusionTensor::MatrixType sol;
        {
          DiffusionTensor::MatrixType lhsTrans = lhs.transpose();
          sol = DiffusionTensor::MatrixInverseType(lhsTrans*lhs)*lhsTrans*rhs;
        }
#endif

        DiffusionTensor::MatrixType M(3, 3);

        unsigned int elemidx = 0;
        for (unsigned int i = 0; i < 3; i++)
          for (unsigned int j = i; j < 3; j++)
          {
            M(i, j) = sol(elemidx, 0);
            M(j, i) = M(i, j);
            elemidx++;
          }

        DiffusionTensor T;
        T.FromMatrix(M);
        T.ForcePositiveDefinite();

        m_Output->SetPixel(ind, T);

      }

}

void
DTImageReader
::ReadStackedImage(const char* fn)
{
  if (fn == NULL)
    itkExceptionMacro(<< "NULL argument");

  typedef itk::ImageFileReader<FloatImageType> ReaderType;
  ReaderType::Pointer reader = ReaderType::New();

  reader->SetFileName(fn);
  reader->Update();

  FloatImagePointer tmp = reader->GetOutput();

  DTImageSizeType size = tmp->GetLargestPossibleRegion().GetSize();

  if ((size[2] % 6) != 0)
    itkExceptionMacro("Z-axis dimension not a multiple of 6");

  unsigned int stacked_zdim = size[2];
  size[2] = stacked_zdim/6;

  DTImageRegionType region;
  region.SetSize(size);

  // Allocate space for the output image
  m_Output = DTImageType::New();
  m_Output->CopyInformation(tmp);
  m_Output->SetRegions(region);
  m_Output->Allocate();

  DTImageIndexType ind;
  DTImageIndexType input_ind;
  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        input_ind[0] = ind[0];
        input_ind[1] = ind[1];

        DiffusionTensor T;
        for (int k = 0; k < 6; k++)
        {
          input_ind[2] = ind[2] + k*size[2];
          double vv = tmp->GetPixel(input_ind);
          T.SetElementAt(k, vv);
        }
        T.ForcePositiveDefinite();

        m_Output->SetPixel(ind, T);
      }

}

void
DTImageReader
::ReadVTKImage(const char* fn)
{
  if (fn == NULL)
    itkExceptionMacro(<< "NULL argument");

  vtkSmartPointer<vtkStructuredPointsReader> tensorReader =
    vtkSmartPointer<vtkStructuredPointsReader>::New();
  tensorReader->SetFileName(fn);
  tensorReader->Update();

  vtkSmartPointer<vtkStructuredPoints> sp = tensorReader->GetOutput();

  int* dims = sp->GetDimensions();
  double* gaps = sp->GetSpacing();
  double* origs = sp->GetOrigin();

  DTImageSizeType size;
  DTImageSpacingType spacing;
  DTImagePointType origin;
  for (int i = 0; i < 3; i++)
  {
    size[i] = dims[i];
    spacing[i] = gaps[i];
    origin[i] = origs[i];
  }

  DTImageRegionType region;
  region.SetSize(size);

  // Allocate space for the output image
  m_Output = DTImageType::New();
  m_Output->SetRegions(region);
  m_Output->Allocate();
  m_Output->SetOrigin(origin);
  m_Output->SetSpacing(spacing);

  vtkSmartPointer<vtkDataArray> tensors = sp->GetPointData()->GetTensors();

  unsigned int tensorIndex = 0;

  DTImageIteratorType it(m_Output, region);

  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    DTImageIndexType ind = it.GetIndex();
    double* pt = tensors->GetTuple(tensorIndex++);

    DiffusionTensor T;

    int k = 0;
    for (int j = 0; j < 3; j++)
      for (int i = j; i < 3; i++)
      {
        float x = pt[i+3*j];

        //itk::ByteSwapper<float>::SwapFromSystemToBigEndian(&x);

        T.SetElementAt(k++, x);
      }

    T.ForcePositiveDefinite();

    m_Output->SetPixel(ind, T);
  }

}

void
DTImageReader
::ReadVectorImage(const char* fn)
{
  if (fn == NULL)
    itkExceptionMacro(<< "NULL argument");

  typedef itk::ImageFileReader<VectorImageType> ReaderType;

  ReaderType::Pointer reader = ReaderType::New();
  reader->SetFileName(fn);
  reader->Update();

  VectorImagePointer img = reader->GetOutput();

  DTImageSizeType size = img->GetLargestPossibleRegion().GetSize();
  DTImageSpacingType spacing = img->GetSpacing();
  DTImagePointType origin = img->GetOrigin();

  //if (size[3] != 6)
  //  itkExceptionMacro(<< "Expecting 6 vector elements");

  DTImageRegionType region;
  region.SetSize(size);

  // Allocate space for the output image
  m_Output = DTImageType::New();
  m_Output->CopyInformation(img);
  m_Output->SetRegions(region);
  m_Output->Allocate();
  m_Output->SetOrigin(origin);
  m_Output->SetSpacing(spacing);

  DTImageIteratorType it(m_Output, region);

  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    DTImageIndexType ind = it.GetIndex();
    
    VectorImageType::PixelType v = img->GetPixel(ind);

    DiffusionTensor T;

    int k = 0;
    for (int j = 0; j < 3; j++)
      for (int i = j; i < 3; i++)
      {
        float x = v[i+3*j];

        //itk::ByteSwapper<float>::SwapFromSystemToBigEndian(&x);

        T.SetElementAt(k++, x);
      }

    T.ForcePositiveDefinite();

    m_Output->SetPixel(ind, T);
  }

}
