package tractography;

import numerics.*;

import java.util.*;

/**
 * Connection probability image with targets. Connection probabilities are defined to targets 
 * and not to voxels.
 *
 * @version $Id: TargetCP_Image.java,v 1.1 2008/12/08 17:48:43 bennett Exp $
 * @author  Philip Cook
 * 
 */
public class TargetCP_Image {


    // voxel dims of seed space
    private final double xVoxelDim;
    private final double yVoxelDim;
    private final double zVoxelDim;


    // dimensions of seed space
    private final int xDataDim;
    private final int yDataDim;
    private final int zDataDim;

    private short[][][] targets = null;

    private int minTargetIndex = 0;
    private int maxTargetIndex = 0;

    private int numTargets = 0;

    // if directional, streamlineCounts[i] contains streamline counts for direction
    // i = [0,1]. else streamlineCounts[0][t] contains count for target t - minTargetIndex
    private int[][] streamlineCounts;

    // total number of streamlines used to make this image
    private int totalStreamlines = 0;
   
    // update only the target that the streamline hits first
    private boolean countFirstEntry = true;

    // if directional, vector tells us forwards / backwards (or left / right, or whatever)
    private Vector3D forwards = null;

    private boolean directional = false;


    /**
     * Initializes the image with the dimensions of the seed space.
     *
     */
    public TargetCP_Image(short[][][] targets,
				double xVoxelDim, double yVoxelDim, double zVoxelDim) {

	xDataDim = targets.length;
	yDataDim = targets[0].length;
	zDataDim = targets[0][0].length;

	this.targets = targets;

	for (int i = 0; i < xDataDim; i++) {
	    for (int j = 0; j < yDataDim; j++) {
		for (int k = 0; k < zDataDim; k++) {
		    if (targets[i][j][k] > 0 && targets[i][j][k] < minTargetIndex) {
			minTargetIndex = targets[i][j][k];
		    }
		    if (targets[i][j][k] > 0 && targets[i][j][k] > maxTargetIndex) {
			maxTargetIndex = targets[i][j][k];
		    }
		}
	    }
	}
	

	numTargets = maxTargetIndex - minTargetIndex + 1;

	streamlineCounts = new int[1][maxTargetIndex + 1];


	this.xVoxelDim = xVoxelDim;
	this.yVoxelDim = yVoxelDim;
	this.zVoxelDim = zVoxelDim;

    }



    /**
     * Call this method to split the streamlines in two at the seed point and produce 
     * separate connection probability maps from each. The vector <code>v</code> defines a "forward" 
     * direction and should be approximately tangential to the streamline path at the seed point.
     * <p> 
     * Two volumes will be returned from <code>getStreamlineCounts</code> 
     * and <code>getConnectionProbabilities</code>. The first volume is connection probabilities 
     * for streamline segments where the dot product of the streamline direction and v is greater 
     * than zero. The second volume is connection probabilities in the other direction. 
     * <p>
     * As an example, consider mapping the connectivity from the corpus callosum. We have streamlines
     * seeded along the mid-sagittal line, which proceed to the left and right hemispheres. We would
     * therefore use v = [-1, 0, 0]. Streamlines are then split at the seed point and the segments  
     * proceeding left (or in any direction with a negative x component) form the first connection 
     * probability volume. Those proceeding to the right (direction with a positive x component)
     * form the second connection probability volume. We can therefore get connection probability
     * maps to the left and right hemispheres from the same seed point. This example shows
     * the importance of choosing the direction v carefully. It is important that v is not 
     * perpendicular to the streamline direction at any seed point.
     * 
     *
     * @param v a "forward" direction. Should be approximately tangential to the streamline 
     * trajectory at the seed point.
     */
    public void setDirectional(Vector3D v) {
	forwards = v.normalized();
	directional = true;

	streamlineCounts = new int[2][maxTargetIndex + 1];
    }


    public boolean directional() {
	return directional;
    }

    
    /**
     * Add some streamlines to the image.
     *
     */
    public final void processTracts(TractCollection tc) {

	for (int tCounter = 0; tCounter < tc.numberOfTracts(); tCounter++) {

	    processTract(tc.getTract(tCounter));
	}
    }


