
#ifndef _DiffusionTensor_h
#define _DiffusionTensor_h

#include "vnl/vnl_matrix.h"
#include "vnl/vnl_vector.h"
#include "vnl/algo/vnl_cholesky.h"
#include "vnl/algo/vnl_matrix_inverse.h"
#include "vnl/algo/vnl_qr.h"
#include "vnl/algo/vnl_symmetric_eigensystem.h"

class DiffusionTensor
{

public:

  typedef float ScalarType;

  typedef vnl_matrix<ScalarType> MatrixType;
  typedef vnl_vector<ScalarType> VectorType;

  // VNL matrix operators
  typedef vnl_matrix_inverse<ScalarType> MatrixInverseType;
  typedef vnl_cholesky MatrixCholeskyType;
  typedef vnl_qr<ScalarType> MatrixQRType;
  typedef vnl_svd<ScalarType> MatrixSVDType;
  typedef vnl_symmetric_eigensystem<ScalarType> MatrixEigenType;

  DiffusionTensor();
  DiffusionTensor(const DiffusionTensor& dt);
  ~DiffusionTensor();

  const DiffusionTensor& operator=(const DiffusionTensor& dt);
  const DiffusionTensor& operator*=(double d);
  const DiffusionTensor& operator/=(double d);

  DiffusionTensor operator+(const DiffusionTensor& dt);
  DiffusionTensor operator-(const DiffusionTensor& dt);

  DiffusionTensor operator*(double d);
  DiffusionTensor operator/(double d);

  // Get the six tensor elements
  inline const ScalarType* GetElements() const { return m_Elements; }

  // Get the tensor as a matrix
  MatrixType GetMatrix() const;

  // Copy assign from a 3x3 matrix
  void FromMatrix(const MatrixType& m);

  // Set the 6 elements and check if any element is NaN or Inf
  void SetElements(const ScalarType* elems);

  void SetElementAt(unsigned int i, double v);
  void SetElementAt(unsigned int i, unsigned int j, double v);

  void ForcePositiveDefinite();

  void ForceUnitTrace();

  void AddMeanDiffusivity(double d);

private:

  ScalarType m_Elements[6];

};

#endif
