/**
 * @file  itkJointSegmentationRegistrationFunction.txx
 * @brief Function class implementing the EM based joint segmentation-registration algorithm.
 *
 * Copyright (c) 2011-2014 University of Pennsylvania. All rights reserved.<br />
 * See http://www.cbica.upenn.edu/sbia/software/license.html or COPYING file.
 *
 * Contact: SBIA Group <sbia-software at uphs.upenn.edu>
 */

#ifndef __itkJointSegmentationRegistrationFunction_txx
#define __itkJointSegmentationRegistrationFunction_txx


#include <itkExceptionObject.h>
#include <vnl/vnl_math.h>
#include <vnl/algo/vnl_qr.h>
#include <itkNiftiImageIO.h>

#include "itkJointSegmentationRegistrationFunction.h"


const double PI = 3.1415926535897932384626433832795;
const double eps = 1e-8;
const double epss = 1e-32;
const double INITIAL_MEAN_VALUE = 1e32;


namespace itk {


//////////////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////////////

/****************************************************************************/
template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
JointSegmentationRegistrationFunction<TFixedImage, TMovingImage, TDeformationField, VNumberOfFixedChannels, VNumberOfMovingChannels>
::JointSegmentationRegistrationFunction()
{
	int i, j;
	RadiusType r;

	for (j = 0; j < ImageDimension; j++) {
		r[j] = 0;
	}
	this->SetRadius(r);

	m_TimeStep = 1.0;
	// this is a conservative choice //
	SetSigma2(1.0);
	SetDeltaSigma2(0.0);

	m_bEstimateTumorRegionOnly = false;
	m_fTumorRegionThreshold = 10e-6;

	std::cout << "creating moving and fixed image vector objects .." << std::endl;
	m_FixedImageVector.resize(NumberOfFixedChannels, NULL);
	m_MovingImageVector.resize(NumberOfMovingChannels, NULL);
	m_MovingImageInterpolatorVector.resize(NumberOfMovingChannels, NULL);
	m_MovingImageWarperVector.resize(NumberOfMovingChannels, NULL);

	std::cout << "creating weight image vector objects .." << std::endl;
	m_WeightImageVector.resize(NumberOfMovingChannels,NULL);

	for (i = 1; i <= NumberOfMovingChannels; i++) {
		m_WeightImageVector.at(i-1) = WeightImageType::New();
	}
	std::cout << "populating the weight image vector done." << std::endl; 

	/** initializing mean and variance vectors **/
	for (i = 1; i <= NumberOfMovingChannels; i++) {
		MeanType  v(INITIAL_MEAN_VALUE);
		GetMeanVector()->push_back(v); 

		VarianceType s;
		s.set_identity();
		GetVarianceVector()->push_back(s);
	}
	std::cout << "populating the means and covarince vectors done." << std::endl; 

	/* we dont have information on input images at this point so:*/
	m_FixedImageSpacing.Fill(1.0);
	m_FixedImageOrigin.Fill(0.0);
	m_FixedImageDirection.SetIdentity();
	m_Normalizer = 0.0;
	
	m_Metric = NumericTraits<double>::max();
	m_NumberOfPixelsProcessed = 0L;
	m_RMSChange = NumericTraits<double>::max();
}

//////////////////////////////////////////////////////////////////////////////
// Debugging
//////////////////////////////////////////////////////////////////////////////

/****************************************************************************/
template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void JointSegmentationRegistrationFunction<TFixedImage, TMovingImage, TDeformationField, VNumberOfFixedChannels, VNumberOfMovingChannels>
::LogVariables()
{
	std::ofstream fout; // output file
	fout.open("s_variables_log.txt", std::ios::out | std::ios::app);
	fout << "iteration:" << this->GetNumberOfElapsedIterations() << "   metric:" << GetMetric()  << std::endl;
	fout.close();

	fout.open("s_cost.txt");
	fout << GetMetric() << std::endl;
	fout.close();
}

//////////////////////////////////////////////////////////////////////////////
//Computes the posteriors  
//////////////////////////////////////////////////////////////////////////////

static
inline void ComputeLikelihood4(double m[16], double ym[4], double* like) {
	double inv[16], det, det_1, ySIy;
	
	inv[ 0] =  m[ 5] * m[10] * m[15] - m[ 5] * m[11] * m[14] - m[ 9] * m[ 6] * m[15] + m[ 9] * m[ 7] * m[14] + m[13] * m[ 6] * m[11] - m[13] * m[ 7] * m[10];
	inv[ 4] = -m[ 4] * m[10] * m[15] + m[ 4] * m[11] * m[14] + m[ 8] * m[ 6] * m[15] - m[ 8] * m[ 7] * m[14] - m[12] * m[ 6] * m[11] + m[12] * m[ 7] * m[10];
	inv[ 8] =  m[ 4] * m[ 9] * m[15] - m[ 4] * m[11] * m[13] - m[ 8] * m[ 5] * m[15] + m[ 8] * m[ 7] * m[13] + m[12] * m[ 5] * m[11] - m[12] * m[ 7] * m[ 9];
	inv[12] = -m[ 4] * m[ 9] * m[14] + m[ 4] * m[10] * m[13] + m[ 8] * m[ 5] * m[14] - m[ 8] * m[ 6] * m[13] - m[12] * m[ 5] * m[10] + m[12] * m[ 6] * m[ 9];
	inv[ 1] = -m[ 1] * m[10] * m[15] + m[ 1] * m[11] * m[14] + m[ 9] * m[ 2] * m[15] - m[ 9] * m[ 3] * m[14] - m[13] * m[ 2] * m[11] + m[13] * m[ 3] * m[10];
	inv[ 5] =  m[ 0] * m[10] * m[15] - m[ 0] * m[11] * m[14] - m[ 8] * m[ 2] * m[15] + m[ 8] * m[ 3] * m[14] + m[12] * m[ 2] * m[11] - m[12] * m[ 3] * m[10];
	inv[ 9] = -m[ 0] * m[ 9] * m[15] + m[ 0] * m[11] * m[13] + m[ 8] * m[ 1] * m[15] - m[ 8] * m[ 3] * m[13] - m[12] * m[ 1] * m[11] + m[12] * m[ 3] * m[ 9];
	inv[13] =  m[ 0] * m[ 9] * m[14] - m[ 0] * m[10] * m[13] - m[ 8] * m[ 1] * m[14] + m[ 8] * m[ 2] * m[13] + m[12] * m[ 1] * m[10] - m[12] * m[ 2] * m[ 9];
	inv[ 2] =  m[ 1] * m[ 6] * m[15] - m[ 1] * m[ 7] * m[14] - m[ 5] * m[ 2] * m[15] + m[ 5] * m[ 3] * m[14] + m[13] * m[ 2] * m[ 7] - m[13] * m[ 3] * m[ 6];   
	inv[ 6] = -m[ 0] * m[ 6] * m[15] + m[ 0] * m[ 7] * m[14] + m[ 4] * m[ 2] * m[15] - m[ 4] * m[ 3] * m[14] - m[12] * m[ 2] * m[ 7] + m[12] * m[ 3] * m[ 6];  
	inv[10] =  m[ 0] * m[ 5] * m[15] - m[ 0] * m[ 7] * m[13] - m[ 4] * m[ 1] * m[15] + m[ 4] * m[ 3] * m[13] + m[12] * m[ 1] * m[ 7] - m[12] * m[ 3] * m[ 5];   
	inv[14] = -m[ 0] * m[ 5] * m[14] + m[ 0] * m[ 6] * m[13] + m[ 4] * m[ 1] * m[14] - m[ 4] * m[ 2] * m[13] - m[12] * m[ 1] * m[ 6] + m[12] * m[ 2] * m[ 5];   
	inv[ 3] = -m[ 1] * m[ 6] * m[11] + m[ 1] * m[ 7] * m[10] + m[ 5] * m[ 2] * m[11] - m[ 5] * m[ 3] * m[10] - m[ 9] * m[ 2] * m[ 7] + m[ 9] * m[ 3] * m[ 6];  
	inv[ 7] =  m[ 0] * m[ 6] * m[11] - m[ 0] * m[ 7] * m[10] - m[ 4] * m[ 2] * m[11] + m[ 4] * m[ 3] * m[10] + m[ 8] * m[ 2] * m[ 7] - m[ 8] * m[ 3] * m[ 6];  
	inv[11] = -m[ 0] * m[ 5] * m[11] + m[ 0] * m[ 7] * m[ 9] + m[ 4] * m[ 1] * m[11] - m[ 4] * m[ 3] * m[ 9] - m[ 8] * m[ 1] * m[ 7] + m[ 8] * m[ 3] * m[ 5];   
	inv[15] =  m[ 0] * m[ 5] * m[10] - m[ 0] * m[ 6] * m[ 9] - m[ 4] * m[ 1] * m[10] + m[ 4] * m[ 2] * m[ 9] + m[ 8] * m[ 1] * m[ 6] - m[ 8] * m[ 2] * m[ 5];

	det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12] + epss;  
	
