
#include "itkByteSwapper.h"
#include "itkImageFileWriter.h"
#include "itkNumericTraits.h"
#include "itkRescaleIntensityImageFilter.h"

#include "DTImageScalarSource.h"
#include "DTImageWriter.h"

#include <cmath>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>

DTImageWriter::
DTImageWriter()
{
  m_BFactor = 1000.0;
  m_Input = 0;
}

DTImageWriter::
~DTImageWriter()
{

}

void
DTImageWriter::
SetInput(DTImageType* dti)
{
  m_Input = dti;
}

void
DTImageWriter::
SetGradients(const DynArray<DiffusionTensor::MatrixType>& gradlist)
{
  if (gradlist.GetSize() < 7)
  {
    itkExceptionMacro(<< "Need at least 7 gradients");
  }

  // TODO: check zero gradients

  m_Gradients = gradlist;
}

void
DTImageWriter::
WriteVTKImage(const char* fname)
{
  if (fname == NULL)
    itkExceptionMacro(<< "NULL argument");

  if (m_Input.IsNull())
    itkExceptionMacro(<< "No DTI data specified");

  std::ofstream outf;

  outf.open(fname, std::ios::out | std::ios::binary);

  if (outf.fail())
    itkExceptionMacro(<< "Cannot open " << fname);

  DTImageSizeType size = m_Input->GetLargestPossibleRegion().GetSize();

  DTImageSpacingType spacing = m_Input->GetSpacing();

  DTImagePointType origin = m_Input->GetOrigin();

  // Header
  outf << "# vtk DataFile Version 2.0\n"; 
  outf << "DTI Tensor\n"; // Name
  outf << "BINARY\n"; // Type
  outf << "DATASET STRUCTURED_POINTS\n";
  outf << "DIMENSIONS " << size[0] << " " << size[1] << " " << size[2] << "\n";
  outf << "SPACING " << spacing[0] << " " << spacing[1] << " " << spacing[2];
  outf << "\n";
  outf << "ORIGIN " << origin[0] << " " << origin[1] << " " << origin[2] << "\n";
  outf << "POINT_DATA " << size[0]*size[1]*size[2] << "\n";
  outf << "TENSORS mytensors float\n";

  ShortImageIndexType ind;

  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        DiffusionTensor::MatrixType D = m_Input->GetPixel(ind).GetMatrix();

        for (unsigned int i = 0; i < 3; i++)
          for (unsigned int j = 0; j < 3; j++)
          {
            float x = static_cast<float>(D(i, j));

            itk::ByteSwapper<float>::SwapFromSystemToBigEndian(&x);

            outf.write(reinterpret_cast<char*>(&x), sizeof(float));
          }
      }

  outf.close();

  
}

void
DTImageWriter::
WriteStackedImage(const char* fname)
{
  if (fname == NULL)
    itkExceptionMacro(<< "NULL argument");

  if (m_Input.IsNull())
    itkExceptionMacro(<< "No DTI data specified");

  DTImageSizeType size = m_Input->GetLargestPossibleRegion().GetSize();

  unsigned int actual_zdim = size[2];
  size[2] = 6*actual_zdim;

  DTImageRegionType region;
  region.SetSize(size);

  FloatImagePointer tmp = FloatImageType::New();
  tmp->CopyInformation(m_Input);
  tmp->SetRegions(region);
  tmp->Allocate();

  DTImageIndexType ind;
  DTImageIndexType other_ind;
  for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
    for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
      for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
      {
        other_ind[0] = ind[0];
        other_ind[1] = ind[1];
        other_ind[2] = ind[2] % actual_zdim;

        unsigned int k = ind[2] / actual_zdim;

        const DiffusionTensor::ScalarType* elems =
          m_Input->GetPixel(other_ind).GetElements();

        tmp->SetPixel(ind, elems[k]);
      }

  typedef itk::ImageFileWriter<FloatImageType> WriterType;
  WriterType::Pointer writer = WriterType::New();

  writer->SetFileName(fname);
  writer->SetInput(tmp);
  writer->UseCompressionOn();
  writer->Update();
}

void
DTImageWriter::
WriteRawImages(const char* basename)
{
  if (basename == NULL)
    itkExceptionMacro(<< "NULL argument");

  if (m_Input.IsNull())
    itkExceptionMacro(<< "No DTI data specified");

  if (m_Gradients.GetSize() < 7)
    itkExceptionMacro(<< "Need at least 7 gradients");

  unsigned int numGradients = m_Gradients.GetSize();

  DTImageRegionType region = m_Input->GetLargestPossibleRegion();
  DTImageSizeType size = region.GetSize();

  // Allocate pointers to output images
  DynArray<ShortImagePointer> outputList;
  outputList.Initialize(numGradients, 0);

  // Compute MD
  DTImageScalarSource::Pointer ssource = DTImageScalarSource::New();
  ssource->SetInput(m_Input);
  ssource->Update();

  // Rescale MD
  typedef itk::RescaleIntensityImageFilter<
    DTImageScalarSource::FloatImageType, ShortImageType> MDRescalerType;

  MDRescalerType::Pointer rescaler = MDRescalerType::New();
  rescaler->SetInput(ssource->GetMDImage());
  rescaler->SetOutputMinimum(0);
  rescaler->SetOutputMaximum(itk::NumericTraits<short>::max());
  rescaler->Update();

  ShortImagePointer MDimg = rescaler->GetOutput();

  // Fill output image list
  for (unsigned int i = 0; i < numGradients; i++)
  {
    // Use MD for zero gradient images
    if (m_Gradients[i].frobenius_norm() < 1e-20)
    {
      outputList[i] = MDimg;
    }

    DiffusionTensor::MatrixType g = m_Gradients[i];

    // Normalize g
    double norm = sqrt(g.frobenius_norm());
    if (norm != 0.0)
      g /= norm;

    ShortImagePointer img = ShortImageType::New();
    img->CopyInformation(m_Input);
    img->SetRegions(region);
    img->Allocate();

    ShortImageIndexType ind;
    for (ind[2] = 0; ind[2] < (long)size[2]; ind[2]++)
      for (ind[1] = 0; ind[1] < (long)size[1]; ind[1]++)
        for (ind[0] = 0; ind[0] < (long)size[0]; ind[0]++)
        {
          DiffusionTensor T = m_Input->GetPixel(ind);
          DiffusionTensor::MatrixType D = T.GetMatrix();

          DiffusionTensor::MatrixType x = g.transpose() * D * g;

          double pixS0 = MDimg->GetPixel(ind);
          double pixSk = pixS0 * exp(-x(0,0) * m_BFactor);

          if (pixSk < 0)
            pixSk = 0;
          if (pixSk > itk::NumericTraits<short>::max())
            pixSk = itk::NumericTraits<short>::max();

          img->SetPixel(ind, static_cast<short>(pixSk));
        }

    outputList[i] = img;

  }

  // Write all the images
  typedef itk::ImageFileWriter<ShortImageType> WriterType;

  for (unsigned int i = 0; i < numGradients; i++)
  {
    WriterType::Pointer writer = WriterType::New();

    std::string fnstr = basename;

    std::stringstream outss;
    outss << "_dti" << i << ".mha" << std::ends;

    fnstr += outss.str();

    writer->SetFileName(fnstr.c_str());
    writer->SetInput(outputList[i]);
    writer->UseCompressionOn();
    writer->Update();
  }

}
