package tractography;

import numerics.*;

/**
 *  Abstract superclass for classes that perform streamline tractography. 
 *
 * @version $Id: FibreTracker.java,v 1.1 2008/12/08 17:48:43 bennett Exp $
 * @author  Philip Cook
 * 
 */
public abstract class FibreTracker {

    protected short[][][] visitedVoxel = null; 

    public static final boolean FORWARD = true;
    public static final boolean BACKWARD = false;
    
    protected final double ipThreshold; // maximum curvature between successive steps
    
    protected final double xVoxelDim;
    protected final double yVoxelDim;
    protected final double zVoxelDim; // mm per voxel
    
    protected final int xDataDim;
    protected final int yDataDim;
    protected final int zDataDim;
    
    // Upper boundaries for tracking, in mm
    private final double xUBound;
    private final double yUBound;
    private final double zUBound;

    protected final boolean[][][] isotropic; 
         
    protected final TractographyImage image; 


    short[] zeros;

    /** 
     * Construct a FibreTracker.
     * @param image the image within which the tracking will take place.
     * @param ipThresh the minimum cosine of the angle between the tracking direction 
     * across successive voxels.
     *
     */
    protected FibreTracker(TractographyImage image, double ipThresh) {
        
	ipThreshold = ipThresh;
	
	xDataDim = image.xDataDim();
	yDataDim = image.yDataDim();
	zDataDim = image.zDataDim();

	xVoxelDim = image.xVoxelDim();
	yVoxelDim = image.yVoxelDim();
	zVoxelDim = image.zVoxelDim();

	// Upper bounds
	xUBound = xDataDim * xVoxelDim;
	yUBound = yDataDim * yVoxelDim;
	zUBound = zDataDim * zVoxelDim;

	visitedVoxel = new short[xDataDim][yDataDim][zDataDim];

	isotropic = image.getIsotropicMask();

	this.image = image;

        zeros = new short[zDataDim];

    }

 
    

    
    /** 
     * Track paths from seed points placed at the centre of all voxels within the ROI. 
     * @param roi Voxel region within which to sow seeds. Checks bounds and truncates ROI
     * if necessary.
     * @return a <code>TractCollection</code> containing the results of tracking from all seed 
     * points in the ROI.
     * @see tractography.TractCollection
     */
    public final TractCollection trackPaths(RegionOfInterest roi) {

	Point3D[] points = roi.getSeedPoints();

	TractCollection paths = new TractCollection(points.length + 1, 100.0);
	
	for (int i = 0; i < points.length; i++) {
	    // don't call trackFromSeed(Point, pdIndex, direction)
	    // directly or you will mess up PICo.
	    paths.addTractCollection(trackFromSeed(points[i]));
	    
	}
  
// 	System.err.println("\tReturning " + paths.numberOfTracts() + " paths");

	return paths;
	
  
    }


  
    /** 
     * Track paths from a single seed point within the ROI. 
     * @param point the point in mm to track from
     * @return a <code>TractCollection</code> containing the results of tracking from this 
     * seed point. For single fibre trackers, this <code>TractCollection</code> will contain 
     * a single tract; multi-fibre trackers may return more than one tract.  
     * 
     * Note to developers: Other trackers override this method, so it should always be called 
     * for tractography. Only this method should call 
     * #trackFromSeed(Point3D seedPoint, int pdIndex, boolean direction).
     *
     */
    public TractCollection trackFromSeed(Point3D seedPoint) {

	TractCollection collection = new TractCollection(4, 100.0);	

	if(!inBounds(seedPoint)) {
	    Tract t = new Tract(2, 100.0);
	    t.addPoint(seedPoint);
	    collection.addTract(t);
	    return collection;
	}
	
	int i = (int)(seedPoint.x / xVoxelDim);
	int j = (int)(seedPoint.y / yVoxelDim);
	int k = (int)(seedPoint.z / zVoxelDim);

	int numPDs = image.numberOfPDs(i,j,k);

	for (int p = 0; p < numPDs; p++) {
	    
	    // might need to replace FORWARD and BACKWARD with a vector
	    
	    Tract t1 = trackFromSeed(seedPoint, p, FORWARD);
	    Tract t2 = trackFromSeed(seedPoint, p, BACKWARD);
	
	    t1.joinTract(t2);
	    collection.addTract(t1);

            for (int x = 0; x < xDataDim; x++) {
                for (int y = 0; y < yDataDim; y++) {
                    System.arraycopy(zeros, 0, visitedVoxel[x][y], 0, zDataDim);
                }
            }

	}

	return collection;

    }


