
#include<iostream>
#include<iomanip>
#include<sstream>

#include "itkNrrdImageIO.h"
#include "itkImageFileWriter.h"
#include "itkImageFileReader.h"
#include "itkImage.h"
#include "itkVector.h"
#include "itkVectorGradientAnisotropicDiffusionImageFilter.h"
#include "itkDiffusionTensor3D.h"
#include "itkImageLinearIteratorWithIndex.h"
#include "itkVariableLengthVector.h"         
#include "itkImageDuplicator.h"
#include "itkMetaDataObject.h"

#include "itkVectorImage.h"

#include "itkRicianFilter.h"
#include "itkGaussianFilter.h"

#include "DWIRicianNoiseRemovalFilterCLP.h"

//#define NUMIMAGES 7 
//#define NUMIMAGES 19 
template <unsigned int NUMIMAGES>
class N
{
public:
  typedef  itk::Image<float, 3> DiffusionDataImageType;
  typedef  itk::ImageFileWriter<DiffusionDataImageType> DiffusionDataImageWriterType;
  typedef  itk::ImageFileReader<DiffusionDataImageType> DiffusionDataImageReaderType;
  typedef  itk::Vector< float,NUMIMAGES> VectorPixelType;
  typedef  itk::Image< VectorPixelType,3> VectorDiffusionDataImageType;
  typedef  itk::ImageDuplicator< DiffusionDataImageType> ImageDuplicator;
  typedef  itk::ImageDuplicator< VectorDiffusionDataImageType> ImageDuplicator2;
  
  typedef  itk::ImageFileWriter<VectorDiffusionDataImageType> VectorDiffusionDataImageWriterType;
  typedef  itk::ImageFileReader<VectorDiffusionDataImageType> VectorDiffusionDataImageReaderType;
  
  typedef typename itk::Image< VectorPixelType,3>::Pointer VectDiffDataImagePointerType;
};
typedef unsigned int uint32;

/*
 * createDWIVectorImage: returns a 3D volume of vectors whose components are
 * the dwi image signals along each gradient directions.NumImages specifies the
 * vector dimension.
 * dwi is an array of pointers to each 2D Diffusion Weighted Image
 * numImages: The number of 2D DWI images in the array.
 */
template< unsigned int NumImage >
typename N<NumImage>::VectDiffDataImagePointerType
createDWIVectorImage(typename N<NumImage>::DiffusionDataImageType::Pointer dwi[],uint32 numImages)
{
  
  //Allocate Memory For the 3D Volume of vectors
  typename N<NumImage>::VectorDiffusionDataImageType::Pointer vectorDWI = N<NumImage>::VectorDiffusionDataImageType::New(); 
  typename N<NumImage>::VectorDiffusionDataImageType::IndexType start;
  start[0]=0;
  start[1]=0;
  start[2]=0;
  //define the size of the image
  typename N<NumImage>::DiffusionDataImageType::SizeType dwiSize= 
    dwi[0]->GetLargestPossibleRegion().GetSize();
  typename N<NumImage>::VectorDiffusionDataImageType::SizeType size;
  size[0]=dwiSize[0];
  size[1]=dwiSize[1];
  size[2]=dwiSize[2];
  //Initialize a region 
  typename N<NumImage>::VectorDiffusionDataImageType::RegionType region;
  region.SetSize(size);
  region.SetIndex(start);
  //Allocate memory
  vectorDWI->SetRegions(region);
  vectorDWI->SetSpacing(dwi[0]->GetSpacing());
  vectorDWI->SetDirection(dwi[0]->GetDirection());
  vectorDWI->Allocate();
  
  typename N<NumImage>::VectorDiffusionDataImageType::IndexType vectorPixelIndex;
  typename N<NumImage>::VectorDiffusionDataImageType::PixelType vectorPixelValue;
  typename N<NumImage>::DiffusionDataImageType::IndexType scalarPixelIndex;
  typename N<NumImage>::DiffusionDataImageType::PixelType scalarPixelValue;
  
  for(int i=0;i<size[0];i++) 
    for(int j=0;j<size[1];j++)
      for(int k=0;k<size[2];k++)
	{
	  vectorPixelIndex[0]=i;
	  vectorPixelIndex[1]=j;
	  vectorPixelIndex[2]=k;
	  
	  scalarPixelIndex[0]=i;
	  scalarPixelIndex[1]=j;
	  scalarPixelIndex[2]=k;
	  
	  //Set the vector components
	  for(unsigned int vdim=0;vdim<numImages;vdim++)
	    {
	      scalarPixelValue=dwi[vdim]->GetPixel(scalarPixelIndex);
	      vectorPixelValue[vdim]=scalarPixelValue;
	    }
	  vectorDWI->SetPixel(vectorPixelIndex,vectorPixelValue);
	}
  return(vectorDWI);
}

