/**
 * @file  itkMultiResolutionJointSegmentationRegistration.txx
 * @brief Filter class implementing multi resolution joint segmentation-registration algorithm.
 *
 * Copyright (c) 2011-2014 University of Pennsylvania. All rights reserved.<br />
 * See http://www.cbica.upenn.edu/sbia/software/license.html or COYPING file.
 *
 * Contact: SBIA Group <sbia-software at uphs.upenn.edu>
 */

#ifndef _itkMultiResolutionJointSegmentationRegistration_txx
#define _itkMultiResolutionJointSegmentationRegistration_txx


#include <itkRecursiveGaussianImageFilter.h>
#include <itkRecursiveMultiResolutionPyramidImageFilter.h>
#include <itkImageRegionIterator.h>
#include <vnl/vnl_math.h>

#include "itkMultiResolutionJointSegmentationRegistration.h"


namespace itk {


/*
 * Default constructor
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels> 
::MultiResolutionJointSegmentationRegistration()
{
	unsigned int ilevel;
	int i;

	std::cout << "Multi resolution filter constructor .." << std::endl;

	/** setting the total number of required inputs**/
	this->SetNumberOfRequiredInputs(NumberOfFixedChannels + NumberOfMovingChannels);

#if ITK_VERSION_MAJOR >= 4
	this->RemoveRequiredInputName( "Primary" );
#endif

	/** instantiating a segmentor_registrator filter object **/
	typename DefaultSegmentationRegistrationType::Pointer segmentor_registrator = DefaultSegmentationRegistrationType::New();
	m_SegmentationRegistrationFilter = static_cast<DefaultSegmentationRegistrationType*>(segmentor_registrator.GetPointer());

	/**determining the number of levels **/
	m_NumberOfLevels = 3;
	m_NumberOfIterations.resize(m_NumberOfLevels);

	m_SegmentationRegistrationFilter = DefaultSegmentationRegistrationType::New();

	/**allocating the number of channels in moving and fixed pyramids**/
	m_MovingImagePyramidVector.resize(NumberOfMovingChannels);
	m_FixedImagePyramidVector.resize(NumberOfFixedChannels);

	for (i = 1; i <= NumberOfFixedChannels; i++) {
		m_FixedImagePyramidVector.at(i-1) = FixedImagePyramidType::New();
		m_FixedImagePyramidVector.at(i-1)->SetNumberOfLevels(m_NumberOfLevels);
	}
	for (i = 1; i <= NumberOfMovingChannels; i++) {
		m_MovingImagePyramidVector.at(i-1) = MovingImagePyramidType::New();
		m_MovingImagePyramidVector.at(i-1)->SetNumberOfLevels(m_NumberOfLevels);
	}

	m_FieldExpander = FieldExpanderType::New();
	m_InitialDeformationField = NULL;

	/** initializing the number of levels **/
	for (ilevel = 0; ilevel < m_NumberOfLevels; ilevel++) {
		m_NumberOfIterations[ilevel] = 10;
	}
	m_CurrentLevel = 0;

	m_StopRegistrationFlag = false;
	std::cout << "Multi resolution filter constructor done .." << std::endl;
}

