#ifndef ImageProcessing_cxx
#define ImageProcessing_cxx

#include "ImageProcessing.h"

#include "itkIdentityTransform.h"
#include "itkLinearInterpolateImageFunction.h"
#include "itkCastImageFilter.h"
#include "itkSliceBySliceImageFilter.h"
#include "itkConvolutionImageFilter.h"
#include "itkAffineTransform.h"
#include "itkNearestNeighborInterpolateImageFunction.h"

#include "itkConnectedThresholdImageFilter.h"
#include "itkBinaryThresholdImageFilter.h"

#include "itkComplexToRealImageFilter.h"


#define PI 3.14159265

/**************************************************************************************************************************/
ImageProcessing::ImageProcessing()
{
	this->Image = ImageType::New();

	this->ITKImage = ImageType::New();
	this->OutputImage = vtkKWImage::New();
	this->ConvolvedImage = DoubleImageType::New();	
	
	this->ImageSize[0] = 300;
	this->ImageSize[1] = 300;
	this->ImageSize[2] = 300;

	this->ImageSpacing[0] = 2;
	this->ImageSpacing[1] = 2;
	this->ImageSpacing[2] = 2;
	
	this->transform = AffineTransformType::New();
	this->transform->SetIdentity();
	this->DirCos = this->transform->GetMatrix();
	this->RegTransform = TransformType::New();
	

}
/**************************************************************************************************************************/
void ImageProcessing::ResampleImage()
{
	//std::ofstream debugStream;
 //   debugStream.open("E:\\Development\\output\\resample.txt");
	
	ResampleFilterType::Pointer resamplefilter = ResampleFilterType::New();

	if (this->DirCos != transform->GetMatrix())
	{
		this->transform->SetMatrix(this->DirCos);
	}

	/*debugStream <<"Direction Cosines"<<std::endl;
	debugStream << this->DirCos <<std::endl;*/

	resamplefilter->SetTransform( this->transform );  
	
	typedef itk::LinearInterpolateImageFunction< ImageType, double>			InterpolatorType;

	InterpolatorType::Pointer interpolator = InterpolatorType::New();	
	resamplefilter->SetInterpolator( interpolator );
	resamplefilter->SetDefaultPixelValue( 0 );

	// Set Image Parameters  Probably put this as user input at some point.
	ImageType::SpacingType outputspacing;
		outputspacing[0] = this->ImageSpacing[0];
		outputspacing[1] = this->ImageSpacing[1];
		outputspacing[2] = this->ImageSpacing[2];
	resamplefilter->SetOutputSpacing( outputspacing );

	ImageType::SizeType outputsize;
		outputsize[0] = this->ImageSize[0]/outputspacing[0];
		outputsize[1] = this->ImageSize[1]/outputspacing[1];
		outputsize[2] = this->ImageSize[2]/outputspacing[2];
	resamplefilter->SetSize(outputsize);
	
	ImageType::PointType outputorigin;
	ImageType::PointType origin;
	
		outputorigin[0] = (int)-(outputsize[0]*outputspacing[0])/2;  //this->ImageSize[0])/2;
		outputorigin[1] = (int)-(outputsize[1]*outputspacing[1])/2; //(this->ImageSize[1])/2;
		outputorigin[2] = (int)-(outputsize[2]*outputspacing[2])/2; //(this->ImageSize[2])/2; 

	//origin = DirCos * outputorigin;

	resamplefilter->SetOutputOrigin( outputorigin );
	resamplefilter->SetInput(this->Image);

	ImageType::DirectionType direction;
	direction = this->Image->GetDirection();
	
//	resamplefilter->SetOutputDirection(direction);

	//*************************************
	//Set output Image
	resamplefilter->Update();
	this->OutputImage->SetITKImageBase(resamplefilter->GetOutput());

	//debugStream.close();
}
/**************************************************************************************************************************/
void ImageProcessing::ExtractGridVoxels(){

	//Function extracts from the anatomical image the volume that is defined by the CSI grid.
	//In order to compensate for the Chemical Shift and Point Spread Function, a extra border set
	//at a percentage of the volume is extracted as well which is hard coded into the function
	
	/*std::ofstream debugStream;
    debugStream.open("E:\\Development\\output\\extractIP.txt");*/
	
	//*************************************************
	//Padding of extracted region around CSI grid
	double extraBorder[3];
	extraBorder[0] = 0 * this->GridSpacing[0]; //Don't need extra in this dim because FOV > VOI
	extraBorder[1] = 0 * this->GridSpacing[1];
	extraBorder[2] = 1.23 * this->GridSpacing[2];
	
	//*************************************************
	//Direction Cosines -> rotation matrix
	MatrixType IDirCos;
	IDirCos = this->DirCos.GetInverse();
	/*debugStream<<this->DirCos<<std::endl;
	debugStream<<IDirCos<<std::endl;*/
	//****************************************************

	//****************************************************
	//Create Extract Filter
	typedef itk::ResampleImageFilter<ImageType, ImageType>				ResampleFilterType;	
	ResampleFilterType::Pointer gridFilter = ResampleFilterType::New();
	
	//Set Transform	
	typedef itk::AffineTransform<double,3>								AffineTransformType;
	AffineTransformType::Pointer transform = AffineTransformType::New();
	transform->SetMatrix(this->DirCos);
	gridFilter->SetTransform(transform);  
	
	//Set Interpolation Method 
	typedef itk::NearestNeighborInterpolateImageFunction< ImageType,double > NNInterpolatorType;
	NNInterpolatorType::Pointer interpolator = NNInterpolatorType::New();
	gridFilter->SetInterpolator( interpolator );
	gridFilter->SetDefaultPixelValue( 0 );
	
	//Set Spacing
	ImageType::SpacingType gridSpacing;
		gridSpacing[0] = 1;
		gridSpacing[1] = 1;
		gridSpacing[2] = 1;
	gridFilter->SetOutputSpacing( gridSpacing );
	
	//Set Size
	ImageType::SizeType outputsize;
		outputsize[0] = (this->GridFOV[0] + extraBorder[0]);
		outputsize[1] = (this->GridFOV[1] + extraBorder[1]);
		outputsize[2] = (this->GridFOV[2] + extraBorder[2]);
	gridFilter->SetSize(outputsize);

	ImageType::PointType temp;
	temp[0] = outputsize[0];
	temp[1] = outputsize[1];
	temp[2] = outputsize[2];

	ImageType::PointType temp2;
	temp2 = this->DirCos * temp;

	//Set Origin
	ImageType::PointType origin;
		origin[0] = this->GridCenter[0] - temp2[0]/2;
		origin[1] = this->GridCenter[1] - temp2[1]/2; 
		origin[2] = this->GridCenter[2] - temp2[2]/2;
	
	ImageType::PointType outputorigin;
		outputorigin = IDirCos * origin;

	gridFilter->SetOutputOrigin( outputorigin );	
	
	//Set Input Image and Update
	gridFilter->SetInput(this->Image );
	gridFilter->Update();

	//*************************************
	//Set output Image
	ImageType::DirectionType direction;
	direction = this->DirCos;

	ImageType::Pointer tempImage = ImageType::New();
	tempImage = gridFilter->GetOutput();
	tempImage->SetDirection(direction);
	tempImage->SetOrigin(origin);
	this->OutputImage->SetITKImageBase( tempImage );
	//debugStream.close();

}
/**************************************************************************************************************************/
void ImageProcessing::Threshold(){

//	vtkErrorMacro("Threshold begins");

	typedef itk::ConnectedThresholdImageFilter<ImageType,ImageType> ConnectedFilterType;

	ConnectedFilterType::Pointer connectedFilter = ConnectedFilterType::New();
	
	PixelType lower = 700;
	PixelType upper = 3000;

	connectedFilter->SetInput(this->Image);

	connectedFilter->SetLower(lower);
	connectedFilter->SetUpper(upper);
	
	ImageType::IndexType seed;
	ImageType::PointType inputOrigin;
	ImageType::SizeType size;

	size = this->Image->GetLargestPossibleRegion().GetSize();
	inputOrigin = this->Image->GetOrigin();
	
	seed[0] = size[0]/2;
	seed[1] = size[1]/2;
	seed[2] = size[2]/2;
	
	connectedFilter->AddSeed(seed);
	connectedFilter->SetReplaceValue(255);
	connectedFilter->Update();

	//ThresholdFilterType::Pointer thresholdfilter = ThresholdFilterType::New();
	//thresholdfilter->SetInput(this->Image);
	//thresholdfilter->SetInsideValue(255);
	//thresholdfilter->SetLowerThreshold(lower);
	//thresholdfilter->SetUpperThreshold(upper);
	//thresholdfilter->Update();

	//std::ofstream debugStream;
 //   debugStream.open("E:\\Development\\output\\threshold.txt");

	//connectedFilter->Print(debugStream, 0);

	ImageType::Pointer thresholdImage = ImageType::New();
	thresholdImage = connectedFilter->GetOutput();
	this->OutputImage->SetITKImageBase( connectedFilter->GetOutput() );

	/*thresholdImage->Print(debugStream, 0);
	debugStream.close();*/
}//end of Threshold
/**************************************************************************************************************************/
void ImageProcessing::MaskFilter()
{
	//typedef itk::MaskImageFilter<ImageType, ImageType> MaskFilterType;
	MaskFilterType::Pointer maskFilter = MaskFilterType::New();
	maskFilter->SetInput1(this->Image);
	maskFilter->SetInput2(this->Mask);
	maskFilter->SetOutsideValue(0);
	maskFilter->Update();

	this->OutputImage->SetITKImageBase( maskFilter->GetOutput()) ;

}
/**************************************************************************************************************************/
void ImageProcessing::SetImage( vtkKWImage * image)
{
	this->Image = GetITK(image);
}
/**************************************************************************************************************************/
void ImageProcessing::SetMask( vtkKWImage * image)
{
	this->Mask = GetITK(image);
}
/**************************************************************************************************************************/
void ImageProcessing::SetImage( ImageType::Pointer image)
{
	this->Image = image;
}

