package tractography;

import misc.DT;
import numerics.*;

/**
 * Provides interpolated measurements of the diffusion tensor 
 * at any point within the dataset.
 *
 *
 * @version $Id: DT_LinearInterpolator.java,v 1.1 2008/12/08 17:48:43 bennett Exp $
 * @author Philip Cook
 * 
 */
public final class DT_LinearInterpolator extends EightNeighbourInterpolator 
    implements ImageInterpolator {


    private DT_TractographyImage image;

    
    /** Construct an interpolator.
     * @param data the dataset to use for interpolation.
     */
    public DT_LinearInterpolator(DT_TractographyImage data) {
	
	super( data.xDataDim(), data.yDataDim(), data.zDataDim(), 
	       data.xVoxelDim(), data.yVoxelDim(), data.zVoxelDim()
	       );

	image = data;

    }
    
  
    /** 
     * Gets the interpolated tensor at some point. Each tensor component is 
     * interpolated independently. If there are multiple tensors in a voxel, 
     * the one aligned closest to the current tracking direction is used. If any 
     * neighbouring voxels are classified as background, their tensor is replaced by the
     * tensor in the voxel containing the point.
     * 
     * 
     * @param point the point in mm to interpolate at.
     * @return the interpolated tensor.
     *
     */
    protected DT getInterpolatedTensor(Point3D point, Vector3D previousDir) {


	double[] interpFraction = new double[8];
	int[] dims = new int[6];
        
	// get the interpolation parameters
	int inVoxel = setInterpolationVoxels(point, interpFraction, dims);
	
        DT[] tensors = new DT[8];

        for (int i = 0; i < 8; i++) {

            int x = dims[i / 4];
            int y = dims[2 + ((i / 2) % 2)];
            int z = dims[4 + (i % 2)];
            
            tensors[i] = chooseTensor(x,y,z, previousDir);
        }
        
        if (tensors[inVoxel] == null) {
            // highly unlikely but theoretically possible that a precision error might cause this
            tensors[inVoxel] = new DT(0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
        }
        for (int i = 0; i < 8; i++) {

            if (tensors[i] == null) { // happens if one of the voxels is background
                tensors[i] = tensors[inVoxel];
            }

        }
        



	double[] dt000, dt001, dt010, dt011, dt100, dt101, dt110, dt111;

	// possible performance increase by writing this out longhand
	
	dt000 = tensors[0].getComponents();
	dt001 = tensors[1].getComponents();
	dt010 = tensors[2].getComponents();
	dt011 = tensors[3].getComponents();
	dt100 = tensors[4].getComponents();
	dt101 = tensors[5].getComponents();
	dt110 = tensors[6].getComponents();
	dt111 = tensors[7].getComponents();

	double[] components = new double[6];

	for (int i = 0; i < 6; i++) {
	
	    components[i] = 
		
		dt000[i] * interpFraction[0] +
		dt001[i] * interpFraction[1] +
		dt010[i] * interpFraction[2] +
		dt011[i] * interpFraction[3] +
		dt100[i] * interpFraction[4] +
		dt101[i] * interpFraction[5] +
		dt110[i] * interpFraction[6] +
		dt111[i] * interpFraction[7]; 
	
	}

				
	return new DT(components[0], components[1], components[2], 
		      components[3],components[4],components[5]);
	
    }
    

    private DT chooseTensor(int i, int j, int k, Vector3D previousDir) {
	
	DT[] dts = image.getDTs(i,j,k);


        if (dts.length == 0) {
	    return null;
	}
        if (dts.length == 1) {
	    return dts[0];
	}
	
	Vector3D[] pds = image.getPDs(i,j,k);
	
	int choice = 0;
	double maxDot = 0.0;
	
	for (int d = 0; d < pds.length; d++) {
	    double dot = Math.abs(pds[d].dot(previousDir));
	    if (dot > maxDot) {
		maxDot = dot;
		choice = d;
	    }
	}
	
	return dts[choice];
    }

  
    public Vector3D getTrackingDirection(Point3D point, Vector3D previousDirection) {
	DT dt = getInterpolatedTensor(point, previousDirection);
	double[][] seig = dt.sortedEigenSystem();
	
	Vector3D e1 = new Vector3D(seig[1][0], seig[2][0], seig[3][0]);

	if (e1.dot(previousDirection) > 0.0) {
	    return e1;
	}
	else {
	    return e1.negated();
	}
    }


    /** 
     * Get the initial tracking direction, given a pdIndex and a seed point.
     * 
     * @param direction if true, the direction will be the PD, if false, it will be the negated PD.
     * @return the tracking direction for this point. 
     * 
     */
    public Vector3D getTrackingDirection(Point3D point, int pdIndex, boolean direction) {

	int x = (int)( (point.x / xVoxelDim) );
	int y = (int)( (point.y / yVoxelDim) );
	int z = (int)( (point.z / zVoxelDim) );

	Vector3D[] pds = image.getPDs(x,y,z);

	if (direction) {
	    return getTrackingDirection(point, pds[pdIndex]);
	}
	else {
	    return getTrackingDirection(point, pds[pdIndex].negated());
	}

    }


}
