package misc;

import tractography.*;
import tools.*;
import numerics.*;


/**
 * Sparsely populated 3D image. The image contains a vector of scalar values at each point
 * in 3D space. The vector length is variable, and is zero if no values are added to the voxel.
 *
 * @author Philip Cook
 * @version $Id: SparseVectorImage.java,v 1.1 2008/12/08 17:48:43 bennett Exp $
 */
public class SparseVectorImage implements VoxelwiseStatisticalImage {


    private final double[][][][] data;

    private final double[][][][] weights;

    private final int[] dataDims;

    private final double[] voxelDims;
    
    private final int[][][] vectorLengths;
    
    protected final static int INITIAL_ARRAY_LENGTH = 50;

    private final static double GROWTH_FACTOR = 2.0;
  
    
    /**
     * Construct an image, default interpolation is nearest neighbour.
     *
     * @param dataDims 3D dimensions of the image.
     * @param voxelDims 3D dimensions of the image, in mm.
     */
    public SparseVectorImage(int[] dataDims, double[] voxelDims) {
	this.voxelDims = voxelDims;
	this.dataDims = dataDims;

	data = new double[dataDims[0]][dataDims[1]][dataDims[2]][];
	weights = new double[dataDims[0]][dataDims[1]][dataDims[2]][];
	vectorLengths = new int[dataDims[0]][dataDims[1]][dataDims[2]];
    }


    
    /**
     * Add a value to the voxel (i,j,k), with unit weight.
     *
     */
    public void addValue(int i, int j, int k, double v) {
	addValue(i,j,k,v, 1.0);
    }



    /**
     * Add a value to the voxel (i,j,k), with custom weight.
     *
     * @param weight must be non-negative.
     *
     */
    public void addValue(int i, int j, int k, double v, double weight) {
	if (data[i][j][k] == null) {
	    data[i][j][k] = new double[INITIAL_ARRAY_LENGTH];
	    weights[i][j][k] = new double[INITIAL_ARRAY_LENGTH];
	}

	data[i][j][k][vectorLengths[i][j][k]] = v;
	weights[i][j][k][vectorLengths[i][j][k]] = weight;
	
	vectorLengths[i][j][k]++;
	
	if (vectorLengths[i][j][k] == data[i][j][k].length) {
	    growDataVectors(i,j,k, GROWTH_FACTOR);
	}
	
    }


    /**
     *
     * Add a value to the voxel containing the point, with unit weight.
     *
     */
    public void addValue(Point3D p, double v) {
	addValue(p, v, 1.0);
    }
    


    /**
     *
     * Add a value to the voxel containing the point, with custom weight.
     *
     * @param weight must be non-negative.
     */
    public void addValue(Point3D p, double v, double weight) {
	
	int i = (int)(p.x / voxelDims[0]);
	int j = (int)(p.y / voxelDims[1]);
	int k = (int)(p.z / voxelDims[2]);
	
	addValue(i,j,k, v, weight);
    }
    

    /**
     * Computes a 3D image, where each voxel intensity is some statistic of the associated vector.
     *
     * @param stat one of "mean", "max", "min", "median", "var".
     */
    public double[][][] getVoxelStatistic(String stat) {
	
	double[][][] result = new double[dataDims[0]][dataDims[1]][dataDims[2]];

	for (int i = 0; i < dataDims[0]; i++) {
	    for (int j = 0; j < dataDims[1]; j++) {	
		for (int k = 0; k < dataDims[2]; k++) {
		    
		    if (vectorLengths[i][j][k] > 0) {
			if (stat.equals("mean")) {
			    result[i][j][k] = ArrayOps.weightedMean(valuesAt(data, i,j,k), 
								    valuesAt(weights, i,j,k));
			}
			else if (stat.equals("max")) {
			    result[i][j][k] = ArrayOps.max(valuesAt(data, i,j,k));
			}
			else if (stat.equals("min")) {
			    result[i][j][k] = ArrayOps.min(valuesAt(data, i,j,k));
			}
			else if (stat.equals("median")) {
			    result[i][j][k] = ArrayOps.median(valuesAt(data, i,j,k));
			}
			else if (stat.equals("var")) {
			    
			    double[] dataVec = valuesAt(data, i,j,k);
			    double[] weightVec = valuesAt(weights, i,j,k);
			    
			    result[i][j][k] = ArrayOps.weightedVar( dataVec, weightVec, 
								    ArrayOps.weightedMean(dataVec, weightVec) );
			}

		    }
			
		}
	    }
	}

	return result;
	
    }


    /**
     * Gets the subset of array data[i][j][k] that contains data.
     *
     * @return the data vector for voxel i,j,k.
     *
     */
    private double[] valuesAt(double[][][][] vol, int i, int j, int k) {

	if (vectorLengths[i][j][k] == 0) {
	    return new double[0];
	}

	double[] values = new double[vectorLengths[i][j][k]];

	System.arraycopy(vol[i][j][k], 0, values, 0, values.length);
	
	return values;
    }

 
    private void growDataVectors(int i, int j, int k, double factor) {
	
	int oldLength = data[i][j][k].length;
	
	int newLength = (int)(oldLength * factor);
	
	double[] replacement = new double[newLength];
	
	System.arraycopy(data[i][j][k], 0, replacement, 0, oldLength);
	
	data[i][j][k] = replacement;
	
	replacement = new double[newLength];
	
	System.arraycopy(weights[i][j][k], 0, replacement, 0, oldLength);
	
	weights[i][j][k] = replacement;
	
    }
   
}