/*
 * Set the moving image image.
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::SetNthMovingImage(const MovingImageType * ptr, unsigned int n)
{
	this->ProcessObject::SetNthInput(NumberOfFixedChannels + n, const_cast<MovingImageType *>(ptr));
}

/*
 * Get the moving image image.
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
const typename MultiResolutionJointSegmentationRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType, VNumberOfFixedChannels,VNumberOfMovingChannels>
::MovingImageType *
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::GetNthMovingImage(unsigned int n) const
{
	return dynamic_cast<const MovingImageType *>(this->ProcessObject::GetInput(NumberOfFixedChannels + n));
}

/*
 * Set the nth fixed image.
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::SetNthFixedImage(const FixedImageType * ptr, unsigned int n)
{
	this->ProcessObject::SetNthInput(n, const_cast<FixedImageType *>(ptr));
}

/*
 * Get the nth fixed image.
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
const typename MultiResolutionJointSegmentationRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType,VNumberOfFixedChannels,VNumberOfMovingChannels>
::FixedImageType *
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::GetNthFixedImage(unsigned int n) const
{
	return dynamic_cast<const FixedImageType *> (this->ProcessObject::GetInput(n));
}

/*
 * Retrives the number of required inputs.
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
std::vector< SmartPointer<DataObject> >::size_type
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::GetNumberOfValidRequiredInputs() const
{
	typename std::vector< SmartPointer<DataObject> >::size_type num = 0;
	int i;

	for (i = 1; i <= NumberOfFixedChannels; i++) {
		if (GetNthFixedImage(i)) {
			num++;
		}
	}
	for (i = 1; i <= NumberOfMovingChannels; i++) {
		if (GetNthMovingImage(i)) {
			num++;
		}
	}

	return num;
}

/*
 * Set the number of multi-resolution levels
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::SetNumberOfLevels(unsigned int num)
{
	int i;

	if (m_NumberOfLevels != num) {
		this->Modified();
		m_NumberOfLevels = num;
		m_NumberOfIterations.resize(m_NumberOfLevels);
	}

	for (i = 1; i <= NumberOfMovingChannels; i++) {
		m_MovingImagePyramidVector.at(i-1)->SetNumberOfLevels(m_NumberOfLevels);
	}

	for (i = 1; i <= NumberOfFixedChannels; i++) {
		m_FixedImagePyramidVector.at(i-1)->SetNumberOfLevels(m_NumberOfLevels);
	}
}

/*
 * Standard PrintSelf method.
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::PrintSelf(std::ostream& os, Indent indent) const
{
	unsigned int ilevel;
	int i;

	Superclass::PrintSelf(os, indent);
	os << indent << "NumberOfLevels: " << m_NumberOfLevels << std::endl;
	os << indent << "CurrentLevel: " << m_CurrentLevel << std::endl;

	os << indent << "NumberOfIterations: [";
	for (ilevel = 0; ilevel < m_NumberOfLevels - 1; ilevel++) {
		os << m_NumberOfIterations[ilevel] << ", ";
	}
	os << m_NumberOfIterations[ilevel] << "]" << std::endl;
	
	os << indent << "SegmentationRegistrationFilter: ";
	os << m_SegmentationRegistrationFilter.GetPointer() << std::endl;
	
	os << indent << "MovingImagePyramidVector: ";
	for (i = 1; i < NumberOfMovingChannels+1; i++) {
		os << m_MovingImagePyramidVector.at(i-1) << std::endl;
	}

	os << indent << "FixedImagePyramid: ";
	for (i = 1; i < NumberOfFixedChannels+1; i++) {
		os << m_FixedImagePyramidVector.at(i-1) << std::endl;
	}

	os << indent << "FieldExpander: ";
	os << m_FieldExpander.GetPointer() << std::endl;

	os << indent << "StopRegistrationFlag: ";
	os << m_StopRegistrationFlag << std::endl;
}

/*
 * Perform a the deformable registration using a multiresolution scheme
 * using an internal mini-pipeline
 *
 *  ref_pyramid ->  segmentor_registrator  ->  field_expander --|| tempField
 * test_pyramid ->           |                              |
 *                           |                              |
 *                           --------------------------------    
 *
 * A tempField image is used to break the cycle between the
 * segmentor_registrator and field_expander.
 *
 */                              