/*
 * Read Real DWI data
 */
template< unsigned int NumImage >
typename N<NumImage>::VectorDiffusionDataImageType::Pointer 
readDWI_ImagesREAL(const std::string& prefix)
{
  typename N<NumImage>::VectorDiffusionDataImageReaderType::Pointer dwiReader= 
    N<NumImage>::VectorDiffusionDataImageReaderType::New();
  dwiReader->SetFileName(prefix);
  try{
    dwiReader->Update();
  }
  catch (itk::ExceptionObject e) 
    {
      std::cerr << e << std::endl;
      exit(-1);
    };
  return(dwiReader->GetOutput());
}


/*
 * Write out DWI data
 */
template< unsigned int NumImage >
void writeDWI_ImagesREAL(const std::string& prefix,
			 typename N<NumImage>::VectorDiffusionDataImageType::Pointer image,
			 itk::MetaDataDictionary & dict)
{
  
  typename N<NumImage>::VectorDiffusionDataImageWriterType::Pointer dwiWriter= 
   N<NumImage>::VectorDiffusionDataImageWriterType::New();
  dwiWriter->SetFileName(prefix);       
  //Copy the header information (including the gradient directions)
  //image->SetMetaDataDictionary(dict);
  dwiWriter->SetInput(image);
  try{
    dwiWriter->Update();
  }
  catch (itk::ExceptionObject e) 
    {
      std::cerr << e << std::endl;
      exit(-1);
    }
}
/*
 * readDWI_Images returns a pointer to a 3D volume of vectors with dimension
 * determined by numImages parameter.
 */
template< unsigned int NumImage >
typename N<NumImage>::VectorDiffusionDataImageType::Pointer 
readDWI_Images(const uint32 numImages,char* prefix)
{
  typename N<NumImage>::DiffusionDataImageReaderType::Pointer dwiReader= N<NumImage>::DiffusionDataImageReaderType::New();
  //This line doesn't compile so i switch back to the original code
  //typename N<NumImage>::DiffusionDataImageType::Pointer  *dwi = new N<NumImage>::DiffusionDataImageType::Pointer[numImages];
  typename N<NumImage>::DiffusionDataImageType::Pointer  dwi[numImages];
  typename N<NumImage>::ImageDuplicator::Pointer duplicatorFilter= N<NumImage>::ImageDuplicator::New();
  std::ostringstream ostr;
  
  for(int i=0;i<numImages;i++)
    {
      ostr<<prefix<<std::setfill('0')<<std::setw(3)<<i<<".nrrd";
      std::cout<<"Reading File="<<ostr.str()<<std::endl;
      dwiReader->SetFileName((ostr.str()).c_str());
      ostr.str("");
      try
	{
	  dwiReader->Update();
	  duplicatorFilter->SetInputImage(dwiReader->GetOutput()); 
	  duplicatorFilter->Update();
	  dwi[i]=duplicatorFilter->GetOutput();
	}
      catch (itk::ExceptionObject e) 
	{
	  std::cerr << e << std::endl;
	  exit(-1);
	}
      
    }
  typename N<NumImage>::VectorDiffusionDataImageType::Pointer dwiVector=createDWIVectorImage(dwi,numImages);
  //  delete [] dwi;
  return(dwiVector);
}

/*
 * Square each value in the original
 * DWI volume
 */
template< unsigned int NumImage >
typename N<NumImage>::VectorDiffusionDataImageType::Pointer
squareDWIVolume(typename N<NumImage>::VectorDiffusionDataImageType::Pointer vectorDWI)
{
  std::cout<<"Squaring Image....."<<std::endl;	
  typename N<NumImage>::VectorDiffusionDataImageType::SizeType size= 
    vectorDWI->GetLargestPossibleRegion().GetSize();
	
  typename N<NumImage>::VectorDiffusionDataImageType::IndexType vectorPixelIndex;
  typename N<NumImage>::VectorDiffusionDataImageType::PixelType vectorPixelValue;
  
  for(int i=0;i<size[0];i++) 
    for(int j=0;j<size[1];j++)
      for(int k=0;k<size[2];k++)
	{
	  vectorPixelIndex[0]=i;
	  vectorPixelIndex[1]=j;
	  vectorPixelIndex[2]=k;
	  vectorPixelValue=
	    vectorDWI->GetPixel(vectorPixelIndex);
	  for(int i=0;i<vectorPixelValue.Size();i++)
	    vectorPixelValue[i]*=vectorPixelValue[i];
	  vectorDWI->SetPixel(vectorPixelIndex,vectorPixelValue);
	}
  return(vectorDWI);
}

