
#include "PointSpreadFunction.h"
#include "itkConstantPadImageFilter.h"
#include "itkImageFileWriter.h"

PointSpreadFunction::PointSpreadFunction()
{
	this->KernelChoice = 1; //Default value
	this->numCells = 24; //Default value -- Generalize 

	this->outputImage = PSFImageType::New();
	this->inputImage = PSFImageType::New();

	this->buffer[0] = 0;	this->buffer[1] = 0;	this->buffer[2] = 0;
}
void PointSpreadFunction::Update()
{


	typedef itk::ExtractImageFilter< PSFImageType, ImageSliceType > FilterType;
	FilterType::Pointer filter = FilterType::New();

	PSFImageType::RegionType inputRegion = this->inputImage->GetLargestPossibleRegion();
	PSFImageType::SizeType size = inputRegion.GetSize();
	PSFImageType::IndexType start = inputRegion.GetIndex();

	int plane = 0; //plane of CSI grid is the smallest size of the input region
	
	if ((size[2] < size[1]) && (size[2] < size[0]))
	{
		plane = 2;
	}
	else if ((size[1] < size[2]) && (size[1] < size[0]))
	{
		plane = 1;
	}
	else 
	{
		plane = 0;
	}
	
	//number of slices in the input region
	int maxSlice = size[plane];
	
	//Create output image
	PSFImageType::Pointer imageOut = PSFImageType::New();
	PSFImageType::RegionType outputRegion;
	outputRegion.SetSize(size);
	imageOut->SetRegions(outputRegion);
	imageOut->Allocate();
	imageOut->FillBuffer(0);
	imageOut->SetDirection(inputImage->GetDirection());
	imageOut->SetOrigin(inputImage->GetOrigin());
	
	//Iterate through input image slice by slice and convolve each slice with the PSF
	for (int sliceNumber = 0; sliceNumber < maxSlice; sliceNumber++)
	{
		start[plane] = sliceNumber;
		size[plane] = 0;
		//std::cout<<sliceNumber<<std::endl;
		//Extract each slice
		PSFImageType::RegionType desiredRegion;
		desiredRegion.SetSize(  size  );
		desiredRegion.SetIndex( start );
		filter->SetExtractionRegion( desiredRegion );
		filter->SetInput( inputImage );
		try
		{
			filter->Update();
		}
		catch( itk::ExceptionObject & excp )
		{
			std::cerr << "Error extracting the slice image: " << std::endl;
			std::cerr << excp << std::endl;
		}

		//Send each slice to Convolve fuction
		ImageSliceType::Pointer sliceImage = ImageSliceType::New();
		sliceImage = Convolve(filter->GetOutput());

		//For Debugging and testing purposes, write out only the middle slice results
		if (sliceNumber == maxSlice/2)
		{
			typedef itk::ImageFileWriter<ImageSliceType> SliceWriterType;
			SliceWriterType::Pointer writer = SliceWriterType::New();
			writer->SetFileName("E:/Development/output/slice.nii");
			writer->SetInput(sliceImage);
			try
			{
				writer->Update();
			}
			catch( itk::ExceptionObject & excp )
			{
				std::cerr << "Error writing the slice image: " << std::endl;
				std::cerr << excp << std::endl;
			}
		}


		//Put the convolved slice into the 3D output image in the appropriate slice
		ImageSliceType::IndexType index;
		PSFImageType::IndexType outIndex;

		itk::ImageRegionIterator<ImageSliceType> iter(sliceImage, sliceImage->GetLargestPossibleRegion());
		itk::ImageRegionIterator<PSFImageType> iterOut(imageOut, imageOut->GetLargestPossibleRegion());
		iter.GoToBegin();

		while( !iter.IsAtEnd() ) 
		{
			index = iter.GetIndex();

			if (plane == 0)
			{
				outIndex[0] = sliceNumber;
				outIndex[1] = index[0];
				outIndex[2] = index[1];
			}
			else if (plane == 1)
			{
				outIndex[0] = index[0];
				outIndex[1] = sliceNumber;
				outIndex[2] = index[1];
			}
			else
			{
				outIndex[0] = index[0];
				outIndex[1] = index[1];
				outIndex[2] = sliceNumber;
			}
			
			iterOut.SetIndex(outIndex);

			DoublePixelType pixel( iter.Get() );
			iterOut.Set(pixel);
			++iter; 
		}
	}
	//Because the image is padded during the convoution process, we need to crop it back to its original size
	this->outputImage = CropImage(imageOut);
}