    /**
     * Add a single streamline to the image.
     *
     */
    public final void processTract(Tract t) {
	
	totalStreamlines++;
	
	boolean[] hitTarget = new boolean[maxTargetIndex+1];
	
	VoxelList voxelList = t.toVoxelList(xVoxelDim, yVoxelDim, zVoxelDim);
	
	Voxel[] voxels = voxelList.getVoxels();
	
	int voxelSeedIndex = voxelList.seedPointIndex();
	int tractSeedIndex = t.seedPointIndex();
	
	int numVoxels = voxels.length;
	
	int numPoints = t.numberOfPoints();
	    

	// if directional, make a separate map for each direction 
	// (forwards and backwards from seed)
	if (directional) {
	
	    // tangent points upwards (ie if you follow the tangent from point s
	    // it points towards point s+1
	    Vector3D tangent = null;

	    if (tractSeedIndex < numPoints - 1) {
		tangent = new Vector3D(t.getPoint(tractSeedIndex+1), 
				       t.getPoint(tractSeedIndex)).normalized();
	    }
	    else if (tractSeedIndex > 0) {
		tangent = new Vector3D(t.getPoint(tractSeedIndex), 
				       t.getPoint(tractSeedIndex-1)).normalized();
	    }
	    else {

		// just one point in streamline

		int voxelTarget = targets[voxels[0].x][voxels[0].y][voxels[0].z];
		    
		if (voxelTarget > 0) {
			
		    streamlineCounts[0][voxelTarget]++;
		    streamlineCounts[1][voxelTarget]++;
		}

		return;

	    }

	    // tangent points in direction of streamline at the seed point, when counting
	    // up from point 0 to the seed point

	    int upwardIndex = forwards.dot(tangent) > 0.0 ? 0 : 1;
	    int downwardIndex = upwardIndex == 0 ? 1 : 0;
	    
	    upwards:
	    for (int v = voxelSeedIndex; v < numVoxels; v++) {
		int voxelTarget = targets[voxels[v].x][voxels[v].y][voxels[v].z];
		
		if (voxelTarget > 0 && !hitTarget[voxelTarget]) {
		    
		    streamlineCounts[upwardIndex][voxelTarget]++;
		    hitTarget[voxelTarget] = true;
		    
		    if (countFirstEntry) {
			break upwards;
		    }
		}
	    }
		
	    // reset hitTarget
	    // this allows streamline to connect at both ends to the same target
	    hitTarget = new boolean[maxTargetIndex+1];
	    
	    // count down to 0 from seed
	    downwards:
	    for (int v = voxelSeedIndex; v >= 0; v--) {
		int voxelTarget = targets[voxels[v].x][voxels[v].y][voxels[v].z];
		
		if (voxelTarget > 0 && !hitTarget[voxelTarget]) {
		    
		    streamlineCounts[downwardIndex][voxelTarget]++;
		    hitTarget[voxelTarget] = true;
		    
		    if (countFirstEntry) {
			break downwards;
		    }
		}
	    }
	    
	
	}
	else {
	    // not directional. If count first entry then we need to 
	    // determine which target the fibre hits first
		
		
	    if (countFirstEntry) {
		    
		double forwardDistance = -1.0;
		double backwardDistance = -1.0;
		    
		int forwardTarget = 0;
		int backwardTarget = 0;
		    
		for (int p = tractSeedIndex; p < numPoints; p++) {

		    Point3D point = t.getPoint(p);

		    int i = (int)(point.x / xVoxelDim);
		    int j = (int)(point.y / yVoxelDim);
		    int k = (int)(point.z / zVoxelDim);

		    int voxelTarget = targets[i][j][k];
			
		    if (voxelTarget > 0) {
			forwardDistance = t.pathLengthFromSeed(p);
			forwardTarget = voxelTarget;
			break;
		    }
		}
                        
		// count down to 0 from seed
		for (int p = tractSeedIndex; p >= 0; p--) {

		    Point3D point = t.getPoint(p);

		    int i = (int)(point.x / xVoxelDim);
		    int j = (int)(point.y / yVoxelDim);
		    int k = (int)(point.z / zVoxelDim);

		    int voxelTarget = targets[i][j][k];
			
		    if (voxelTarget > 0) {
			backwardDistance = t.pathLengthFromSeed(p);
			backwardTarget = voxelTarget;
			break;
		    }
		}
		    
		if (forwardDistance < 0.0) {
		    if (backwardDistance > 0.0) {
			streamlineCounts[0][backwardTarget]++;
		    }
		}
		else if (backwardDistance < 0.0) {
		    if (forwardDistance > 0.0) {
			streamlineCounts[0][forwardTarget]++;
		    }
		}
		else if (forwardDistance < backwardDistance) {
		    streamlineCounts[0][forwardTarget]++;
		}
		else { // also happens if distances are equal
		    streamlineCounts[0][backwardTarget]++;
		}
		    
	    }
	    else {
		// not directional and not count first entry
		// increase cp to all connected targets
		    
		int vl = voxels.length;
		    
		int v = 1;
		    
		int voxelTarget = targets[voxels[0].x][voxels[0].y][voxels[0].z];
		    
		if (voxelTarget > 0) {
		    streamlineCounts[0][voxelTarget]++;
		    hitTarget[voxelTarget] = true;
		}
		    
		while (v < vl) {
			
		    voxelTarget = targets[voxels[v].x][voxels[v].y][voxels[v].z];
			
		    if (voxelTarget > 0 && !hitTarget[voxelTarget]) {
			streamlineCounts[0][voxelTarget]++;
			hitTarget[voxelTarget] = true;
		    }
		    else {
			// don't want to count the same target twice
		    }
			
			
		    v++;
			
		} // end while
		    
	    }
	}
	    
    }
    