/*
 * Perform Rician Bias Correction
 */
template< unsigned int NumImage >
typename N<NumImage>::VectorDiffusionDataImageType::Pointer
biasCorrect(typename N<NumImage>::VectorDiffusionDataImageType::Pointer vectorDWI,float sigma)
{
  std::cout<<"Performing Bias Correction..."<<std::endl;
  typename N<NumImage>::VectorDiffusionDataImageType::SizeType size= 
    vectorDWI->GetLargestPossibleRegion().GetSize();
  
  typename N<NumImage>::VectorDiffusionDataImageType::IndexType vectorPixelIndex;
  typename N<NumImage>::VectorDiffusionDataImageType::PixelType vectorPixelValue;
  
  for(int i=0;i<size[0];i++) 
    for(int j=0;j<size[1];j++)
      for(int k=0;k<size[2];k++)
	{
	  vectorPixelIndex[0]=i;
	  vectorPixelIndex[1]=j;
	  vectorPixelIndex[2]=k;
	  
	  vectorPixelValue=
	    vectorDWI->GetPixel(vectorPixelIndex);
	  for(int i=0;i<vectorPixelValue.Size();i++){
	    vectorPixelValue[i]-=(2*sigma*sigma);
	    if(vectorPixelValue[i]<=0)
	      vectorPixelValue[i]=0;
	    else
	      vectorPixelValue[i]=sqrt(vectorPixelValue[i]);
	  }
	  vectorDWI->SetPixel(vectorPixelIndex,vectorPixelValue);
	}
  return(vectorDWI);
  
}

/*
 * This routine does anisotropic Filtering on the 3-D volume of 
 * DWI Images
 */
template< unsigned int NumImage >
typename N<NumImage>::VectorDiffusionDataImageType::Pointer 
anisoFilter(
	    typename N<NumImage>::VectorDiffusionDataImageType::Pointer vectorDWI,
	    uint32 numIterations,
	    float conductance,
	    float timeStep,
	    uint32 filterType,
	    float sigma,
	    float lamda1=1.0,
	    float lamda2=1.0
	    )
{
  typedef itk::VectorGradientAnisotropicDiffusionImageFilter < typename N<NumImage>::VectorDiffusionDataImageType, typename N<NumImage>::VectorDiffusionDataImageType >  PeronaMalikFilterType;
  typedef itk::RicianFilter < typename N<NumImage>::VectorDiffusionDataImageType, typename N<NumImage>::VectorDiffusionDataImageType >  RicianFilterType;
  typedef itk::GaussianFilter< typename N<NumImage>::VectorDiffusionDataImageType, typename N<NumImage>::VectorDiffusionDataImageType >  GaussianFilterType;
  typename N<NumImage>::ImageDuplicator2::Pointer duplicatorFilter= N<NumImage>::ImageDuplicator2::New();
  duplicatorFilter->SetInputImage(vectorDWI); 
  duplicatorFilter->Update();
  typename N<NumImage>::VectorDiffusionDataImageType::Pointer noisyVectorDWI=duplicatorFilter->GetOutput();
  
  typename N<NumImage>::VectorDiffusionDataImageType::Pointer input;
  typename N<NumImage>::VectorDiffusionDataImageType::Pointer filteredVolume;
  
  std::cout<<"k="<<conductance;
  std::cout<<",dt="<<timeStep;
  std::cout<<",numiter="<<numIterations<<std::endl;
  
  typename PeronaMalikFilterType::Pointer filter1;
  typename RicianFilterType::Pointer filter2;
  typename GaussianFilterType::Pointer filter3;
  
  switch(filterType)
    {
    case 0:  //Simple Aniso Filtering
      
      std::cout<<"Simple AnisoTropic Filtering ..."<<std::endl;
      filter1= PeronaMalikFilterType::New();                		
      filter1->SetInput(vectorDWI);
      filter1->SetNumberOfIterations(numIterations);
      filter1->SetTimeStep(timeStep);
      filter1->SetConductanceParameter(conductance);
      filter1->Update();
      filteredVolume=(filter1->GetOutput()); 
      break;
    case 1: //Aniso Filtering squared image
      std::cout<<"Squared Magnitude Filtering.."<<std::endl;
      input=squareDWIVolume<NumImage>(vectorDWI);
      filter1=PeronaMalikFilterType::New();
      filter1->SetInput(vectorDWI);
      filter1->SetNumberOfIterations(numIterations);
      filter1->SetTimeStep(timeStep);
      filter1->SetConductanceParameter(conductance);
      filter1->Update();
      filteredVolume=biasCorrect<NumImage>(filter1->GetOutput(),sigma);	
      break;
    case 2: //Rician Correction Term
      std::cout<<"Rician Filtering"<<std::endl;
      std::cout<<"Bias Correction term"<<std::endl;
      filter2=RicianFilterType::New();
      filter2->SetInput(vectorDWI);
      filter2->SetNumberOfIterations(numIterations);
      filter2->SetTimeStep(timeStep);
      filter2->SetConductanceParameter(conductance);
      filter2->SetNoisyImage(vectorDWI);
      filter2->InitializeAttachmentTermObjects(sigma,lamda1,lamda2);
      filter2->Update();
      filteredVolume=(filter2->GetOutput()); 
      break;
    case 3: //Gaussian Correction Term
      std::cout<<"ANiso Filtering with Gaussian"<<std::endl;
      std::cout<<"Bias Correction term"<<std::endl;
      filter3=GaussianFilterType::New();
      filter3->SetInput(vectorDWI);
      filter3->SetNumberOfIterations(numIterations);
      filter3->SetTimeStep(timeStep);
      filter3->SetConductanceParameter(conductance);
      filter3->SetNoisyImage(noisyVectorDWI);
      filter3->InitializeAttachmentTermObject(sigma,lamda2);
      filter3->Update();
      filteredVolume=(filter3->GetOutput()); 
      break;
      
    }
  
  std::cout<<"Finished applying Filter"<<std::endl;
  
  return(filteredVolume);
}