ImageSliceType::Pointer  PointSpreadFunction::Convolve(ImageSliceType::Pointer sliceImage )
{	
	/*******************************************************************/
	// FFT of slice image, results in complex conjugate image
	typedef itk::VnlFFTRealToComplexConjugateImageFilter< DoublePixelType, 2 >  FFTFilterType;
	FFTFilterType::Pointer fftFilter = FFTFilterType::New();
	fftFilter->SetInput( sliceImage );
	
	/*******************************************************************/
	//Seperate real image from complex image
	typedef itk::ComplexToRealImageFilter<PSFComplexImageType,ImageSliceType>  Complex2RealFilterType;
	Complex2RealFilterType::Pointer realFilter = Complex2RealFilterType::New();
	realFilter->SetInput(fftFilter->GetOutput());
	try
	{
		realFilter->Update();
	}
	catch( itk::ExceptionObject & excp )
	{
		std::cerr << "Error writing the real image: " << std::endl;
		std::cerr << excp << std::endl;
	}
	/*******************************************************************/
	//Seperate imaginary image from complex image
	typedef itk::ComplexToImaginaryImageFilter<PSFComplexImageType,ImageSliceType>  Complex2ImagFilterType;
	Complex2ImagFilterType::Pointer imagFilter = Complex2ImagFilterType::New();
	imagFilter->SetInput(fftFilter->GetOutput());
	imagFilter->Update();
	/******************************************************************************************/
	//Shift both images
	typedef itk::FFTShiftImageFilter< ImageSliceType, ImageSliceType > ShiftFilterType;
	ShiftFilterType::Pointer tissueShiftReal = ShiftFilterType::New();
	tissueShiftReal->SetInput( realFilter->GetOutput() );
	tissueShiftReal->InverseOff();

	ShiftFilterType::Pointer tissueShiftImag = ShiftFilterType::New();
	tissueShiftImag->SetInput( imagFilter->GetOutput() );
	tissueShiftImag->InverseOff();

	/******************************************************************************************/
	//Create a complex image of the convolution kernel based on the user's choice or default value
	//Seperate real and imaginary images from the complex kernel image

	PSFComplexImageType::SizeType size;
	size = sliceImage->GetLargestPossibleRegion().GetSize();

	PSFComplexImageType::Pointer kernel = PSFComplexImageType::New();
	kernel = CreateKernel(this->KernelChoice, size);
	
	c2rFilterType::Pointer c2rFilter = c2rFilterType::New();
	c2rFilter->SetInput(kernel);

	c2iFilterType::Pointer c2iFilter = c2iFilterType::New();
	c2iFilter->SetInput(kernel);

	/*******************************************************************/
	//Multiply the FFT images of the kernel and image slice together
	typedef itk::MultiplyImageFilter< ImageSliceType, ImageSliceType, ImageSliceType > MultiplyFilterType;
	
	MultiplyFilterType::Pointer multiplyReal = MultiplyFilterType::New();
	multiplyReal->SetInput1( tissueShiftReal->GetOutput() );
	multiplyReal->SetInput2( c2rFilter->GetOutput() );

	MultiplyFilterType::Pointer multiplyImag = MultiplyFilterType::New();
	multiplyImag->SetInput1( tissueShiftImag->GetOutput()  );
	multiplyImag->SetInput2( c2iFilter->GetOutput() );
	
	/*******************************************************************/
	//Shift the convolved images back
	ShiftFilterType::Pointer shiftReal = ShiftFilterType::New();
	
	shiftReal->SetInput(multiplyReal->GetOutput());
	shiftReal->InverseOn();
	
	ShiftFilterType::Pointer shiftImag = ShiftFilterType::New();
	shiftImag->SetInput(multiplyImag->GetOutput());
	shiftImag->InverseOn();
	
	/*******************************************************************/
	//Recombine the real and imaginary images back together
	RI2CImageFilter::Pointer complexFilter = RI2CImageFilter::New();
	complexFilter->SetInput1( shiftReal->GetOutput() );
	complexFilter->SetInput2( shiftImag->GetOutput() );
	
	try
	{
		complexFilter->Update();
	}
	catch( itk::ExceptionObject & excp )
	{
		std::cerr << "Error writing the real image: " << std::endl;
		std::cerr << excp << std::endl;
	}	
	/*******************************************************************/
	//Inverse FFT 
	typedef itk::VnlFFTComplexConjugateToRealImageFilter<double, 2 >  IFFTFilterType;
	IFFTFilterType::Pointer fftInverseFilter = IFFTFilterType::New();
	fftInverseFilter->SetInput(complexFilter->GetOutput());
	try
	{
		fftInverseFilter->Update();
	}
	catch( itk::ExceptionObject & excp )
	{
		std::cerr << "Error: " << std::endl;
		std::cerr << excp << std::endl;
	}
	/*******************************************************************/
	//Return the convolved image
	return fftInverseFilter->GetOutput();
}
/***********************************************************************/
void PointSpreadFunction::SetImage( vtkKWImage * image)
{
	typedef itk::CastImageFilter<ImageType,PSFImageType> DoubleCastFilterType;
	DoubleCastFilterType::Pointer doubleCast = DoubleCastFilterType::New();
	doubleCast->SetInput(GetITK(image));
	doubleCast->Update();
	//Image needs to be padded with zeros 
	this->inputImage = PadImage(doubleCast->GetOutput());
}
/***********************************************************************/
void PointSpreadFunction::SetKernelChoice(int choice) 
{
	this->KernelChoice = choice;
}
/***********************************************************************/
void PointSpreadFunction::SetPhaseEncodeSteps(int steps) 
{
	this->numCells = steps;
}
/***********************************************************************/
vtkKWImage * PointSpreadFunction::GetOutput()
{
	typedef itk::CastImageFilter<PSFImageType,ImageType> CastFilterType;
	CastFilterType::Pointer cast = CastFilterType::New();
	cast->SetInput(this->outputImage);
	cast->Update();

	vtkKWImage * output = vtkKWImage::New();
	output->SetITKImageBase(cast->GetOutput());
	return output;
}
/***********************************************************************/
PSFComplexImageType::Pointer PointSpreadFunction::CreateKernel(int choice, PSFComplexImageType::SizeType size)
{
	PSFComplexImageType::Pointer complexImage = PSFComplexImageType::New();
	
	PSFComplexImageType::RegionType region;
	region.SetSize(size);
	complexImage->SetRegions(region);
	complexImage->Allocate();
	complexImage->FillBuffer(0);

	//Loop through image to build kernel
	PSFComplexImageType::IndexType index;

	//Try an iterator with an if then statement.  
	itk::ImageRegionIterator<PSFComplexImageType> iter(complexImage, complexImage->GetLargestPossibleRegion());
	iter.GoToBegin();
	//Create kernel based on the number of MR phase encode steps in CSI grid
	int kernelmin = size[1]/2 - this->numCells/2;
	int kernelmax = size[1]/2 + this->numCells/2;

	while( !iter.IsAtEnd() ) 
	{
		index = iter.GetIndex();
		switch ( choice ) 
		{
		case 1 : //Square kernel for full phase encoding
			if ( (kernelmin <= index[0]) && (kernelmax >= index[0]) && (kernelmin <= index[1]) && (kernelmax >= index[1])) 
			{
				PSFComplexPixelType pixel(1.0, 0.0);
				 iter.Set(pixel);
			}
			else
			{
				PSFComplexPixelType pixel(0.0, 0.0);
				iter.Set(pixel);
			}
			break;
		case 2 : //Hamming kernel
			double t, t1,t2,d;
			if ( (kernelmin <= index[0]) && (kernelmax >= index[0]) && (kernelmin <= index[1]) && (kernelmax >= index[1])) 
			{
				double i = static_cast<double>(index[1]) - static_cast<double>(size[1])/2;
				double j = static_cast<double>(index[0]) - static_cast<double>(size[0])/2;
				
				int ks = 32;
				t1=2*(i-(ks/2))/ks;
				t2=2*(j-(ks/2))/ks;				
				
				d=sqrt(i*i+j*j);
				t=(1-(d-0.5))/0.5; 
				if (d > ks/2 ) //&& ( ( abs(i) > ks/2 ) || (abs(j) > ks/2 ) ))
				{
					PSFComplexPixelType pixel(0.08, 0.0);
					 iter.Set( pixel );
				}
				else
				{
					PSFComplexPixelType pixel((0.54 + 0.46*cos(3.14*d/ks) ), 0.0);
					iter.Set( pixel );
				}
			}
			else
			{
				PSFComplexPixelType pixel(0.0, 0.0);
				iter.Set(pixel);
			}
			break;
		}
		++iter; 
	}

	c2rFilterType::Pointer c2rFilter = c2rFilterType::New();
	c2rFilter->SetInput(complexImage);
	c2rFilter->Update();

	//typedef itk::ImageFileWriter<ImageSliceType> SliceWriterType;
	//SliceWriterType::Pointer sliceWriter = SliceWriterType::New();
	//sliceWriter->SetInput(c2rFilter->GetOutput() );
	//sliceWriter->SetFileName("E:/Development/Sandbox/PSF-BUILD/Debug/kernel-512.nii");

	//sliceWriter->Update();


	return complexImage;
}
/***********************************************************************/
ImageType::Pointer PointSpreadFunction::GetITK(vtkKWImage * image)
{	
	ImageType::Pointer output = ImageType::New();

	const ImageType * testImage = static_cast< const ImageType * >(image->GetITKImageBase() );
	output = const_cast<ImageType *>(testImage);
	
	return output;
}
/***********************************************************************/
void PointSpreadFunction::WriteImage()
{	
	typedef itk::CastImageFilter<PSFImageType,ImageType> CastFilterType;
	CastFilterType::Pointer cast = CastFilterType::New();
	cast->SetInput(this->outputImage);
	cast->Update();

	typedef itk::ImageFileWriter<ImageType>    WriterType;
	WriterType::Pointer writer = WriterType::New();
	writer->SetFileName( "E:/Development/output/convolved.nii" );
	writer->SetInput( cast->GetOutput() );  
	try
	{
		writer->Update();
	}
	catch( itk::ExceptionObject & excp )
	{
		std::cerr << excp << std::endl;
	}
}
/***********************************************************************/
PSFImageType::Pointer PointSpreadFunction::PadImage(PSFImageType::Pointer image)
{
	PSFImageType::SizeType size;
	PSFImageType::SizeType padSize;

	size = image->GetLargestPossibleRegion().GetSize();

	padSize[0] = 1;
	padSize[1] = 1;
	padSize[2] = 1;

	while( size[0] > padSize[0] )
	{
		padSize[0] = padSize[0]*2;
	}

	while( size[1] > padSize[1] )
	{
		padSize[1] = padSize[1]*2;
	}
	while( size[2] > padSize[2] )
	{
		padSize[2] = padSize[2]*2;
	}
	
	padSize[0] = padSize[0];
	padSize[1] = padSize[1];
	padSize[2] = padSize[2];

	this->buffer[0] = padSize[0]-size[0]; //0;
	this->buffer[1] = padSize[1]-size[1];
	this->buffer[2] = padSize[2]- size[2];
	
	PSFImageType::Pointer outputImage = PSFImageType::New();
	PSFImageType::RegionType region;
	region.SetSize(padSize);
	outputImage->SetRegions(region);
	outputImage->Allocate();
	outputImage->FillBuffer(0);
	outputImage->SetDirection(image->GetDirection());
	outputImage->SetSpacing(image->GetSpacing());
	outputImage->SetOrigin( image->GetOrigin() );

	itk::ImageRegionIterator<PSFImageType> iter(image, image->GetLargestPossibleRegion());
	iter.GoToBegin();

	itk::ImageRegionIterator<PSFImageType> padIter(outputImage, outputImage->GetLargestPossibleRegion());
	padIter.GoToBegin();
	
	while (!iter.IsAtEnd())
	{
		PSFImageType::IndexType indexIN;
		PSFImageType::IndexType indexOUT;
		indexIN = iter.GetIndex();
		indexOUT[0] = this->buffer[0]/2 + indexIN[0];
		indexOUT[1] = this->buffer[1]/2 + indexIN[1];
		indexOUT[2] = this->buffer[2]/2 + indexIN[2];
		padIter.SetIndex(indexOUT);
		padIter.Set( iter.Get() );
		++iter;
	}

	typedef itk::ImageFileWriter<PSFImageType> PSFWriterType;
	PSFWriterType::Pointer writer = PSFWriterType::New();
	writer->SetInput(outputImage);
	writer->SetFileName("E:/Development/output/padded.nii");
	writer->Update();

	return outputImage;

}
/***********************************************************************/
PSFImageType::Pointer PointSpreadFunction::CropImage(PSFImageType::Pointer image)
{
	PSFImageType::SizeType size;
	PSFImageType::SizeType cropSize;

	size = image->GetLargestPossibleRegion().GetSize();

	cropSize[0] = size[0] - this->buffer[0];
	cropSize[1] = size[1] - this->buffer[1];
	cropSize[2] = size[2] - this->buffer[2];

	PSFImageType::Pointer outputImage = PSFImageType::New();
	PSFImageType::RegionType region;
	region.SetSize(cropSize);
	outputImage->SetRegions(region);
	outputImage->Allocate();
	outputImage->FillBuffer(0);
	outputImage->SetDirection(image->GetDirection());
	outputImage->SetSpacing(image->GetSpacing());
	outputImage->SetOrigin( image->GetOrigin() );

	itk::ImageRegionIterator<PSFImageType> iter(image, image->GetLargestPossibleRegion());
	iter.GoToBegin();

	itk::ImageRegionIterator<PSFImageType> cropIter(outputImage, outputImage->GetLargestPossibleRegion());
	cropIter.GoToBegin();

	while (!cropIter.IsAtEnd())
	{
		PSFImageType::IndexType indexIN;
		PSFImageType::IndexType indexOUT;
		indexOUT = cropIter.GetIndex();
		indexIN[0] = this->buffer[0]/2 + indexOUT[0];
		indexIN[1] = this->buffer[1]/2 + indexOUT[1];
		indexIN[2] = this->buffer[2]/2 + indexOUT[2];
		iter.SetIndex(indexIN);
		cropIter.Set( iter.Get() );
		++cropIter;
	}
	
	return outputImage;

}