	det_1 = 1.0 / det;

	ySIy = ym[0] * (inv[ 0] * ym[0] + inv[ 1] * ym[1] + inv[ 2] * ym[2] + inv[ 3] * ym[3]) +
	       ym[1] * (inv[ 4] * ym[0] + inv[ 5] * ym[1] + inv[ 6] * ym[2] + inv[ 7] * ym[3]) +
	       ym[2] * (inv[ 8] * ym[0] + inv[ 9] * ym[1] + inv[10] * ym[2] + inv[11] * ym[3]) +
	       ym[3] * (inv[12] * ym[0] + inv[13] * ym[1] + inv[14] * ym[2] + inv[15] * ym[3]);
	ySIy *= det_1;

	*like = 1.0 / vcl_sqrt(2.0 * PI * det) * exp (-0.5 * ySIy);
}

static
inline void SolveLinearSystem3(double m[9], double b[3], double* x) {
	double inv[9], det, det_1;

	inv[0] =  (m[4] * m[8] - m[7] * m[5]);
	inv[1] = -(m[3] * m[8] - m[5] * m[6]);
	inv[2] =  (m[3] * m[7] - m[4] * m[6]);
	inv[3] = -(m[1] * m[8] - m[2] * m[7]);
	inv[4] =  (m[0] * m[8] - m[2] * m[6]);
	inv[5] = -(m[0] * m[7] - m[1] * m[6]);
	inv[6] =  (m[1] * m[5] - m[2] * m[4]);
	inv[7] = -(m[0] * m[5] - m[2] * m[3]);
	inv[8] =  (m[0] * m[4] - m[1] * m[3]);

	det = m[0] * inv[0] - m[1] * inv[1] + m[2] * inv[2];

	if (det == 0) {
		x[0] = x[1] = x[2] = 0;
		return;
	} else {
		det_1 = 1.0 / det;
	}

	x[0] = inv[0] * b[0] + inv[1] * b[1] + inv[2] * b[2];
	x[1] = inv[3] * b[0] + inv[4] * b[1] + inv[5] * b[2];
	x[2] = inv[6] * b[0] + inv[7] * b[1] + inv[8] * b[2];

	x[0] *= det_1;
	x[1] *= det_1;
	x[2] *= det_1;
}

#define USE_FAST_LIKELIHOOD