/*
 * Convert vector DWI volume to scalar DWI volumes
 */
template< unsigned int NumImage >
void createScalarDWIImages(typename N<NumImage>::VectorDiffusionDataImageType::Pointer vectorDWI,
			   typename N<NumImage>::DiffusionDataImageType::Pointer dwi[],
			   int numImages)
{
  
  //Allocate memory for numImage 2D DWI's
  typename N<NumImage>::DiffusionDataImageType::IndexType start;
  start[0]=0;
  start[1]=0;
  start[2]=0;
  //define the size of the image
  typename N<NumImage>::VectorDiffusionDataImageType::SizeType dwiSize= 
    vectorDWI->GetLargestPossibleRegion().GetSize();
  typename N<NumImage>::DiffusionDataImageType::SizeType size;
  size[0]=dwiSize[0];
  size[1]=dwiSize[1];
  size[2]=dwiSize[2];
  //Initialize a region 
  typename N<NumImage>::VectorDiffusionDataImageType::RegionType region;
  region.SetSize(size);
  region.SetIndex(start);
  
  //Allocate memory
  for(int i=0;i<numImages;i++)
    {
      dwi[i]= N<NumImage>::DiffusionDataImageType::New();
      dwi[i]->SetRegions(region);
      dwi[i]->SetSpacing(vectorDWI->GetSpacing());
      dwi[i]->SetDirection(vectorDWI->GetDirection());
      dwi[i]->Allocate();
    }
  
  typename N<NumImage>::VectorDiffusionDataImageType::IndexType vectorPixelIndex;
  typename N<NumImage>::VectorDiffusionDataImageType::PixelType vectorPixelValue;
  typename N<NumImage>::DiffusionDataImageType::IndexType scalarPixelIndex;
  typename N<NumImage>::DiffusionDataImageType::PixelType scalarPixelValue;
  
  for(int i=0;i<size[0];i++) 
    for(int j=0;j<size[1];j++)
      for(int k=0;k<size[2];k++)
	{
	  vectorPixelIndex[0]=i;
	  vectorPixelIndex[1]=j;
	  vectorPixelIndex[2]=k;
	  
	  scalarPixelIndex[0]=i;
	  scalarPixelIndex[1]=j;
	  scalarPixelIndex[2]=k;
	  
	  //Set the vector components
	  vectorPixelValue= vectorDWI->GetPixel(vectorPixelIndex);
	  for(unsigned int vdim=0;vdim<numImages;vdim++)
	    dwi[vdim]->SetPixel(scalarPixelIndex,vectorPixelValue[vdim]);
	}
}


/*
 * writeFilteredDWI writes out the vector DWI volume
 * as separate DWI Images which can then be estimated 
 * using the estimate tensor program
 */