    /** 
     * Track paths from a single seed point within the ROI. 
     * @param point the seed point in mm.
     * @param pd the index of the principal direction to follow from the seed point.
     * @return a <code>TractCollection</code> containing a single tract.  
     * 
     * 
     */
    public TractCollection trackFromSeed(Point3D seedPoint, int pd) {

	TractCollection collection = new TractCollection(2, 100.0);	

	if(!inBounds(seedPoint)) {
	    Tract t = new Tract(2, 100.0);
	    t.addPoint(seedPoint);
	    collection.addTract(t);
	    return collection;
	}
	
	int i = (int)(seedPoint.x / xVoxelDim);
	int j = (int)(seedPoint.y / yVoxelDim);
	int k = (int)(seedPoint.z / zVoxelDim);

	int numPDs = image.numberOfPDs(i,j,k);

    	
	Tract t1 = trackFromSeed(seedPoint, pd, FORWARD);
	Tract t2 = trackFromSeed(seedPoint, pd, BACKWARD);
	
	t1.joinTract(t2);

	collection.addTract(t1);

	for (int x = 0; x < xDataDim; x++) {
	    for (int y = 0; y < yDataDim; y++) {
		System.arraycopy(zeros, 0, visitedVoxel[x][y], 0, zDataDim);
	    }
	}


	return collection;

    }



    
    /** 
     * Track paths from a single seed point within the ROI. 
     * @param point the point in mm to track from.
     * @param pdIndex the principal direction to follow.
     * @param direction if {@link tractography.FibreTracker#FORWARD FORWARD}, start tracking along 
     * the PD, otherwise track in the opposite direction.
     *
     * @return a <code>Tract</code> containing the results of tracking. 
     *
     * @see #trackFromSeed(numerics.Point3D)
     * 
     */
    protected abstract Tract trackFromSeed(Point3D seedPoint, int pdIndex, boolean direction);
    // don't call this directly from here (see trackFromSeed(Point3D))
   

    /** 
     * @param position a position to test.
     * @return <code>false</code> if the position is outside the dataset, 
     * <code>true</code> otherwise.
     */
    protected final boolean inBounds(Point3D point) {
	double xPos = point.x;
	double yPos = point.y;
	double zPos = point.z;
	return (xPos >= 0.0 && yPos >= 0.0 && zPos >= 0.0 
		&& xPos < xDataDim * xVoxelDim && yPos < yDataDim * yVoxelDim
		&& zPos < zDataDim * zVoxelDim);
    }


    /** 
     * Finds <code>Tract</code>'s that enter two specified regions.
     * Tracts will be seeded in both regions.
     * @return a <code>TractCollection</code> containing all <code>Tract</code>'s 
     * that enter both regions.
     */
    public final TractCollection getConnectingPaths(RegionOfInterest roi1, RegionOfInterest roi2) {
	
	TractCollection paths;		

	TractCollection connectingPaths = new TractCollection(200, 100.0); 

	paths = trackPaths(roi1);

	// First test paths from roi1 for entry into roi2
	for (int i = 0; i < paths.numberOfTracts(); i++) {
	    
	    int n = 0;
	    
	    int pathPoints = paths.getTract(i).numberOfPoints();
	    
	    while (n < pathPoints) {
		Point3D point = paths.getTract(i).getPoint(n);
		n++;
		
		// Test if point is inside roi2
		if ( roi2.containsMMPoint(point) ) {
		    // point is inside other region
		    connectingPaths.addTract( paths.getTract(i) );
		    break;
		}
	    }
	    
	}
	

	paths = trackPaths(roi2);
	
	// Now test paths from roi2 into roi1
	for (int i = 0; i < paths.numberOfTracts(); i++) {
	    
	    int pathPoints = paths.getTract(i).numberOfPoints();

	    int n = 0;

	    while (n < pathPoints) {
		Point3D point = paths.getTract(i).getPoint(n);
		n++;
		// Test if point is inside THIS region
		if (roi1.containsMMPoint(point) ) {
		    // point is inside roi1
		    // Add this tract to connectingPaths
		    connectingPaths.addTract( paths.getTract(i) );
		    break;
		}
	    }
	    
	}

// 	System.err.println("Returning " + connectingPaths.numberOfTracts() + " connecting paths");
	return connectingPaths;
    }

    
    /**
     *
     * @deprecated this method is for testing, and only works with one PD per voxel. 
     * For applications, write streamline output and then use the streamline processing app.
     */
    protected double[][][] connectionProbability(Point3D seed, int iterations) {
	ConnectionProbabilityImage image = new ConnectionProbabilityImage(xDataDim, yDataDim, zDataDim,
									  xVoxelDim, yVoxelDim, zVoxelDim);

	for (int i = 0; i < iterations; i++) {
	    TractCollection t = trackFromSeed(seed);
	    image.processTracts(t);
	}
	
	return image.getConnectionProbabilities();

    }


    public double ipThreshold() {
	return ipThreshold;
    }
    

    /**
     * PICo trackers override this. Call this instead of image.getPDs if you want the tracker to work with
     * probabilistic subclasses. 
     */
    protected Vector3D[] getPDs(int i, int j, int k) {
	return image.getPDs(i,j,k);
    }




    /** 
     * Track a path from a single seed point, for many monte-carlo iterations
     * and wrap result in a <code>TractCollection</code>.
     * @param seedPoint the point (in mm) to track from.
     * @param mcIterations the number of monte-carlo iterations
     *  
     */
    public TractCollection getPICoTracts(Point3D seedPoint, int mcIterations) {


	int numPDs = image.numberOfPDs((int)(seedPoint.x / xVoxelDim),
				       (int)(seedPoint.y / yVoxelDim),
				       (int)(seedPoint.z / zVoxelDim));
	
	
	TractCollection collection = new TractCollection(numPDs * mcIterations + 1, 100.0);
	
	for (int i = 0; i < mcIterations; i++) {
	    TractCollection t = trackFromSeed(seedPoint);
	    collection.addTractCollection(t);
	}
	
	return collection;
	
    }



  

}






