package tractography;

import numerics.*;

import java.util.Random;

/**
 * Provides interpolated tracking directions at any point within the dataset. 
 * Interpolates directions using a similar method to Behrens et al
 * (Magnetic Resonance in Medicine, 50:1077-1088, 2003). 
 *
 *
 * @version $Id: NeighbourChoiceInterpolator.java,v 1.1 2008/12/08 17:48:43 bennett Exp $
 * @author  Philip Cook
 * 
 */
public class NeighbourChoiceInterpolator extends EightNeighbourInterpolator implements ProbabilisticInterpolator,
 ImageInterpolator {
    
    private final double[] interpFraction = new double[8];
    private final int[] dims = new int[6];

    private final TractographyImage image;

    private final Random ran;

    private Vector3D[][][][] voxelPDs;

    private boolean[][][] randomized;

    private boolean[][][] background;

    private final boolean[] line;

    public NeighbourChoiceInterpolator(TractographyImage image, Random r) {
	
	super( image.xDataDim(), image.yDataDim(),image.zDataDim(), 
	       image.xVoxelDim(), image.yVoxelDim(),image.zVoxelDim()
	       );

        line = new boolean[zDataDim];
	this.image = image;
	ran = r;
	voxelPDs = new Vector3D[xDataDim][yDataDim][zDataDim][];
	randomized = new boolean[xDataDim][yDataDim][zDataDim];
        background = new boolean[xDataDim][yDataDim][zDataDim];
        
       	for (int k = 0; k < zDataDim; k++) {
	    for (int j = 0; j < yDataDim; j++) {
		for (int i = 0; i < xDataDim; i++) {
                    background[i][j][k] = (image.numberOfPDs(i,j,k) == 0);
                }
            }
        }

    }
    
  
    /** 
     * Get the tracking direction at some point. The direction will come from one of the 8 voxels 
     * surrounding the point. The probability of any voxel being chosen is the same as the trilinear 
     * interpolation fraction for that voxel. Note that successive calls with the same point may 
     * return different directions.
     *
     * @param point the point in mm to interpolate at. 
     * @return the list of PDs from one of the neighbours. They will be randomized, if necessary.
     *
     */
    public Vector3D getTrackingDirection(Point3D point, Vector3D previousDirection) {
   
	setInterpolationVoxels(point, interpFraction, dims);

        double sumInterp = 0.0;

        // probability of choosing background voxel is zero
        for (int i = 0; i < 8; i++) {
            int x = dims[i / 4];
            int y = dims[2 + ((i / 2) % 2)];
            int z = dims[4 + (i % 2)];

            if (background[x][y][z]) {
                interpFraction[i] = 0.0;
            }

            sumInterp += interpFraction[i];
        }

        double random = ran.nextDouble() * sumInterp;

	double cumulSum = 0.0; // add up fractions as we go

	// loop over neighbours and see if random number is < cumulSum
	for (int i = 0; i < 8; i++) {

	    cumulSum = cumulSum + interpFraction[i];

	    // include a delta in case of precision errors 	    
	    if (random - cumulSum < 1E-9) {
		int x = dims[i / 4];
		int y = dims[2 + ((i / 2) % 2)];
		int z = dims[4 + (i % 2)];

		if (!randomized[x][y][z]) {
		    voxelPDs[x][y][z] = image.getPDs(x,y,z, previousDirection);
		    randomized[x][y][z] = true;
		}

		// decide which direction in voxel to use based on pdfs
		Vector3D trackingDir = null;
		
		if (voxelPDs[x][y][z].length == 1) {
		    trackingDir = voxelPDs[x][y][z][0];   
		}
		else {

                    double[] dotProd = new double[voxelPDs[x][y][z].length];
                    
                    for (int p = 0; p < voxelPDs[x][y][z].length; p++) {
                        dotProd[p] = Math.abs(voxelPDs[x][y][z][p].dot(previousDirection));
                    }
                    
                    double maxDot = -1.0;
                    int index = -1;
                    
                    for (int p = 0; p < voxelPDs[x][y][z].length; p++) {
                        if (dotProd[p] > maxDot) {
                            maxDot = dotProd[p];
                            index = p;
                        }
                    }
                    
                    trackingDir = voxelPDs[x][y][z][index];
                    
		}

		if (previousDirection.dot(trackingDir) > 0.0) { 
		    return trackingDir;
		}
		else {
		    return trackingDir.negated();
		}
		
		
	    }
	    
	}
	    
	
	// should never get here
	throw new java.lang.IllegalStateException
	    ("Rejected all 8 neighbourhood tensors. cumulsum == " 
	     + cumulSum + ", interpFractions == {" + interpFraction[0] + ", " + interpFraction[1] + 
	     ", " + interpFraction[2] + ", " + interpFraction[3] + ", " + interpFraction[4] + ", " +  
	     interpFraction[5] + ", " + interpFraction[6] + ", " + interpFraction[7] + "}"
	     );
	
    }



    /** 
     * Get the initial tracking direction, given a pdIndex and a seed point.
     * 
     * @param pdIndex follow this pd for the first tracking step. 
     * @param direction if true, the direction will be the pd, if false, it will be the 
     * negated pd.
     * @return the tracking direction for this point. 
     * 
     */
    public Vector3D getTrackingDirection(Point3D point, int pdIndex, boolean direction) {
	
	int x = (int)(point.x / xVoxelDim);
	int y = (int)(point.y / yVoxelDim);
	int z = (int)(point.z / zVoxelDim);

	if (!randomized[x][y][z]) {
	    voxelPDs[x][y][z] = image.getPDs(x,y,z);
	    randomized[x][y][z] = true;
	}

	if (direction) {
	    return getTrackingDirection(point,voxelPDs[x][y][z][pdIndex]);
	}
	else {
	    return getTrackingDirection(point,voxelPDs[x][y][z][pdIndex]).negated();
	}


    }

    
    public final void resetRandomization() {
        
        // doing it this way will make this method have a higher percentage in the profiler
        // but the total execution time will be significantly less than if we allocate a 
        // new volume.
        for (int j = 0; j < yDataDim; j++) {
            for (int i = 0; i < xDataDim; i++) {
                System.arraycopy(line, 0, randomized[i][j], 0, zDataDim);
            }
        }
    }
    
    
}