    /**
     * Get the streamline counts as an image. Each voxel of a target has the same value,
     * which is the number of streamlines that enter that target.
     * 
     * @return a 4D image <code>sc</code>. If directional, <code>sc.length == 2</code> and 
     * <code>sc[0]</code> is the image for streamline segments proceeding in the forward 
     * direction. 
     *
     * @see #setDirectional(numerics.Vector3D)
     */
    public int[][][][] getStreamlineCounts() {

	int[][][][] scImage = null;

	if (directional) { 
	    scImage = new int[2][xDataDim][yDataDim][zDataDim];
	}
	else {
	    scImage = new int[1][xDataDim][yDataDim][zDataDim];
	}
	
	for (int i = 0; i < xDataDim; i++) {
	    for (int j = 0; j < yDataDim; j++) {
		for (int k = 0; k < zDataDim; k++) {
		    if (targets[i][j][k] > 0) {
			scImage[0][i][j][k] = streamlineCounts[0][targets[i][j][k]];
			
			if (directional) {
			    scImage[1][i][j][k] = streamlineCounts[1][targets[i][j][k]];
			}
		    }
		}
	    }
	}

	return scImage;
    }



    /**
     * Get the connection probabilities as an image, which are the streamline counts
     * divided by the total number of streamlines.
     *
     * @return a 4D image <code>cp</code>. If directional, <code>cp.length == 2</code> and 
     * <code>cp[0]</code> is the image for streamline segments proceeding in the forward 
     * direction. 
     *
     * @see #setDirectional(numerics.Vector3D)
     */
    public double[][][][] getConnectionProbabilities() {
	double[][][][] cp = null;

	if (directional) {
	    cp = new double[2][xDataDim][yDataDim][zDataDim];
	}
	else {
	    cp = new double[1][xDataDim][yDataDim][zDataDim];
	}

	double norm = (double)(totalStreamlines);

	for (int i = 0; i < xDataDim; i++) {
	    for (int j = 0; j < yDataDim; j++) {
		for (int k = 0; k < zDataDim; k++) {

		    if (targets[i][j][k] > 0) {
			cp[0][i][j][k] = streamlineCounts[0][targets[i][j][k]] / norm;
			
			if (directional) {
			    cp[1][i][j][k] = streamlineCounts[1][targets[i][j][k]] / norm;
			}
		    }
		    
		}
	    }
	}

	return cp;

    }


    /**
     * @return the total number of streamlines processed by <code>processTracts</code>.
     *
     */
    public int totalStreamlines() {
	return totalStreamlines;
    }


    /**
     * Gets the streamline counts as an array <code>a</code>, where <code>a[0,1][t]</code> is the
     * streamline count for target <code>t</code>.
     *
     */
    protected int[][] getStreamlineCountsArray() {

	int numArrays = directional ? 2 : 1;

	int[][] defCopy = new int[numArrays][maxTargetIndex + 1];

	for (int i = 0; i < numArrays; i++) {
	    System.arraycopy(streamlineCounts[i], minTargetIndex, defCopy[i], minTargetIndex, 
			     numTargets); 
	}

	return defCopy;
    }



    /**
     * The default behaviour is to count the first entry of a streamline to a target zone. 
     * If this method is called with a <code>false</code> parameter, then streamlines are
     * allowed to connect to any number of target zones.
     *
     * @see #setDirectional(numerics.Vector3D)
     */
    public void setCountFirstEntry(boolean b) {
	countFirstEntry = b;
    }


}