/**************************************************************************************************************************/
void ImageProcessing::SetSize( ImageType::SizeType size)
{
	for (int i=0; i < 3; i++)
	{
		this->ImageSize[i] = size[i];
	}
	
}
/**************************************************************************************************************************/
void ImageProcessing::SetSpacing( ImageType::SpacingType size)
{
	for (int i=0; i < 3; i++)
	{
		this->ImageSpacing[i] = size[i];
	}
	
}

/**************************************************************************************************************************/
void ImageProcessing::SetGrid(vtkStructuredGrid * grid)
{
	this->Grid = grid;
	this->numCells = this->Grid->GetNumberOfCells();
}
/**************************************************************************************************************************/
void ImageProcessing::SetGridSpacing(double * spacing)
{
	this->GridSpacing[0] = spacing[0];
	this->GridSpacing[1] = spacing[1];
	this->GridSpacing[2] = spacing[2];
}
/**************************************************************************************************************************/
void ImageProcessing::SetGridFOV(double * FOV)
{
	this->GridFOV[0] = FOV[0];
	this->GridFOV[1] = FOV[1];
	this->GridFOV[2] = FOV[2];
}
/**************************************************************************************************************************/
void ImageProcessing::SetGridCenter(double * center)
{
	this->GridCenter[0] = center[0];
	this->GridCenter[1] = center[1];
	this->GridCenter[2] = center[2];
}
/**************************************************************************************************************************/
void ImageProcessing::SetGridDirCos(MatrixType dirCos)
{
	this->DirCos = dirCos;
}
/**************************************************************************************************************************/
vtkKWImage * ImageProcessing::GetOutput()
{
	return this->OutputImage;
}
/**************************************************************************************************************************/
ImageType::Pointer ImageProcessing::GetITKOutput()
{
	return this->ITKImage;
}
/**************************************************************************************************************************/
void ImageProcessing::Separate(int label)
{	
	//std::ofstream debugStream;
 //   debugStream.open("E:\\Development\\output\\separate.txt");
	
	typedef itk::BinaryThresholdImageFilter<ImageType,ImageType>		BinaryThresholdFilterType;
	
	BinaryThresholdFilterType::Pointer Classfilter = BinaryThresholdFilterType::New();
	Classfilter->SetInput(this->Image);
	Classfilter->SetInsideValue(255);
	Classfilter->SetLowerThreshold(label);
	Classfilter->SetUpperThreshold(label);
	Classfilter->Update();

	ImageType::Pointer classImage = ImageType::New();
	classImage = Classfilter->GetOutput();
	
	//typedef itk::ImageFileWriter<ImageType>  WriterType;
	//WriterType::Pointer writer = WriterType::New();
	//writer->SetInput(classImage);
	//writer->SetFileName("E:\\Development\\output\\classImage.nii");
	//writer->Update();
	
	this->ITKImage = classImage;

	this->OutputImage->SetITKImageBase(classImage);
	
	//debugStream.close();
}
/**************************************************************************************************************************
ImageType::Pointer ImageProcessing::GetClassImage(int label)
{	
	return Separate(label);
}
/**************************************************************************************************************************/
ImageType::Pointer ImageProcessing::GetITK(vtkKWImage * image)
{	
	ImageType::Pointer output = ImageType::New();

	if (image)
	{
		const ImageType * testImage = static_cast< const ImageType * >(image->GetITKImageBase() );
		output = const_cast<ImageType *>(testImage);
	}
	else 
	{
		output = NULL;
	}
	return output;
}
/**************************************************************************************************************************/
void ImageProcessing::WriteImage()
{	
	/* Write the vtkKWImage */
	vtkKWImageIO * kwWriter = vtkKWImageIO::New();
	kwWriter->SetFileName( "E:/Development/output/imageProcessOut.nii" );
	kwWriter->SetImageToBeWritten( this->OutputImage );  
	try
	{
		kwWriter->WriteImage();
	}
	catch( itk::ExceptionObject & excp )
	{
		std::cerr << excp << std::endl;
	}
}
/**************************************************************************************************************************/
void ImageProcessing::RegisterImages(ImageType::Pointer fixed, ImageType::Pointer moving)
{
	 
  MetricType::Pointer         metric        = MetricType::New();
  OptimizerType::Pointer      optimizer     = OptimizerType::New();
  InterpolatorType::Pointer   interpolator  = InterpolatorType::New();
  RegistrationType::Pointer   registration  = RegistrationType::New();
  
  TransformType::Pointer  transform = TransformType::New();
  transform->SetIdentity();
  
  registration->SetTransform(	  transform		);
  registration->SetMetric(        metric        );
  registration->SetOptimizer(     optimizer     );
  registration->SetInterpolator(  interpolator  );

  metric->SetNumberOfHistogramBins( 20 );                           // Metric specific
  metric->SetNumberOfSpatialSamples( 10000 );                   // Metric specific   
    
  registration->SetFixedImage(    fixed   );
  registration->SetMovingImage(   moving   );
  
  registration->SetFixedImageRegion( fixed->GetBufferedRegion() );

  
 // typedef   RegistrationType::ParametersType     ParametersType;

  //transform->SetIdentity();

  //registration->SetInitialTransformParameters(transform->GetParameters());
  
  //optimizer->SetMaximumStepLength(4.00);
  //optimizer->SetMinimumStepLength( 0.05  );
  //optimizer->SetNumberOfIterations( 200 );
  //optimizer->MaximizeOff();

  typedef itk::CenteredTransformInitializer< TransformType, ImageType, ImageType> TransformInitializerType;
  TransformInitializerType::Pointer initializer = TransformInitializerType::New();

  initializer->SetTransform( transform );
  initializer->SetFixedImage( fixed );
  initializer->SetMovingImage( moving );
  initializer->MomentsOn();
  initializer->InitializeTransform();
  
  registration->SetInitialTransformParameters( transform->GetParameters() );

  double translationScale = 1.0 / 1000.0;

  typedef OptimizerType::ScalesType       OptimizerScalesType;
  OptimizerScalesType optimizerScales( transform->GetNumberOfParameters() );

  optimizerScales[0] =  1.0;
  optimizerScales[1] =  1.0;
  optimizerScales[2] =  1.0;
  optimizerScales[3] =  1.0;
  optimizerScales[4] =  translationScale;
  optimizerScales[5] =  translationScale;

  optimizer->SetScales( optimizerScales );
  
  double steplength = 0.1;
  signed int maxNumberOfIterations = 300;

  optimizer->SetMaximumStepLength( steplength ); 
  optimizer->SetMinimumStepLength( 0.0001 );
  optimizer->SetNumberOfIterations( maxNumberOfIterations );

  optimizer->MinimizeOn();



  try 
    { 
    registration->StartRegistration(); 
  /*  std::cout << "Optimizer stop condition: "
              << registration->GetOptimizer()->GetStopConditionDescription()
              << std::endl;*/
    } 
  catch( itk::ExceptionObject & err ) 
    { 
    std::cerr << "ExceptionObject caught !" << std::endl; 
    std::cerr << err << std::endl; 
    } 
	
	transform->SetParameters(  registration->GetLastTransformParameters()  );

	this->RegTransform = transform;

}
/**************************************************************************************************************************/
TransformType::Pointer ImageProcessing::GetRegistrationTransform()
{

	return this->RegTransform;

}
#endif