package tractography;

import numerics.*;

import java.util.Arrays;

/**
 * <dl>
 * <dt>Purpose: To perform Tractography operations.
 * <BR><BR>
 *
 * <dt>Description:
 * <dd> This class does streamline tractography. No interpolation is applied.
 *
 * </dl>
 *
 * @version $Id: NonInterpolatedFibreTracker.java,v 1.1 2008/12/08 17:48:43 bennett Exp $
 * @author  Philip Cook
 * 
 */
public class NonInterpolatedFibreTracker extends FibreTracker {


    /** Construct a tracker with additional options set.
     * @param data the dataset within which the tracking will take place.
     * @param ipThresh the minimum dot product between the tract direction over adjacent voxels.
     */
    public NonInterpolatedFibreTracker(TractographyImage data, double ipThresh) {

	super(data, ipThresh);
    }


    protected final Tract trackFromSeed(Point3D seedPos, int pdIndex, boolean direction) {
	
	Vector3D trackingDirection, previousBearing;

	Tract path = new Tract(200, 100.0);
	
	Point3D currentPos = seedPos;
	Point3D tmpPos;

	int i,j,k;

	path.addPoint(seedPos, 0.0);

	if (!inBounds(seedPos)) {
	    return path;
	}

	i = (int)(currentPos.x / xVoxelDim);
	j = (int)(currentPos.y / yVoxelDim);
	k = (int)(currentPos.z / zVoxelDim);

	if ( isotropic[i][j][k] ) {
	    return path;
	}

	Vector3D[] pds = getPDs(i,j,k);

	trackingDirection = pds[pdIndex];

        if (direction == BACKWARD) {
	    trackingDirection = trackingDirection.negated();
	}

	// Now track the rest of the way, always moving in direction that has the largest dot product
	// with the vector currentPos - previousPos

	// sometimes the tracking can get stuck in an infinite loop
	// As a safeguard, terminate tracts that exceed 10000 points
	short pointsAdded = 1;
	short maxPoints = 10000;
	
	
	
	Vector3D checkCurveBearing = trackingDirection;

	// 6 planes that can intersect with e1
	// front, back, top, bottom, left, right
	double[] planeIntersections = new double[6];

	// check that we don't curve more than IP threshold in the course of the last step
	double checkCurveDisplacement = 0.0; // displacement since last curve check

	while (true) {

	    visitedVoxel[i][j][k] = pointsAdded;

	    double frontPlane = yVoxelDim * (1 + j);
	    double backPlane = yVoxelDim * j;

	    double topPlane = zVoxelDim * (1 + k);
	    double bottomPlane = zVoxelDim * k;

	    double rightPlane = xVoxelDim * (1 + i);
	    double leftPlane = xVoxelDim * i;

            if (trackingDirection.y != 0.0) {
		
		planeIntersections[0] = (frontPlane - currentPos.y) / trackingDirection.y;
		planeIntersections[1] = (backPlane - currentPos.y) / trackingDirection.y;

	    }
	    else {
		// meets at infinity...
                planeIntersections[0] = Double.MAX_VALUE;
		planeIntersections[1] = Double.MAX_VALUE;
	    }		

	    if (trackingDirection.z != 0.0) {
		planeIntersections[2] = (topPlane - currentPos.z) / trackingDirection.z;
		planeIntersections[3] = (bottomPlane - currentPos.z) / trackingDirection.z;

	    }
	    else {
		planeIntersections[2] = Double.MAX_VALUE;
		planeIntersections[3] = Double.MAX_VALUE;
	    }		

	    if (trackingDirection.x != 0.0 ) {
		planeIntersections[4] = (leftPlane - currentPos.x) / trackingDirection.x;
		planeIntersections[5] = (rightPlane - currentPos.x) / trackingDirection.x;
	    }
	    else {
		planeIntersections[4] = Double.MAX_VALUE;
		planeIntersections[5] = Double.MAX_VALUE;
	    }
		
	    // now sort
	    Arrays.sort(planeIntersections);

	    // scale by smallest positive element
	    // allow value of 0.0; occurs only if point sits precisely on a voxel boundary
	    int planeCounter = 0;
	    while (planeIntersections[planeCounter] < 0.0) {
		planeCounter++;

		// this could happen if point was on boundary between voxels and planeIntersection had 
		// to be > 0. No longer a problem
// 		if (planeCounter == 6) {
// 		    throw new IllegalStateException("No plane intersects tracking vector.\nPosition " + currentPos + "\nTracking vector " + trackingDirection);
// 		}
	    }

	    double displacement = planeIntersections[planeCounter] + 0.001;

	    tmpPos = currentPos.displaced(trackingDirection.scaled(displacement));

	    checkCurveDisplacement += displacement;

	    // check new point in image bounds
	    if(!inBounds(tmpPos)) {
		return path;
	    }
	    
	    i = (int)(tmpPos.x / xVoxelDim);
	    j = (int)(tmpPos.y / yVoxelDim);
	    k = (int)(tmpPos.z / zVoxelDim);


	    if (visitedVoxel[i][j][k] > 0) {

		// if we left this voxel some time ago and looped, terminate fibre
		if (pointsAdded - visitedVoxel[i][j][k] > 1) {
		    return path;
		}
                else if (pointsAdded == visitedVoxel[i][j][k]) {
                    // stuck in same voxel

		    displacement += 0.001;

                    tmpPos = currentPos.displaced(trackingDirection.scaled(displacement));

                    // Quit if we are out of bounds
                    if (!inBounds(tmpPos)) {
                        return path;
                    }
	
                    i = (int)(tmpPos.x / xVoxelDim);
                    j = (int)(tmpPos.y / yVoxelDim);
                    k = (int)(tmpPos.z / zVoxelDim);

                }
		else {
		                                                  
		    // if the vector field converges, eg we have | / | \ |
		    // and we are tracking from *                    *
		    // then we'll get shuttled back and forth between the two voxels
		    // in order to avoid this from causing lots of points to get added,
		    // we increase displacement to 0.1 mm

		    // update visitedVoxel since we are going through i,j,k 
		    visitedVoxel[i][j][k] = pointsAdded;

		    displacement = displacement < 0.1 ? 0.1 : displacement;

                    tmpPos = currentPos.displaced(trackingDirection.scaled(displacement));
                    

                    // Quit if we are out of bounds
                    if (!inBounds(tmpPos)) {
                        return path;
                    }
	
                    i = (int)(tmpPos.x / xVoxelDim);
                    j = (int)(tmpPos.y / yVoxelDim);
                    k = (int)(tmpPos.z / zVoxelDim);
                    
                }
	
	    }

	    currentPos = tmpPos;

	    previousBearing = trackingDirection; 

	    if ( isotropic[i][j][k] ) {
		return path;
	    }

	    trackingDirection = getTrackingDirection(i,j,k, previousBearing);

	    if (checkCurveDisplacement >= zVoxelDim) {
		checkCurveDisplacement = 0.0;
		
		if ( checkCurveBearing.dot(trackingDirection) < ipThreshold) {
		    path.addPoint(currentPos, displacement);
		    return path;
		}
		checkCurveBearing = trackingDirection;
	    }

	    // record position
	    path.addPoint(currentPos, displacement);

	    pointsAdded++;

	    if (pointsAdded == maxPoints) {
		return path;
	    }

	    
	}


	
    }


    protected Vector3D getTrackingDirection(int i, int j, int k, Vector3D previousDirection) {
	
	Vector3D[] voxelPDs = getPDs(i,j,k);
	
	double maxProd = 0.0;
	int pdIndex = 0;

	if (voxelPDs.length > 1) {
	    
	    for (int p = 0; p < voxelPDs.length; p++) {
		double dotProd = Math.abs(voxelPDs[p].dot(previousDirection));
		if (dotProd > maxProd) {
		    maxProd = dotProd;
		    pdIndex = p;
		}
	    }
	}
	
	Vector3D trackingDirection = voxelPDs[pdIndex];

	if ( trackingDirection.dot(previousDirection) < 0.0) {
	    // move along -e1
	    trackingDirection = trackingDirection.negated();
	}

	return trackingDirection;
    }


}
