/*=========================================================================

  Program:   Insight Segmentation & Registration Toolkit
  Module:    $RCSfile: itkVectorImageTensorReorientWarpFilter.txx,v $
  Language:  C++
  Date:      $Date: 2005/07/27 15:21:12 $
  Version:   $Revision: 1.6 $

  Copyright (c) Insight Software Consortium. All rights reserved.
  See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even 
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef __itkVectorImageTensorReorientWarpFilter_txx
#define __itkVectorImageTensorReorientWarpFilter_txx
#include "itkVectorImageTensorReorientWarpFilter.h"

#include "itkImageRegionIterator.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkProgressReporter.h"
#include "itkNumericTraits.h"
#include "itkProgressReporter.h"
#include "itkMatrix.h"
#include "itkDiffusionTensor3D.h"
#include "vnl/vnl_math.h"
#include "vnl/vnl_matrix.h"
#include "itkNeighborhoodAlgorithm.h"
#include "itkConstantBoundaryCondition.h"
#include <vnl/algo/vnl_svd.h>

namespace itk
{

/**
 * Default constructor.
 */
template <class TInputImage,class TOutputImage,class TDeformationField>
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::VectorImageTensorReorientWarpFilter()
{
  // Setup the number of required inputs
  this->SetNumberOfRequiredInputs( 2 );  
  
  // Setup default values
  m_OutputSpacing.Fill( 1.0 );
  m_OutputOrigin.Fill( 0.0 );
  m_VectorLength = 1;
  m_OutputDirection.SetIdentity();
  m_EdgePaddingValue.SetSize(m_VectorLength);

  for ( int i=0; i<m_VectorLength; i++)  
    {
    m_EdgePaddingValue[i] = 0;
    }
    
  for (int i = 0; i < ImageDimension; i++)
    {
    m_NeighborhoodRadius[i] = 1; // radius of neighborhood we will use
    }

  // Setup default interpolator
  typename DefaultInterpolatorType::Pointer interp =
    DefaultInterpolatorType::New();

  m_Interpolator = 
    static_cast<InterpolatorType*>( interp.GetPointer() );

}

/**
 * Standard PrintSelf method.
 */
template <class TInputImage,class TOutputImage,class TDeformationField>
void
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::PrintSelf(std::ostream& os, Indent indent) const
{

  Superclass::PrintSelf(os, indent);

  os << indent << "OutputSpacing: " << m_OutputSpacing << std::endl;;
  os << indent << "OutputOrigin: " << m_OutputOrigin << std::endl;
  os << indent << "EdgePaddingValue: "
     << static_cast<typename NumericTraits<PixelType>::PrintType>(m_EdgePaddingValue)
     << std::endl;
  os << indent << "Interpolator: " << m_Interpolator.GetPointer() << std::endl;
  
}

/** Set the output image spacing.*/
template <class TInputImage,class TOutputImage,class TDeformationField>
void
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::SetOutputSpacing(const double* spacing)
{
  SpacingType s(spacing);
  this->SetOutputSpacing( s );
}

/** Set the output image origin.*/
template <class TInputImage,class TOutputImage,class TDeformationField>
void
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::SetOutputOrigin(const double* origin)
{
  PointType p(origin);
  this->SetOutputOrigin(p);
}


/** Set deformation field as Inputs[1] for this ProcessObject.*/
template <class TInputImage,class TOutputImage,class TDeformationField>
void
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::SetDeformationField(
  DeformationFieldType * field )
{
  this->ProcessObject::SetNthInput( 1, field );
}


/**
 * Return a pointer to the deformation field.
 */
template <class TInputImage,class TOutputImage,class TDeformationField>
typename VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::DeformationFieldType *
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::GetDeformationField(void)
{
  return static_cast<DeformationFieldType *>
    ( this->ProcessObject::GetInput( 1 ));
}


/**
 * Setup state of filter before multi-threading.
 * InterpolatorType::SetInputImage is not thread-safe and hence
 * has to be setup before ThreadedGenerateData
 */
template <class TInputImage,class TOutputImage,class TDeformationField>
void
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::BeforeThreadedGenerateData()
{

  if( !m_Interpolator )
    {
    itkExceptionMacro(<< "Interpolator not set");
    }

  // Connect input image to interpolator
  m_Interpolator->SetInputImage( this->GetInput() );

}


/**
 * Compute the output for the region specified by outputRegionForThread.
 */