template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void 
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::GenerateData()
{
	int i;
	int dim, idim;

	// Check for NULL images and pointers
	if (!m_SegmentationRegistrationFilter) {
		itkExceptionMacro( << "Registration filter not set" );
	}

	if (this->m_InitialDeformationField && this->GetInput(0)) {
		itkExceptionMacro( << "Only one initial deformation can be given. "
			<< "SetInitialDeformationField should not be used in "
			<< "cunjunction with SetArbitraryInitialDeformationField "
			<< "or SetInput.");
	}

	for (i = 1; i <= NumberOfFixedChannels; i++) {
		FixedImageConstPointer fixedImage = GetNthFixedImage(i);

		if (!fixedImage) {
			itkExceptionMacro( << i << "th fixed image not set" );
		}
		if (!GetNthFixedImagePyramid(i)) {
			itkExceptionMacro( << i << " th fixed pyramid not set" );
		}

		std::cout << "Creating fixed pyramid number: " << i << std::endl;
		GetNthFixedImagePyramid(i)->SetInput(fixedImage);
		GetNthFixedImagePyramid(i)->UpdateLargestPossibleRegion();
	} // fixed for channels

	for (i = 1; i <= NumberOfMovingChannels; i++) {
		MovingImageConstPointer movingImage = GetNthMovingImage(i);

		if (!movingImage) {
			itkExceptionMacro( << i << "th moving images not set" );
		}
		if (!GetNthMovingImagePyramid(i)) {
			itkExceptionMacro( << i << " th moving pyramid not set" );
		}

		std::cout << "Creating moving pyramid number: " << i << std::endl;
		// Create the image pyramids.
		GetNthMovingImagePyramid(i)->SetInput(movingImage);
		GetNthMovingImagePyramid(i)->UpdateLargestPossibleRegion();
	} // for moving channels

	std::cout << "pyramids created." << std::endl;

	// Initializations
	m_CurrentLevel = 0;
	m_StopRegistrationFlag = false;

	unsigned int movingLevel = vnl_math_min((int) m_CurrentLevel, (int) GetNthMovingImagePyramid(1)->GetNumberOfLevels());
	unsigned int fixedLevel  = vnl_math_min((int) m_CurrentLevel, (int) GetNthFixedImagePyramid(1)->GetNumberOfLevels());

	DeformationFieldPointer tempField = NULL;

	DeformationFieldPointer inputPtr = const_cast<DeformationFieldType *>(this->GetInput(0));

	if (this->m_InitialDeformationField) {
		tempField = this->m_InitialDeformationField;
	} else if(inputPtr) {
		// Arbitrary initial deformation field is set.
		// smooth it and resample

		// First smooth it
		tempField = inputPtr;

		typedef RecursiveGaussianImageFilter< DeformationFieldType, DeformationFieldType> GaussianFilterType;
		typename GaussianFilterType::Pointer smoother = GaussianFilterType::New();

		for (dim = 0; dim < DeformationFieldType::ImageDimension; ++dim) {
			// sigma accounts for the subsampling of the pyramid
			double sigma = 0.5 * static_cast<float>(GetNthFixedImagePyramid(1)->GetSchedule()[fixedLevel][dim]);

			// but also for a possible discrepancy in the spacing
			sigma *= GetNthFixedImage(1)->GetSpacing()[dim] / inputPtr->GetSpacing()[dim];

			smoother->SetInput(tempField);
			smoother->SetSigma(sigma);
			smoother->SetDirection(dim);

			smoother->Update();

			tempField = smoother->GetOutput();
			tempField->DisconnectPipeline();
		}

		// Now resample
		m_FieldExpander->SetInput(tempField);

		typename FloatImageType::Pointer fi = GetNthFixedImagePyramid(1)->GetOutput(fixedLevel);
		m_FieldExpander->SetSize(fi->GetLargestPossibleRegion().GetSize());
		m_FieldExpander->SetOutputStartIndex(fi->GetLargestPossibleRegion().GetIndex());
		m_FieldExpander->SetOutputOrigin(fi->GetOrigin());
		m_FieldExpander->SetOutputSpacing(fi->GetSpacing());
		m_FieldExpander->SetOutputDirection(fi->GetDirection());

		m_FieldExpander->UpdateLargestPossibleRegion();
		m_FieldExpander->SetInput(NULL);
		tempField = m_FieldExpander->GetOutput();
		tempField->DisconnectPipeline();
	}

	bool lastShrinkFactorsAllOnes = false;

	while (!this->Halt())
	{
		if (tempField.IsNull()) {
#if ITK_VERSION_MAJOR >= 4
			m_SegmentationRegistrationFilter->SetInitialDisplacementField(NULL);
#else
			m_SegmentationRegistrationFilter->SetInitialDeformationField(NULL);
#endif
		} else {
			// Resample the field to be the same size as the fixed image
			// at the current level
			m_FieldExpander->SetInput(tempField);

			typename FloatImageType::Pointer fi = GetNthFixedImagePyramid(1)->GetOutput(fixedLevel);
			m_FieldExpander->SetSize(fi->GetLargestPossibleRegion().GetSize());
			m_FieldExpander->SetOutputStartIndex(fi->GetLargestPossibleRegion().GetIndex());
			m_FieldExpander->SetOutputOrigin(fi->GetOrigin());
			m_FieldExpander->SetOutputSpacing(fi->GetSpacing());
			m_FieldExpander->SetOutputDirection(fi->GetDirection());

			m_FieldExpander->UpdateLargestPossibleRegion();
			m_FieldExpander->SetInput(NULL);
			tempField = m_FieldExpander->GetOutput();
			tempField->DisconnectPipeline();

#if ITK_VERSION_MAJOR >= 4
			m_SegmentationRegistrationFilter->SetInitialDisplacementField(tempField);
#else
			m_SegmentationRegistrationFilter->SetInitialDeformationField(tempField);
#endif
		}

		for (i = 1; i <= NumberOfFixedChannels; i++) {
			// setup registration filter and pyramids 
			m_SegmentationRegistrationFilter->SetNthFixedImage(GetNthFixedImagePyramid(i)->GetOutput(fixedLevel), i);
		}
		for (i = 1; i <= NumberOfMovingChannels; i++) {
			// setup registration filter and pyramids 
			m_SegmentationRegistrationFilter->SetNthMovingImage(GetNthMovingImagePyramid(i)->GetOutput(movingLevel), i);
		}

		m_SegmentationRegistrationFilter->SetNumberOfIterations(m_NumberOfIterations[m_CurrentLevel]);

		// cache shrink factors for computing the next expand factors.
		lastShrinkFactorsAllOnes = true;
		for (idim = 0; idim < ImageDimension; idim++) {
			if (GetNthFixedImagePyramid(1)->GetSchedule()[fixedLevel][idim] > 1) {
				lastShrinkFactorsAllOnes = false;
				break;
			}
		}

		// compute new deformation field
		std::cout << "start " << std::endl;
		m_SegmentationRegistrationFilter->UpdateLargestPossibleRegion();
		std::cout << "end " << std::endl;
		tempField = m_SegmentationRegistrationFilter->GetOutput();
		tempField->DisconnectPipeline();

		// Increment level counter.  
		m_CurrentLevel++;
		movingLevel = vnl_math_min((int)m_CurrentLevel, (int)GetNthMovingImagePyramid(1)->GetNumberOfLevels());
		fixedLevel  = vnl_math_min((int)m_CurrentLevel, (int)GetNthFixedImagePyramid(1)->GetNumberOfLevels());

		// Invoke an iteration event.
		this->InvokeEvent(IterationEvent());

		// We can release data from pyramid which are no longer required.
		if (movingLevel > 0) {
			for (i = 1; i <= NumberOfMovingChannels; i++) {
				GetNthMovingImagePyramid(i)->GetOutput(movingLevel-1)->ReleaseData();
			}
		}
		if (fixedLevel > 0) {
			for (i = 1; i <= NumberOfFixedChannels; i++) {
				GetNthFixedImagePyramid(i)->GetOutput(fixedLevel-1)->ReleaseData();
			}
		}
	} // while not Halt()

	if (!lastShrinkFactorsAllOnes) {
		// Some of the last shrink factors are not one
		// graft the output of the expander filter to
		// to output of this filter

		// resample the field to the same size as the fixed image
		m_FieldExpander->SetInput(tempField);
		m_FieldExpander->SetSize(GetNthFixedImage(1)->GetLargestPossibleRegion().GetSize());
		m_FieldExpander->SetOutputStartIndex(GetNthFixedImage(1)->GetLargestPossibleRegion().GetIndex());
		m_FieldExpander->SetOutputOrigin(GetNthFixedImage(1)->GetOrigin());
		m_FieldExpander->SetOutputSpacing(GetNthFixedImage(1)->GetSpacing());
		m_FieldExpander->SetOutputDirection(GetNthFixedImage(1)->GetDirection());

		m_FieldExpander->UpdateLargestPossibleRegion();
		this->GraftOutput(m_FieldExpander->GetOutput());
	} else {
		// all the last shrink factors are all ones
		// graft the output of registration filter to
		// to output of this filter
		this->GraftOutput(tempField);
	}

	// Release memory
	m_FieldExpander->SetInput(NULL);
	m_FieldExpander->GetOutput()->ReleaseData();
	m_SegmentationRegistrationFilter->SetInput(NULL);
	m_SegmentationRegistrationFilter->GetOutput()->ReleaseData();
}