template< unsigned int NumImage >
void writeFilteredDWI(typename N<NumImage>::VectorDiffusionDataImageType::Pointer vectorDWI,
		      const uint32 numImages,
		      char *prefix)
{
  //This line doesn't compile so i switch back to the original code
  //typename N<NumImage>::DiffusionDataImageType::Pointer *dwi = new  typename N<NumImage>::DiffusionDataImageType::Pointer [numImages];
  typename N<NumImage>::DiffusionDataImageType::Pointer dwi[numImages];
  createScalarDWIImages(vectorDWI,dwi,numImages);
  
  typename N<NumImage>::DiffusionDataImageWriterType::Pointer dwiWriter =  N<NumImage>::DiffusionDataImageWriterType::New();
  std::ostringstream ostr;
  
  for(int i=0;i<numImages;i++)
    {
      ostr<<prefix<<std::setfill('0')<<std::setw(3)<<i<<".nrrd";
      std::cout<<"Writing File="<<ostr.str()<<std::endl;
      dwiWriter->SetFileName((ostr.str()).c_str());
      dwiWriter->SetInput(dwi[i]);
      ostr.str("");
      try
	{
	  dwiWriter->Update();
	}
      catch (itk::ExceptionObject e) 
	{
	  std::cerr << e << std::endl;
	  exit(-1);
	}
    }	
  //  delete [] dwi;
}



/*
 * Main Routine to perform
 * dwi filtering by squaring 
 * the dwi images
 */
template< unsigned int NumImage >
int DoIt(int argc, char* argv[])
{
  //Slicer3 argument
  PARSE_ARGS;
  
  if(argc!=10)
    {
      std::cout<<"USAGE:dwiFilter <arguments>"<<std::endl;
      std::cout<<"Arguments:"<<std::endl;
      std::cout<<"1. Input File Name"<<std::endl;
      std::cout<<"2. Output File Name  "<<std::endl;
      std::cout<<"3. NumIterations"<<std::endl;
      std::cout<<"4. Conductance"<<std::endl;
      std::cout<<"5. TimeStep"<<std::endl;
      std::cout<<"6. Filter Type :";
      std::cout<<" (Simple Aniso-0,Chi Squared-1,Rician-2,Gaussian-3)"<<std::endl;
      std::cout<<"7. Sigma for bias correction"<<std::endl; 
      std::cout<<"8. Lamda (Rician Correction Term)"<<std::endl;
      std::cout<<"9. Lamda (Gaussian Correction Term)"<<std::endl; 
      exit(1);
    }
  else
    {
      uint32 numIter=static_cast<uint32>(numIterations) ;
      float conductance=conductanceVal;
      float timeStep=timeStepVal;
      uint32 filterType= filterTypeId;
      float sigma=sigmaVal;
      float lamda1=lamda1Val;
      float lamda2=lamda2Val;
      
      uint32 numImages=NumImage;
      typename N<NumImage>::VectorDiffusionDataImageType::Pointer vectorDWI=readDWI_ImagesREAL<NumImage>(inputdwi);
      
      //Get all the info from the input images
      itk::MetaDataDictionary & dict = vectorDWI->GetMetaDataDictionary();
      
      typename N<NumImage>::VectorDiffusionDataImageType::Pointer smoothedvectorDWI=
	anisoFilter<NumImage>(vectorDWI,
		    numIter,
		    conductance,
		    timeStep,
		    filterType,
		    sigma,
		    lamda1,
		    lamda2); 
      //Copying the dwi information header
      smoothedvectorDWI->SetMetaDataDictionary (vectorDWI->GetMetaDataDictionary() );     
      writeDWI_ImagesREAL<NumImage>(outputdwi,smoothedvectorDWI,dict);     
    }
  
}



int main(int argc, char* argv[])
{ 
  PARSE_ARGS;

  //Read the image in order to get the number of directions.
  typedef itk::VectorImage<float, 3> OriginalImageType;
  typedef itk::ImageFileReader<OriginalImageType> ImageFileReader;
  ImageFileReader::Pointer reader = ImageFileReader::New();
  reader->SetFileName(inputdwi);
  try
    {
      reader->Update();
    }
  catch (itk::ExceptionObject e) 
    {
      std::cerr << e << std::endl;
      exit(-1);
    }
  const int numdir = reader->GetOutput()->GetNumberOfComponentsPerPixel();
  //std::cout << "NUM : " << numdir << std::endl;

  if((numdir != 7) && (numdir != 13) && (numdir != 59))
    {
      std::cout << "Cannot handle " << numdir << " gradient direction images." << std::endl;
      return -1;
    }

  switch(numdir)
    {
    case 7:
      DoIt<7>(argc, argv);
      break;
    case 13:
      DoIt<13>(argc, argv);
      break;
    case 59:
      DoIt<59>(argc, argv);
      break;
    }
  return EXIT_SUCCESS;
}