template <class TInputImage,class TOutputImage,class TDeformationField>
void
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::ThreadedGenerateData(
  const OutputImageRegionType& outputRegionForThread,
  int threadId )
{

  InputImageConstPointer inputPtr = this->GetInput();
  OutputImagePointer outputPtr = this->GetOutput();
  DeformationFieldPointer fieldPtr = this->GetDeformationField();

  // support progress methods/callbacks
  ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
  
  // iterator for the output image
  ImageRegionIteratorWithIndex<OutputImageType> outputIt(
    outputPtr, outputRegionForThread );

  // iterator for the deformation field
  ImageRegionIterator<DeformationFieldType> fieldIt(
    fieldPtr, outputRegionForThread );
  
  IndexType index;
  PointType point;
  DisplacementType displacement;
  realPixelType outputValue(m_VectorLength);
  vnlMatrixType J;
  
  // Define Boundary condition  
    boundaryConditionType cbc;    
    ConstNeighborhoodIteratorType
    bit( m_NeighborhoodRadius, fieldPtr, outputRegionForThread);
    
    cbc.SetConstant(static_cast<VectorType>(0.0));
    bit.OverrideBoundaryCondition(&cbc);
    bit.GoToBegin(); 
     
  
   while( !outputIt.IsAtEnd() )
    {
    // get the output image index
    index = outputIt.GetIndex();
    outputPtr->TransformIndexToPhysicalPoint( index, point );

    // get the required displacement
    displacement = fieldIt.Get();

    // compute the required input image point
    for(unsigned int j = 0; j < ImageDimension; j++ )
      {
      point[j] += displacement[j];
      }
      
     for (int i = 0; i < 3; ++i)
      {
      for (int j = 0; j < 3; ++j)
        {
       	 J(j,i) =  0.5 * (bit.GetPrevious(i)[j]-bit.GetNext(i)[j] );
        }
      }
      
     vnlMatrixType iden;
     iden.set_identity(); 
     vnl_svd<CompType> svd(J + iden);     
     MatrixType  rotationMatrix(svd.U() * svd.V().transpose());

    // get the interpolated value
    if( m_Interpolator->IsInsideBuffer( point ) )
      {
      typedef typename InterpolatorType::OutputType  OutputType;
      const OutputType interpolatedValue = m_Interpolator->Evaluate( point );
      
      for(  int k = 0; k < m_VectorLength; k++ )
        {
        outputValue[k] = static_cast<double>( interpolatedValue[k] );
        }
	
     //Assuming the scalar components are before the tensor components
     
      MatrixType tensorD;
      
                tensorD(0,0)= outputValue[m_VectorLength-6];  		tensorD(0,1)= outputValue[m_VectorLength-5];	        tensorD(0,2)= outputValue[m_VectorLength-4]; 
		tensorD(1,0)= outputValue[m_VectorLength-5]; 		tensorD(1,1)= outputValue[m_VectorLength-3];		tensorD(1,2)= outputValue[m_VectorLength-2];
		tensorD(2,0)= outputValue[m_VectorLength-4];		tensorD(2,1)= outputValue[m_VectorLength-2];		tensorD(2,2)= outputValue[m_VectorLength-1];	
      
      vnlMatrixType OutMatrix = rotationMatrix*tensorD*rotationMatrix.GetTranspose();
      realPixelType outValue(m_VectorLength);
      
      for(int i=0;i<m_VectorLength-6;i++)
      {
       outValue[i] = outputValue[i];
      
      }
     		
		outValue[m_VectorLength-6] = OutMatrix(0,0);
		outValue[m_VectorLength-5] = OutMatrix(0,1);
		outValue[m_VectorLength-4] = OutMatrix(0,2);
		outValue[m_VectorLength-3] = OutMatrix(1,1);
		outValue[m_VectorLength-2] = OutMatrix(1,2);
		outValue[m_VectorLength-1] = OutMatrix(2,2);
      
	
      outputIt.Set( outValue );
      }
    else
      {
      outputIt.Set( m_EdgePaddingValue );
      }   
    ++outputIt;
    ++fieldIt; 
    ++bit;
    progress.CompletedPixel();
    }

}


template <class TInputImage,class TOutputImage,class TDeformationField>
void
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::GenerateInputRequestedRegion()
{

  // call the superclass's implementation
  Superclass::GenerateInputRequestedRegion();

  // request the largest possible region for the input image
  InputImagePointer inputPtr = 
    const_cast< InputImageType * >( this->GetInput() );

  if( inputPtr )
    {
    inputPtr->SetRequestedRegionToLargestPossibleRegion();
    }

  // just propagate up the output requested region for the 
  // deformation field.
  DeformationFieldPointer fieldPtr = this->GetDeformationField();
  OutputImagePointer outputPtr = this->GetOutput();
  if( fieldPtr )
    {
    fieldPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
    }

}


template <class TInputImage,class TOutputImage,class TDeformationField>
void
VectorImageTensorReorientWarpFilter<TInputImage,TOutputImage,TDeformationField>
::GenerateOutputInformation()
{
  // call the superclass's implementation of this method
  Superclass::GenerateOutputInformation();

  OutputImagePointer outputPtr = this->GetOutput();

  outputPtr->SetSpacing( m_OutputSpacing );
  outputPtr->SetOrigin( m_OutputOrigin );
  outputPtr->SetDirection(m_OutputDirection);
  outputPtr->SetVectorLength(m_VectorLength);

  DeformationFieldPointer fieldPtr = this->GetDeformationField();
  if( fieldPtr )
    {
    outputPtr->SetLargestPossibleRegion( fieldPtr->
                                         GetLargestPossibleRegion() );
    }

}


} // end namespace itk

#endif
 