template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::StopRegistration()
{
	m_SegmentationRegistrationFilter->StopRegistration();
	m_StopRegistrationFlag = true;
}

template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
bool
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::Halt()
{
	// Halt the registration after the user-specified number of levels
	if (m_NumberOfLevels != 0) {
		this->UpdateProgress(static_cast<float>(m_CurrentLevel) / static_cast<float>(m_NumberOfLevels));
	}
	if (m_CurrentLevel >= m_NumberOfLevels) {
		return true;
	}
	if (m_StopRegistrationFlag) {
		return true;
	} else { 
		return false; 
	}
}

template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::GenerateOutputInformation()
{
	typename DataObject::Pointer output;
	int idx;

	if (this->GetInput(0)) {
		// Initial deformation field is set.
		// Copy information from initial field.
		this->Superclass::GenerateOutputInformation();
	} else if(GetNthFixedImage(1)) {
		// Initial deforamtion field is not set. 
		// Copy information from the fixed image.
		for (idx = 0; idx < this->GetNumberOfOutputs(); ++idx) {
			output = this->GetOutput(idx);
			if (output) {
				output->CopyInformation(GetNthFixedImage(1));
			}  
		}
	}
}

template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::GenerateInputRequestedRegion()
{
	int i;

	// call the superclass's implementation
	Superclass::GenerateInputRequestedRegion();
	// request the largest possible region for the moving image
	for (i = 1; i <= NumberOfMovingChannels; i++) {
		MovingImagePointer movingPtr = const_cast<MovingImageType *>(GetNthMovingImage(i));
		if (movingPtr) {
			movingPtr->SetRequestedRegionToLargestPossibleRegion();
		}
	}

	// just propagate up the output requested region for
	// the fixed image and initial deformation field.
	DeformationFieldPointer inputPtr = const_cast<DeformationFieldType *>(this->GetInput());
	DeformationFieldPointer outputPtr = this->GetOutput();
	if (inputPtr) {
		inputPtr->SetRequestedRegion(outputPtr->GetRequestedRegion());
	}

	for (i = 1; i <= NumberOfFixedChannels; i++) {
		FixedImagePointer fixedPtr = const_cast<FixedImageType *>(GetNthFixedImage(i));

		if (fixedPtr) {
			fixedPtr->SetRequestedRegion(outputPtr->GetRequestedRegion());
		}
	}
}

template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
MultiResolutionJointSegmentationRegistration<TFixedImage, TMovingImage, TDeformationField, TRealType, VNumberOfFixedChannels, VNumberOfMovingChannels>
::EnlargeOutputRequestedRegion(DataObject * ptr)
{
	// call the superclass's implementation
	Superclass::EnlargeOutputRequestedRegion(ptr);

	// set the output requested region to largest possible.
	DeformationFieldType * outputPtr;
	outputPtr = dynamic_cast<DeformationFieldType*>(ptr);

	if (outputPtr) {
		outputPtr->SetRequestedRegionToLargestPossibleRegion();
	}
}


} // end namespace itk


#endif