/****************************************************************************/
template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void JointSegmentationRegistrationFunction<TFixedImage, TMovingImage, TDeformationField, VNumberOfFixedChannels, VNumberOfMovingChannels>
::ComputeWeightImages()
{
#ifdef USE_FAST_LIKELIHOOD
	double vv[NumberOfMovingChannels][16];
	double mv[NumberOfMovingChannels][4];
	int i, j, k;

	// assume 4 channels for image
	if (NumberOfFixedChannels != 4) {
		std::cout << "NumberOfFixedChannels must be 4 to use fast likelihood" << std::endl;
		return;
	}

	for (i = 0; i < NumberOfMovingChannels; i++) {
		for (j = 0; j < 4; j++) {
			for (k = 0; k < 4; k++) {
				vv[i][j*4+k] = GetVarianceVector()->at(i)(j, k);
			}
			mv[i][j] = GetMeanVector()->at(i)(j);
		}
	}
#endif

	// define some iterators over moving images (warped priors)
	std::vector<MovingImageConstIteratorType> m_it_vect;
	for (i = 1; i < NumberOfMovingChannels + 1; i++) {
		MovingImageConstIteratorType m_it(GetNthImageWarper(i)->GetOutput(),
			GetNthImageWarper(i)->GetOutput()->GetLargestPossibleRegion());	
		m_it.GoToBegin();
		m_it_vect.push_back(m_it);
	}
	std::vector<FixedImageConstIteratorType> f_it_vect;
	for (i = 1; i < NumberOfFixedChannels + 1; i++) {
		FixedImageConstIteratorType f_it(GetNthFixedImage(i), GetNthFixedImage(i)->GetLargestPossibleRegion());
		f_it.GoToBegin();
		f_it_vect.push_back(f_it);
	}
	// define iterators over weight images
	std::vector<WeightImageIteratorType> w_it_vect;
	for (i = 1; i < NumberOfMovingChannels + 1; i++) {
		WeightImageIteratorType w_it(GetNthWeightImage(i), GetNthWeightImage(i)->GetLargestPossibleRegion());
		w_it.GoToBegin();
		w_it_vect.push_back(w_it);
	}

	// looooping through all voxels
	while (!w_it_vect.at(0).IsAtEnd()) {
		double prior[NumberOfMovingChannels];
		double like[NumberOfMovingChannels];
#ifdef USE_FAST_LIKELIHOOD
		double y[4], ym[4];
		double ys = 0;
#else
		MeanType y;
#endif

		// making a vnl vector from fixed images
		for (i = 0; i < NumberOfFixedChannels; i++) {
#ifdef USE_FAST_LIKELIHOOD
			y[i] = f_it_vect.at(i).Get();
			ys += y[i];
#else
			y(i) = f_it_vect.at(i).Get();
#endif
		}

		// read priors
		for (i = 0; i < NumberOfMovingChannels; i++) {
			prior[i] = m_it_vect.at(i).Get();
		}

		// computing likelihood values
#ifdef USE_FAST_LIKELIHOOD
		if (ys > 0) {
#else
		if (!y.is_zero() {
#endif
			// to the number of classes
			for (i = 0; i < NumberOfMovingChannels; i++) {
#ifdef USE_FAST_LIKELIHOOD
				ym[0] = y[0] - mv[i][0]; 
				ym[1] = y[1] - mv[i][1]; 
				ym[2] = y[2] - mv[i][2]; 
				ym[3] = y[3] - mv[i][3]; 
				ComputeLikelihood4(vv[i], ym, &like[i]);
#else
				MeanType SIy = vnl_qr<double>(GetVarianceVector()->at(i)).solve( y - GetMeanVector()->at(i));
				double ySIy = dot_product(y - GetMeanVector()->at(i), SIy);
				double detS = vnl_determinant(GetVarianceVector()->at(i))+eps;
				like[i] =  1.0/vcl_sqrt(2.0*PI*detS)*exp(-.5*ySIy);
#endif
			}
			// compute the denum
			double denum = epss;
			for (i = 0; i < NumberOfMovingChannels; i++) {
				denum += prior[i] * like[i];
			}
			// compute the posterior
			for (i = 0; i < NumberOfMovingChannels; i++) {
				w_it_vect.at(i).Set(prior[i] * like[i] / denum);
			}
		} else {
			// compute the posterior
			for (i = 0; i < NumberOfMovingChannels; i++) {
				w_it_vect.at(i).Set(0.0);
			}
#ifdef USE_WARPED_BG
			w_it_vect.at(BG).Set(1.0);
#endif
		}      

		// forwarding the iterators
		for (i = 0; i < NumberOfFixedChannels; i++) {
			++f_it_vect.at(i);
		}
		for (i = 0; i < NumberOfMovingChannels; i++) {
			++w_it_vect.at(i);
			++m_it_vect.at(i);
		}
	} 

	// releasing memory 
	for (i = 0; i < NumberOfFixedChannels; i++) {
		f_it_vect.pop_back();
	}
	for (i = 0; i < NumberOfMovingChannels; i++) {
		w_it_vect.pop_back();
		m_it_vect.pop_back();
	}
}

//////////////////////////////////////////////////////////////////////////////
//Updates the means and variances  
//////////////////////////////////////////////////////////////////////////////

/****************************************************************************/
template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
JointSegmentationRegistrationFunction<TFixedImage, TMovingImage, TDeformationField, VNumberOfFixedChannels, VNumberOfMovingChannels>
::UpdateMeansAndVariances(void)
{
#if 0
	int i, j;

	// define some iterators over fixed images
	std::vector<FixedImageConstIteratorType> f_it_vect;

	for (i = 1; i <= NumberOfFixedChannels; i++) {
		FixedImageConstIteratorType f_it(GetNthFixedImage(i), GetNthFixedImage(i)->GetLargestPossibleRegion());
		f_it.GoToBegin();
		f_it_vect.push_back(f_it);
	}

	// define some iterators over weight images
	std::vector<WeightImageIteratorType> w_it_vect;
	for (i = 1; i <= NumberOfMovingChannels; i++) {
		WeightImageIteratorType w_it(GetNthWeightImage(i), GetNthWeightImage(i)->GetLargestPossibleRegion());
		w_it.GoToBegin();
		w_it_vect.push_back(w_it);
	}

	// define a vector object to keep sum of weighted vectors for each class
	std::vector< vnl_vector_fixed<double, NumberOfFixedChannels> > weighted_vect_sums;
	for (i = 0; i < NumberOfMovingChannels; i++) {
		vnl_vector_fixed<double, NumberOfFixedChannels> m(eps);
		weighted_vect_sums.push_back(m);
	}

	// define a vector object to keep sum of the weights for each class
	std::vector<double> weight_sums;
	for (i = 0; i < NumberOfMovingChannels; i++) {
		double s = eps;
		weight_sums.push_back(s);
	}

	// sum loop for means
	while (!w_it_vect.at(0).IsAtEnd()) {
		// making a vnl vector from fixed images
		vnl_vector_fixed<double,NumberOfFixedChannels> y;
		for (i = 0; i < NumberOfFixedChannels; i++) {
			y(i) = f_it_vect.at(i).Get();
		}

		if (!y.is_zero()) {
			// to the number of classes
			for (i = 0; i < NumberOfMovingChannels; i++) {
				weighted_vect_sums.at(i) =  weighted_vect_sums.at(i) + (double)w_it_vect.at(i).Get()*y;
				weight_sums.at(i) = weight_sums.at(i) + w_it_vect.at(i).Get();
			}
		}

		// now forwarding all iterators
		for (i = 0; i < NumberOfFixedChannels; i++) {
			++f_it_vect.at(i);
		}
		for (i = 0; i < NumberOfMovingChannels; i++) {
			++w_it_vect.at(i);
		}
	}

	// computing updated means
	// to the number of classes
	for (i = 0; i < NumberOfMovingChannels; i++) {
		MeanType m = (1.0/ (weight_sums.at(i)+eps))*weighted_vect_sums.at(i) + eps;

		if (this->GetNumberOfElapsedIterations()) {
			// in the first
			// iteration we want to
			// favor user supplied values
			GetMeanVector()->at(i) = m;
		} else {
			for (j = 0; j < NumberOfFixedChannels; j++) {
				if (GetMeanVector()->at(i).get(j) == INITIAL_MEAN_VALUE) {
					GetMeanVector()->at(i).put(j, m[j]);
				}
			}
		}
	}  
	
	// now compute the covariances
	// first sending iterator to begining
	// to the number of the fixed images
	for (i = 0; i < NumberOfFixedChannels; i++) {
		f_it_vect.at(i).GoToBegin();
	}
	// to the number of the moving images
	for (i = 0; i < NumberOfMovingChannels; i++) {
		w_it_vect.at(i).GoToBegin();
	}

	// define a vector object to keep sum of weighted outer products for each class
	std::vector<VarianceType> weighted_outer_prod_sums;
	for (i = 0; i < NumberOfMovingChannels; i++) {
		VarianceType s;
		s.set_identity();
		s *= eps;
		weighted_outer_prod_sums.push_back(s);
	}

	// sum loop for variances
	while (!w_it_vect.at(0).IsAtEnd()) {
		// making a vnl vector from fixed images
		MeanType y;
		for (i = 0; i < NumberOfFixedChannels; i++) {
			y(i) = f_it_vect.at(i).Get();
		}

		if (!y.is_zero()) {
			// to the number of classes
			for (i = 0; i < NumberOfMovingChannels; i++) {
				weighted_outer_prod_sums.at(i) =  weighted_outer_prod_sums.at(i) + 
					(double)w_it_vect.at(i).Get()*outer_product(y-GetMeanVector()->at(i),y-GetMeanVector()->at(i));
			}
		}

		// now forwarding all iterators
		for (i = 0; i < NumberOfFixedChannels; i++) {
			++f_it_vect.at(i);
		}
		for (i = 0; i < NumberOfMovingChannels; i++) {
			++w_it_vect.at(i);
		}
	}

	// computing updated variances
	// to the number of classes
	for (i = 0; i < NumberOfMovingChannels; i++) {
		GetVarianceVector()->at(i) = (1.0/ (weight_sums.at(i)+eps))*weighted_outer_prod_sums.at(i);
	}
#else
	// define a vector object to keep sum of weighted vectors for each class
	double mv_sum[NumberOfMovingChannels][NumberOfFixedChannels];
	// define a vector object to keep sum of weighted outer products for each class
	double vv_sum[NumberOfMovingChannels][NumberOfFixedChannels][NumberOfFixedChannels];
	// define a vector object to keep sum of the weights for each class
	double w_sum[NumberOfMovingChannels];
	double mv[NumberOfMovingChannels][NumberOfFixedChannels];
	int i, j, k;
	
	// define some iterators over fixed images
	std::vector<FixedImageConstIteratorType> f_it_vect;
	for (i = 1; i <= NumberOfFixedChannels; i++) {
		FixedImageConstIteratorType f_it(GetNthFixedImage(i), GetNthFixedImage(i)->GetLargestPossibleRegion());
		f_it.GoToBegin();
		f_it_vect.push_back(f_it);
	}

	// define some iterators over weight images
	std::vector<WeightImageIteratorType> w_it_vect;
	for (i = 1; i <= NumberOfMovingChannels; i++) {
		WeightImageIteratorType w_it(GetNthWeightImage(i), GetNthWeightImage(i)->GetLargestPossibleRegion());
		w_it.GoToBegin();
		w_it_vect.push_back(w_it);
	}

	for (k = 0; k < NumberOfMovingChannels; k++) {
		w_sum[k] = eps;
		for (j = 0; j < NumberOfFixedChannels; j++) {
			mv_sum[k][j] = eps;
			for (i = 0; i < NumberOfFixedChannels; i++) {
				if (i == j) {
					vv_sum[k][j][i] = eps;
				} else {
					vv_sum[k][j][i] = 0;
				}
			}
		}
	}

	// sum loop for means
	while (!w_it_vect.at(0).IsAtEnd()) {
		double y[NumberOfFixedChannels];
		double ys = 0;
		//
		for (i = 0; i < NumberOfFixedChannels; i++) {
			y[i] = f_it_vect.at(i).Get();
			ys += y[i];
		}
		if (ys > 0) {
			// to the number of classes
			for (k = 0; k < NumberOfMovingChannels; k++) {
				double p = (double)w_it_vect.at(k).Get();
				for (i = 0; i < NumberOfFixedChannels; i++) {
					mv_sum[k][i] += p * y[i];
				}
				w_sum[k] += p;
			}
		}
		// now forwarding all iterators
		for (i = 0; i < NumberOfFixedChannels; i++) {
			++f_it_vect.at(i);
		}
		for (i = 0; i < NumberOfMovingChannels; i++) {
			++w_it_vect.at(i);
		}
	}

	// computing updated means to the number of classes
	if (this->GetNumberOfElapsedIterations()) {
		for (k = 0; k < NumberOfMovingChannels; k++) {
			double w_1 = 1.0 / (w_sum[k] + eps);
			for (i = 0; i < NumberOfFixedChannels; i++) {
				mv[k][i] = w_1 * mv_sum[k][i] + eps;
				GetMeanVector()->at(k)[i] = mv[k][i];
			}
		}
	} else {
		// in the first iteration we want to favor user supplied values
		for (k = 0; k < NumberOfMovingChannels; k++) {
			double w_1 = 1.0 / (w_sum[k] + eps);
			for (i = 0; i < NumberOfFixedChannels; i++) {
				mv[k][i] = w_1 * mv_sum[k][i] + eps;
				if (GetMeanVector()->at(k)[i] == INITIAL_MEAN_VALUE) {
					GetMeanVector()->at(k)[i] = mv[k][i];
				} else {
					mv[k][i] = GetMeanVector()->at(k)[i];
				}
			}
		}
	}
	
	// now compute the covariances
	// first sending iterator to begining
	// to the number of the fixed images
	for (i = 0; i < NumberOfFixedChannels; i++) {
		f_it_vect.at(i).GoToBegin();
	}
	// to the number of the moving images
	for (i = 0; i < NumberOfMovingChannels; i++) {
		w_it_vect.at(i).GoToBegin();
	}

	// sum loop for variances
	while (!w_it_vect.at(0).IsAtEnd()) {
		double y[NumberOfFixedChannels];
		double ym[NumberOfFixedChannels];
		double ys = 0;
		//
		for (i = 0; i < NumberOfFixedChannels; i++) {
			y[i] = f_it_vect.at(i).Get();
			ys += y[i];
		}
		if (ys > 0) {
			// to the number of classes
			for (k = 0; k < NumberOfMovingChannels; k++) {
				double p = (double)w_it_vect.at(k).Get();
				for (i = 0; i < NumberOfFixedChannels; i++) {
					ym[i] = y[i] - mv[k][i];
				}
				for (j = 0; j < NumberOfFixedChannels; j++) {
					for (i = 0; i < NumberOfFixedChannels; i++) {
						vv_sum[k][j][i] += p * ym[j] * ym[i];
					}
				}
			}
		}
		// now forwarding all iterators
		for (i = 0; i < NumberOfFixedChannels; i++) {
			++f_it_vect.at(i);
		}
		for (i = 0; i < NumberOfMovingChannels; i++) {
			++w_it_vect.at(i);
		}
	}

	// computing updated variances to the number of classes
	for (k = 0; k < NumberOfMovingChannels; k++) {
		double w_1 = 1.0 / (w_sum[k] + eps);
		for (j = 0; j < NumberOfFixedChannels; j++) {
			for (i = 0; i < NumberOfFixedChannels; i++) {
				GetVarianceVector()->at(k)[j][i] = w_1 * vv_sum[k][j][i];
			}
		}
	}
#endif
}

/*
 * Standard "PrintSelf" method.
 */
template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
JointSegmentationRegistrationFunction<TFixedImage,TMovingImage,TDeformationField,VNumberOfFixedChannels,VNumberOfMovingChannels>
::PrintSelf(std::ostream& os, Indent indent) const
{
	os << indent << "Not implemented yet!" << std::endl;
}

//////////////////////////////////////////////////////////////////////////////
//Preparing the function before each iteration
//////////////////////////////////////////////////////////////////////////////

#ifdef USE_11_PRIORS
// BG, CSF, VT, GM, WM, VS (VEIN), ED, NCR (CAN), TU (CAE), RTN, RTE
static const char label[NumberOfPriorChannels][32] = { "BG", "CSF", "VT", "GM", "WM", "VS", "ED", "NCR", "TU", "RTN", "RTE" };
static const char label2[NumberOfPriorChannels][32] = { "BG", "CSF", "VT", "GM", "WM", "VEIN", "ED", "CAN", "CAE", "RTN", "RTE" };
static const char label3[NumberOfPriorChannels][32] = { "BG", "CSF", "VT", "GM", "WM", "VEINS", "ED", "CAN", "CAE", "RTN", "RTE" };
static int label_idx[NumberOfPriorChannels] = { 0, 10, 50, 150, 250, 25, 100, 175, 200, 210, 220 };
static int label_s_idx[NumberOfPriorChannels] = { 0, 10, 50, 150, 250, 10, 250, 200, 200, 200, 200 };
//static int label_s_idx[NumberOfPriorChannels] = { 0, 10, 10, 150, 250, 10, 250, 250, 250, 250, 250 };
enum {
	BG = 0,
    CSF,
    VT,
    GM,
    WM,
    VS,
    ED,
    NCR,
    TU,
	RTN,
	RTE
};
#endif


template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
BOOL
JointSegmentationRegistrationFunction<TFixedImage,TMovingImage,TDeformationField,VNumberOfFixedChannels,VNumberOfMovingChannels>
::SaveMeansAndVariances(const char* means_file, const char* variances_file, MeanVectorType* pMeanVector, VarianceVectorType* pVarianceVector)
{
	FILE* fp;
	int i, j, k;

	////////////////////////////////////////////////////////////////////////////////
	// save means
	fp = fopen(means_file, "w");
	if (fp == NULL) {
		TRACE("Failed to open file: '%s'\n", means_file);
		return FALSE;
	}
	for(i = 0; i < VNumberOfMovingChannels; i++) {
		MeanType mean;
		//
		mean = pMeanVector->at(i);
		//
		fprintf(fp, "%s\n", label[i]);
		//
		for (j = 0; j < VNumberOfFixedChannels; j++) {
			fprintf(fp, "%f", mean(j));
			if (j == VNumberOfFixedChannels-1) {
				fprintf(fp, "\n");
			} else {
				fprintf(fp, " ");
			}
		}
	}
	fclose(fp);
	////////////////////////////////////////////////////////////////////////////////

	////////////////////////////////////////////////////////////////////////////////
	// save variances
	fp = fopen(variances_file, "w");
	if (fp == NULL) {
		TRACE("Failed to open file: '%s'\n", variances_file);
		return FALSE;
	}
	for(i = 0; i < VNumberOfMovingChannels; i++) {
		VarianceType var;
		//
		var = pVarianceVector->at(i);
		//
		fprintf(fp, "%s\n", label[i]);
		//
		for (j = 0; j < VNumberOfFixedChannels; j++) {
			for (k = 0; k < VNumberOfFixedChannels; k++) {
				fprintf(fp, "%f", var(j, k));
				if (k == VNumberOfFixedChannels-1) {
					fprintf(fp, "\n");
				} else {
					fprintf(fp, " ");
				}
			}
		}
	}
	fclose(fp);
	////////////////////////////////////////////////////////////////////////////////

	return TRUE;
}

/****************************************************************************/
template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
JointSegmentationRegistrationFunction<TFixedImage,TMovingImage,TDeformationField,VNumberOfFixedChannels,VNumberOfMovingChannels>
::InitializeIteration()
{
	int i;

	std::cout << "function initialize iteration.." << std::endl;
	if (!this->GetNumberOfElapsedIterations()) {
		std::cout << "Cache fixed image information" << std::endl;
		m_FixedImageOrigin  = GetNthFixedImage(1)->GetOrigin();
		m_FixedImageSpacing = GetNthFixedImage(1)->GetSpacing();
		m_FixedImageDirection = GetNthFixedImage(1)->GetDirection();

		// setting the dsigma2
		double dsigma2 = (1.0 - GetSigma2()) / double(GetNumberOfMaximumIterations());
		SetDeltaSigma2(dsigma2);

		std::cout << "Initializing vector objects in function .." << std::endl;
		for (i = 1; i <= NumberOfMovingChannels; i++) {
			std::cout << "setting up moving image interpolators .." << std::endl;
			typename DefaultInterpolatorType::Pointer interp = DefaultInterpolatorType::New();
			SetNthImageInterpolator(static_cast<InterpolatorType*>(interp.GetPointer()),i);
			GetNthImageInterpolator(i)->SetInputImage(GetNthMovingImage(i));

			std::cout << "setting up moving image warpers .." << std::endl;
			SetNthImageWarper(WarperType::New(), i);
			GetNthImageWarper(i)->SetInterpolator(GetNthImageInterpolator(i));
			GetNthImageWarper(i)->SetEdgePaddingValue(NumericTraits<MovingPixelType>::max());
			GetNthImageWarper(i)->SetOutputOrigin(this->m_FixedImageOrigin);
			GetNthImageWarper(i)->SetOutputSpacing(this->m_FixedImageSpacing);
			GetNthImageWarper(i)->SetOutputDirection(this->m_FixedImageDirection);
			GetNthImageWarper(i)->SetInput(GetNthMovingImage(i));
#if ITK_VERSION_MAJOR >= 4
			GetNthImageWarper(i)->SetDisplacementField(this->GetDisplacementField());
			GetNthImageWarper(i)->GetOutput()->SetRequestedRegion(this->GetDisplacementField()->GetRequestedRegion());
#else
			GetNthImageWarper(i)->SetDeformationField(this->GetDeformationField());
			GetNthImageWarper(i)->GetOutput()->SetRequestedRegion(this->GetDeformationField()->GetRequestedRegion());
#endif
		} //for
	} //if
	
	std::cout <<"warping moving images .."<< std::endl;
	for (i = 1; i <= NumberOfMovingChannels; i++) {
		GetNthImageWarper(i)->Update();
	}
	
	if (!this->GetNumberOfElapsedIterations()) {
		std::cout << "initialization of posteriors to the warped atlas" << std::endl;
		// setting weight images (posterior probs) into reasonable value
		for (i = 1; i <= NumberOfMovingChannels; i++) {
			WeightImageIteratorType w_it(GetNthWeightImage(i),GetNthWeightImage(i)->GetLargestPossibleRegion()); 
			w_it.GoToBegin();

			MovingImageConstIteratorType m_it(GetNthImageWarper(i)->GetOutput(), GetNthImageWarper(i)->GetOutput()->GetLargestPossibleRegion()); 
			m_it.GoToBegin();

			while (!w_it.IsAtEnd()) {
				if (m_it.Get() != NumericTraits<MovingPixelType>::max()) {
					w_it.Set(m_it.Get());
				} else {
					if ((i-1) == BG) {
						w_it.Set(1.0);
					} else {
						w_it.Set(0.0);
					}
				}
				++w_it;
				++m_it;
			}
		}
	}
	
	std::cout << "updating means and variances ..." << std::endl;
	UpdateMeansAndVariances();

	for (i = 0; i < NumberOfMovingChannels; i++) {
		std::cout << "means of class " << i+1 << ": " << GetMeanVector()->at(i) << std::endl;
	}

#if 0
	{
		char means_file[1024];
		char variances_file[1024];
		sprintf(means_file, "jsr_means_%d.txt", this->GetNumberOfElapsedIterations());
		sprintf(variances_file, "jsr_variances_%d.txt", this->GetNumberOfElapsedIterations());
		SaveMeansAndVariances(means_file, variances_file, GetMeanVector(), GetVarianceVector());
	}
#endif

	std::cout << "updating weight images ..." << std::endl;
	ComputeWeightImages();

	std::cout << "metric: " << GetMetric() << std::endl;
	
#if 0
	typedef itk::ImageFileWriter<FixedImageType> WriterType; 
	typename WriterType::Pointer  writer_test  = WriterType::New();
	typename itk::NiftiImageIO::Pointer imageIO = itk::NiftiImageIO::New();
	writer_test->SetImageIO(imageIO);
	std::cout << "writing weight images..." << std::endl;

	for (i = 1; i <= NumberOfMovingChannels; i++) {
		char name[1024];
		sprintf(name,"jsr_weight_image_%d_%d.nii.gz", this->GetNumberOfElapsedIterations(), i-1);
		writer_test->SetFileName( name );
		writer_test->SetInput(GetNthWeightImage(i));
		writer_test->Update();
	}

	std::cout << "writing warped images..." << std::endl;
	typedef itk::ChangeLabelImageFilter<MovingImageType, MovingImageType> ChangeLabelFilterType;
	typename ChangeLabelFilterType::Pointer change_filter = ChangeLabelFilterType::New();

	for (i = 1; i <= NumberOfMovingChannels ; i++) {
		char name[1024];
		sprintf(name,"jsr_warped_image_%d_%d.nii.gz", this->GetNumberOfElapsedIterations(), i-1);
		writer_test->SetFileName( name );
		change_filter->SetInput(GetNthImageWarper(i)->GetOutput());
		change_filter->SetChange(NumericTraits<MovingPixelType>::max(), 0.0);
		writer_test->SetInput(change_filter->GetOutput());
		writer_test->Update();
	} 
#endif

	// log the variables
	LogVariables();
	
	// initialize metric computation variables
	m_NumberOfPixelsProcessed = 0L;
	m_SumOfEMLogs = 0L;
	m_RMSChange = 0L;

	std::cout << "Initialize iteration done, computing updates ..." << std::endl;
}

//////////////////////////////////////////////////////////////////////////////
//Computing the update in a non boundry neighborhood
//We skip over computing Hessians (because of its computations load)
//and update the deformation field using a Levenberg-Marquardt scheme.
//////////////////////////////////////////////////////////////////////////////

/****************************************************************************/
template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
typename JointSegmentationRegistrationFunction<TFixedImage,TMovingImage,TDeformationField,VNumberOfFixedChannels,VNumberOfMovingChannels>
::PixelType
JointSegmentationRegistrationFunction<TFixedImage,TMovingImage,TDeformationField,VNumberOfFixedChannels,VNumberOfMovingChannels>
::ComputeUpdate(const NeighborhoodType &it, void * gd, const FloatOffsetType& itkNotUsed(offset))
{ 
	//cout << "compute update" << endl;
	GlobalDataStruct *globalData = (GlobalDataStruct *)gd;
	PixelType update;
	update.Fill(0.0);
	int i, j, k, dim;
	//return update;

	const IndexType index = it.GetIndex();
	IndexType FirstIndex = this->GetNthFixedImage(1)->GetLargestPossibleRegion().GetIndex();
	IndexType LastIndex = this->GetNthFixedImage(1)->GetLargestPossibleRegion().GetIndex() + this->GetNthFixedImage(1)->GetLargestPossibleRegion().GetSize();
	// Get moving image related information
	// check if the point was mapped outside of the moving image using
	// the "special value" NumericTraits<MovingPixelType>::max()
	double movingValues[NumberOfMovingChannels];
	double posteriorValues[NumberOfMovingChannels];

	for (i = 1; i < NumberOfMovingChannels+1; i++) {
		movingValues[i-1] = GetNthImageWarper(i)->GetOutput()->GetPixel(index);
		posteriorValues[i-1] =  GetNthWeightImage(i)->GetPixel(index);
		if (movingValues[i-1] == NumericTraits <MovingPixelType>::max()) {
			update.Fill(0.0);
			return update;
		}
	}
	
	// declare a variable to hold grad vectors
	CovariantVectorType warpedMovingGradients[NumberOfMovingChannels];
	MovingPixelType movingPixValues[NumberOfMovingChannels];

	// we don't use a CentralDifferenceImageFunction here to be able to
	// check for NumericTraits<MovingPixelType>::max()
	IndexType tmpIndex = index;

	// to the number of the dimensions in the images
	for (dim = 0; dim < ImageDimension; dim++)
	{
		// bounds checking
		if ((FirstIndex[dim]==LastIndex[dim]) || (index[dim] < FirstIndex[dim]) || (index[dim] >= LastIndex[dim]))
		{
			// case one: completely out of bounds
			for (i = 1; i < NumberOfMovingChannels + 1; i++) {
				warpedMovingGradients[i-1][dim] = 0.0;
			}
			continue;
		} else if (index[dim] == FirstIndex[dim]) {
			// case two: starting edge touch
			// compute derivative
			tmpIndex[dim] += 1;
			for (i = 1; i < NumberOfMovingChannels+1; i++) {
				movingPixValues[i-1] = GetNthImageWarper(i)->GetOutput()->GetPixel(tmpIndex);
				if (movingPixValues[i-1] == NumericTraits <MovingPixelType>::max()) {
					// weird crunched border case
					warpedMovingGradients[i-1][dim] = 0.0;
				} else {
					// forward difference
					warpedMovingGradients[i-1][dim] = static_cast<double>(movingPixValues[i-1]) - movingValues[i-1];
					warpedMovingGradients[i-1][dim] /= m_FixedImageSpacing[dim]; 
				}
			}
			tmpIndex[dim] -= 1;
			continue;
		} else if (index[dim] == (LastIndex[dim]-1)) {
			//case three: ending edge touch
			// compute derivative
			tmpIndex[dim] -= 1;
			/** case three: **/
			for (i = 1; i < NumberOfMovingChannels+1; i++) {
				movingPixValues[i-1] = GetNthImageWarper(i)->GetOutput()->GetPixel(tmpIndex);
				if (movingPixValues[i-1] == NumericTraits<MovingPixelType>::max()) {
					// weird crunched border case
					warpedMovingGradients[i-1][dim] = 0.0;
				} else {
					// backward difference
					warpedMovingGradients[i-1][dim] = movingValues[i-1] - static_cast<double>(movingPixValues[i-1]);
					warpedMovingGradients[i-1][dim] /= m_FixedImageSpacing[dim]; 
				}
			}
			tmpIndex[dim] += 1;
			continue;
		} 

		tmpIndex[dim] += 1;

		for (i = 1; i < NumberOfMovingChannels+1; i++) {
			movingPixValues[i-1] = GetNthImageWarper(i)->GetOutput()->GetPixel(tmpIndex);
			if (movingPixValues[i-1] == NumericTraits<MovingPixelType>::max()) {
				// backward difference
				warpedMovingGradients[i-1][dim] = movingValues[i-1];
				tmpIndex[dim] -= 2;
				if (GetNthImageWarper(i)->GetOutput()->GetPixel( tmpIndex )== NumericTraits<MovingPixelType>::max()) {
					// weird crunched border case
					warpedMovingGradients[i-1][dim] = 0.0;
				} else {
					// backward difference
					warpedMovingGradients[i-1][dim] -= static_cast<double>(GetNthImageWarper(i)->GetOutput()->GetPixel(tmpIndex));
					warpedMovingGradients[i-1][dim] /= m_FixedImageSpacing[dim];
				}
				tmpIndex[dim] += 2;
			} else {
				warpedMovingGradients[i-1][dim] = static_cast<double>(movingPixValues[i-1]);
				tmpIndex[dim] -= 2;
				if (GetNthImageWarper(i)->GetOutput()->GetPixel(tmpIndex)== NumericTraits<MovingPixelType>::max()) {
					// forward difference
					warpedMovingGradients[i-1][dim] -= movingValues[i-1];
					warpedMovingGradients[i-1][dim] /= m_FixedImageSpacing[dim];
				} else {
					// normal case, central difference
					warpedMovingGradients[i-1][dim] -= static_cast<double>( GetNthImageWarper(i)->GetOutput()->GetPixel(tmpIndex));
					warpedMovingGradients[i-1][dim] *= 0.5 / m_FixedImageSpacing[dim];
				}
				tmpIndex[dim] += 2;
			}
		} // for number of channels

		tmpIndex[dim] -= 1;
	}//for dimensions

	// so far warpedMovingGradients have been calculated.

	// adding orientation informaiton
#if ITK_VERSION_MAJOR >= 4
	CovariantVectorType usedGradients[NumberOfMovingChannels];
	for (i = 1; i < NumberOfMovingChannels+1; i++) {
		this->GetNthFixedImage(1)->TransformLocalVectorToPhysicalVector(warpedMovingGradients[i-1], usedGradients[i-1]);
	}
#else
#ifdef ITK_USE_ORIENTED_IMAGE_DIRECTION
	CovariantVectorType usedGradients[NumberOfMovingChannels];
	for (i = 1; i < NumberOfMovingChannels+1; i++) {
		this->GetNthFixedImage(1)->TransformLocalVectorToPhysicalVector(warpedMovingGradients[i-1], usedGradients[i-1]);
	}
#else
	CovariantVectorType usedGradients[NumberOfMovingChannels];
	for (i = 1; i < NumberOfMovingChannels+1; i++) {
		usedGradients[i-1] = warpedMovingGradients[i-1];
	} 
#endif
#endif

#ifdef USE_FAST_LIKELIHOOD
	double gv[NumberOfMovingChannels][3];
	double Av[9];
	double bv[3];

	// assume 3 dimension
	if (ImageDimension != 3) {
		std::cout << "ImageDimension must be 3" << std::endl;
		return update;
	}
#endif

	// making vnl vectors for gradients
#ifdef USE_FAST_LIKELIHOOD
	for (i = 0; i < NumberOfMovingChannels; i++) {
		for (j = 0; j < 3; j++) {
			gv[i][j] = usedGradients[i][j];
		}
	}
#else
	std::vector<vnl_vector_fixed<double, ImageDimension>>  grads_vector; 
	for (i = 1; i < NumberOfMovingChannels+1; i++) {
		vnl_vector_fixed<double, ImageDimension> v(0.0);
		for (j = 0; j < ImageDimension; j++) {
			v(j) = usedGradients[i-1][j];
		}
		grads_vector.push_back(v);
	}
#endif

	// we solve Ax=b for x
	// making b
#ifdef USE_FAST_LIKELIHOOD
	for (j = 0; j < 3; j++) {
		bv[j] = 0;
	}
	if (m_bEstimateTumorRegionOnly) {
		if (posteriorValues[TU]+posteriorValues[NCR]+posteriorValues[ED] >= m_fTumorRegionThreshold) {
			for (i = 0; i < NumberOfMovingChannels; i++) {
				double pm = (posteriorValues[i] / (movingValues[i]+eps));
				for (j = 0; j < 3; j++) {
					bv[j] += pm * gv[i][j];
				}
			}
		}
	} else {
		for (i = 0; i < NumberOfMovingChannels; i++) {
			double pm = (posteriorValues[i] / (movingValues[i]+eps));
			for (j = 0; j < 3; j++) {
				bv[j] += pm * gv[i][j];
			}
		}
	}
#else
	vnl_vector_fixed<double, ImageDimension> b(0.0);
	if (m_bEstimateTumorRegionOnly) {
		if (posteriorValues[TU]+posteriorValues[NCR]+posteriorValues[ED] > m_fTumorRegionThreshold) {
			for (i = 0; i < NumberOfMovingChannels; i++) {
				b += (posteriorValues[i] / (movingValues[i]+eps)) * grads_vector.at(i);
			}
		}
	} else {
		for (i = 0; i < NumberOfMovingChannels; i++) {
			b += (posteriorValues[i] / (movingValues[i]+eps)) * grads_vector.at(i);
		}
	}
#endif

	// making A
#ifdef USE_FAST_LIKELIHOOD
	{
		double s;
		s = GetSigma2();
		// decreasing of sigma (depricated)
		//s *= (1.0 - float(this->GetNumberOfElapsedIterations()) * GetDeltaSigma2());
		for (j = 0; j < 3; j++) {
			for (i = 0; i < 3; i++) {
				if (j == i) {
					Av[j*3+i] = s;
				} else {
					Av[j*3+i] = 0;
				}
			}
		}
		for (k = 0; k < NumberOfMovingChannels; k++) {
			double pm = 0.5 * posteriorValues[k] / ((movingValues[k]+epss) * (movingValues[k]+eps));
			for (j = 0; j < 3; j++) {
				for (i = 0; i < 3; i++) {
					Av[j*3+i] += pm * gv[k][j] * gv[k][i];
				}
			}
		}
	}
#else
	vnl_matrix_fixed<double, ImageDimension, ImageDimension> A(0.0);
	// adding a null component to A
	A.set_identity();
	A *= GetSigma2();
	// decreasing of sigma (depricated)
	//A *= (1.0 - float(this->GetNumberOfElapsedIterations()) * GetDeltaSigma2());

	for (i = 0; i < NumberOfMovingChannels; i++) {
		A += (.5*posteriorValues[i] / (movingValues[i]+epss) / (movingValues[i]+eps)) * outer_product(grads_vector.at(i), grads_vector.at(i));
	}
#endif

	// compute update
	vnl_vector_fixed<double, ImageDimension> u;

#ifdef USE_FAST_LIKELIHOOD
	if (bv[0]*bv[0]+bv[1]*bv[1]+bv[2]*bv[2] > 1e-4) {
		double uv[3];
		SolveLinearSystem3(Av, bv, uv);
		u[0] = uv[0];
		u[1] = uv[1];
		u[2] = uv[2];
	} else {
		u[0] = 0;
		u[1] = 0;
		u[2] = 0;
	}
#else
	if (b.squared_magnitude() > 1e-4) {
		// this should be faster
		u = vnl_qr<double>(A).solve(b);
	} else {
		u.fill(0.0);
	}
#endif

	if (!u.is_finite()) {
		update.Fill(0.0);
		std::cout << "oops!, the update field is infinite!" << std::endl;
	} else {
		for (j = 0; j < ImageDimension; j++) {
			update[j] = u[j];    
		}
	}

	// compute energy
	double cost = 0.0;
#ifdef USE_FAST_LIKELIHOOD
	double y[4], ym[4];
	double vv[NumberOfMovingChannels][16];
	double mv[NumberOfMovingChannels][4];

	// assume 4 channels for image
	if (NumberOfFixedChannels != 4) {
		std::cout << "NumberOfFixedChannels must be 4 to use fast likelihood" << std::endl;
		return update;
	}

	for (i = 0; i < NumberOfMovingChannels; i++) {
		for (j = 0; j < 4; j++) {
			for (k = 0; k < 4; k++) {
				vv[i][j*4+k] = GetVarianceVector()->at(i)(j, k);
			}
			mv[i][j] = GetMeanVector()->at(i)(j);
		}
	}
#else
	vnl_vector_fixed<double, NumberOfFixedChannels> y(0.0);
#endif


	for (i = 1; i < NumberOfFixedChannels+1; i++)
	{
#ifdef USE_FAST_LIKELIHOOD
		y[i-1] = GetNthFixedImage(i)->GetPixel(index);
#else
		y(i-1) = GetNthFixedImage(i)->GetPixel(index);
#endif
	}  

	// to the number of classes
	for(i = 0; i < NumberOfMovingChannels; i++)
	{
#ifdef USE_FAST_LIKELIHOOD
		double like;
		ym[0] = y[0] - mv[i][0]; 
		ym[1] = y[1] - mv[i][1]; 
		ym[2] = y[2] - mv[i][2]; 
		ym[3] = y[3] - mv[i][3]; 
		ComputeLikelihood4(vv[i], ym, &like);
		// we have to take care of negative values
		double val = like * movingValues[i];
		if (val <= 0.0) val = epss;
		cost += posteriorValues[i] * log(val);
#else
		MeanType SIy = vnl_qr<double>(GetVarianceVector()->at(i)).solve( y - GetMeanVector()->at(i));
		double ySIy = dot_product(y - GetMeanVector()->at(i), SIy);
		double detS = vnl_determinant(GetVarianceVector()->at(i))+eps;
		/** we have to take care of negative values **/
		double val = movingValues[i];
		if (val <= 0.0) val = epss;
		cost += posteriorValues[i]*( log(1.0/vcl_sqrt(2.0*PI*detS))-.5*ySIy + log( val ));
#endif
	}

	if (globalData) {
		globalData->m_SumOfEMLogs -=  cost; // minus of the cost. so we expect minimization 
		globalData->m_NumberOfPixelsProcessed += 1;
		globalData->m_SumOfSquaredChange += update.GetSquaredNorm();
	}

	// update done
	return update;
}

//////////////////////////////////////////////////////////////////////////////
// Update the metric and release the per-thread-global data.
//////////////////////////////////////////////////////////////////////////////

/****************************************************************************/
template <class TFixedImage, class TMovingImage, class TDeformationField, unsigned int VNumberOfFixedChannels, unsigned int VNumberOfMovingChannels>
void
JointSegmentationRegistrationFunction<TFixedImage,TMovingImage,TDeformationField,VNumberOfFixedChannels,VNumberOfMovingChannels>
::ReleaseGlobalDataPointer(void *gd) const
{
	GlobalDataStruct * globalData = (GlobalDataStruct *) gd;

	m_MetricCalculationLock.Lock();
	
	m_NumberOfPixelsProcessed += globalData->m_NumberOfPixelsProcessed;
	m_SumOfEMLogs += globalData->m_SumOfEMLogs;
	m_SumOfSquaredChange += globalData->m_SumOfSquaredChange;

	if (m_NumberOfPixelsProcessed) {
		m_Metric = (m_SumOfEMLogs) / static_cast<double>(m_NumberOfPixelsProcessed); 

		m_RMSChange = vcl_sqrt(m_SumOfSquaredChange / static_cast<double>(m_NumberOfPixelsProcessed)); 
	}

	m_MetricCalculationLock.Unlock();

	delete globalData;
}


} // end namespace itk

#endif
