package tractography;

import misc.DT;
import numerics.*;

/**
 * Provides interpolated measurements of the principal direction
 * at any point within the dataset.
 *
 *
 * @version $Id: VectorLinearInterpolator.java,v 1.1 2008/12/08 17:48:43 bennett Exp $
 * @author Philip Cook
 * 
 */
public final class VectorLinearInterpolator extends EightNeighbourInterpolator 
    implements ImageInterpolator {


    private TractographyImage image;

    
    /** Construct an interpolator.
     * @param data the dataset to use for interpolation.
     */
    public VectorLinearInterpolator(TractographyImage data) {
	
	super( data.xDataDim(), data.yDataDim(), data.zDataDim(), 
	       data.xVoxelDim(), data.yVoxelDim(), data.zVoxelDim()
	       );

	image = data;

    }
    

    private Vector3D chooseVector(int i, int j, int k, Vector3D previousDir) {
	
	// do not modify array
	Vector3D[] vecs = image.getPDs(i,j,k);

	if (vecs.length == 0) {
	    return null;
	}
        if (vecs.length == 1) {
	    if (vecs[0].dot(previousDir) > 0.0) {
		return vecs[0];
	    }
	    else {
		return vecs[0].negated();
	    }
	}
	
	int choice = 0;
	double maxDot = 0.0;
	
	for (int d = 0; d < vecs.length; d++) {
	    double dot = Math.abs(vecs[d].dot(previousDir));
	    if (dot > maxDot) {
		maxDot = dot;
		choice = d;
	    }
	}

	if (vecs[choice].dot(previousDir) > 0.0) {
	    return vecs[choice];
	}
	else {
	    return vecs[choice].negated();
	}
    }

  
    public Vector3D getTrackingDirection(Point3D point, Vector3D previousDirection) {

	double[] interpFraction = new double[8];
	int[] dims = new int[6];
        
	// get the interpolation parameters
	int inVoxel = setInterpolationVoxels(point, interpFraction, dims);
	
        Vector3D[] vectors = new Vector3D[8];

        for (int i = 0; i < 8; i++) {

            int x = dims[i / 4];
            int y = dims[2 + ((i / 2) % 2)];
            int z = dims[4 + (i % 2)];
            
            vectors[i] = chooseVector(x,y,z, previousDirection);
        }
        
        if (vectors[inVoxel] == null) {
            // highly unlikely but theoretically possible that a precision error might cause this
            vectors[inVoxel] = new Vector3D(0.0, 0.0, 0.0);
        }
        for (int i = 0; i < 8; i++) {

            if (vectors[i] == null) { // happens if one of the voxels is background
                vectors[i] = vectors[inVoxel];
            }

        }

	Vector3D v000, v001, v010, v011, v100, v101, v110, v111;
	
	v000 = vectors[0];
	v001 = vectors[1];
	v010 = vectors[2];
	v011 = vectors[3];
	v100 = vectors[4];
	v101 = vectors[5];
	v110 = vectors[6];
	v111 = vectors[7];
	
	double[] components = new double[3];

	components[0] = 
	    v000.x * interpFraction[0] +
	    v001.x * interpFraction[1] +
	    v010.x * interpFraction[2] +
	    v011.x * interpFraction[3] +
	    v100.x * interpFraction[4] +
	    v101.x * interpFraction[5] +
	    v110.x * interpFraction[6] +
	    v111.x * interpFraction[7]; 

	components[1] = 
	    v000.y * interpFraction[0] +
	    v001.y * interpFraction[1] +
	    v010.y * interpFraction[2] +
	    v011.y * interpFraction[3] +
	    v100.y * interpFraction[4] +
	    v101.y * interpFraction[5] +
	    v110.y * interpFraction[6] +
	    v111.y * interpFraction[7]; 

	components[2] = 
	    v000.z * interpFraction[0] +
	    v001.z * interpFraction[1] +
	    v010.z * interpFraction[2] +
	    v011.z * interpFraction[3] +
	    v100.z * interpFraction[4] +
	    v101.z * interpFraction[5] +
	    v110.z * interpFraction[6] +
	    v111.z * interpFraction[7]; 
	
	return new Vector3D(components).normalized();

    }


    /** 
     * Get the initial tracking direction, given a pdIndex and a seed point.
     * 
     * @param direction if true, the direction will be the PD, if false, it will be the negated PD.
     * @return the tracking direction for this point. 
     * 
     */
    public Vector3D getTrackingDirection(Point3D point, int pdIndex, boolean direction) {

	int x = (int)( (point.x / xVoxelDim) );
	int y = (int)( (point.y / yVoxelDim) );
	int z = (int)( (point.z / zVoxelDim) );

	Vector3D[] pds = image.getPDs(x,y,z);

	if (direction) {
	    return getTrackingDirection(point, pds[pdIndex]);
	}
	else {
	    return getTrackingDirection(point, pds[pdIndex].negated());
	}

    }


}
