package edu.vanderbilt.masi.algorithms.labelfusion;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SparseMatrix6D extends AbstractCalculation {
	
	private short [][][][][][] labels;
	private float [][][][][][] vals;
	private int d1, d2, d3, d4, d5;
	private short maxd6;
	
	public SparseMatrix6D (int dim1, int dim2, int dim3, int dim4, int dim5, short maxdim6) {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing new Sparse Matrix 6D +++");
		// set the dimensional arguments
		d1 = dim1;
		d2 = dim2;
		d3 = dim3;
		d4 = dim4;
		d5 = dim5;
		maxd6 = maxdim6;
		
		// allocate the indexes and values
		labels = new short [d1][d2][d3][d4][][];
		vals = new float [d1][d2][d3][d4][][];
		
	}
	
	public SparseMatrix5D get_SparseMatrix5D(int j) {
		SparseMatrix5D sparseW = new SparseMatrix5D(d1, d2, d3, d4, maxd6);
		
		for (int x = 0; x < d1; x++)
			for (int y = 0; y < d2; y++)
				for (int z = 0; z < d3; z++)
					for (int v = 0; v < d4; v++)
						if (vals[x][y][z][v] != null)
							sparseW.init_voxel(x, y, z, v,
									   		   get_all_labels(x, y, z, v, j),
									   		   get_all_vals(x, y, z, v, j));
		
		return(sparseW);
	}
	
	public void set_SparseMatrix5D(int j,
								   SparseMatrix5D sparseW) {
		
		for (int x = 0; x < d1; x++)
			for (int y = 0; y < d2; y++)
				for (int z = 0; z < d3; z++)
					for (int v = 0; v < d4; v++)
						if (vals[x][y][z][v] != null)
							set_all_vals_inds(x, y, z, v, j, sparseW.get_all_vals(x, y, z, v));
	}
	
	public void init_voxel(int x,
						   int y,
						   int z,
						   int v,
						   int j,
						   float [] lp,
						   short lab) {
		
		if (labels[x][y][z][v] == null) {
			labels[x][y][z][v] = new short [d5][];
			vals[x][y][z][v] = new float [d5][];
		}
		
		// make sure it is the appropriate length
		if (lp.length != maxd6)
			System.err.println("Cannot init voxel");
		
		// set the number of values (labels) to keep for this voxel
		int numkeep = 0;
		for (int l = 0; l < maxd6; l++)
			if (lp[l] > 0)
				numkeep++;
		
		if (numkeep == 0) {
			
			labels[x][y][z][v][j] = new short [1];
			vals[x][y][z][v][j] = new float [1];
			labels[x][y][z][v][j][0] = lab;
			vals[x][y][z][v][j][0] = 0.00001f;
			
		} else {	
			
			// allocate the space for this voxel
			labels[x][y][z][v][j] = new short [numkeep];
			vals[x][y][z][v][j] = new float [numkeep];
			
			// set the indexes values for this voxel
			int currind = 0;
			for (short l = 0; l < maxd6; l++)
				if (lp[l] > 0) {
					labels[x][y][z][v][j][currind] = l;
					vals[x][y][z][v][j][currind] = lp[l];
					currind++;
				}
		}
	}
	
	public void free(int x,
					 int y,
					 int z,
					 int v) {
		labels[x][y][z][v] = null;
		vals[x][y][z][v] = null;
	}
	
	public void set_val(int x,
						int y,
						int z,
						int v,
						int j,
						short label,
						float val) {
		
		// get the current array of interest
		short [] A = labels[x][y][z][v][j];
		
		// if it is empty don't set anything
		if (A == null)
			return;
		
		// if not empty try and find it
		int ind = binary_search(A, label, 0, (A.length-1));
		
		// if we didn't find it don't set anything
		if (ind < 0)
			return;
		
		// if everything worked -- set the value
		vals[x][y][z][v][j][ind] = val;
	}
	
	public void set_val_ind(int x,
							int y,
							int z,
							int v,
							int j,
							int ind,
							float val) {
		
		if (ind >= labels[x][y][z][v][j].length)
			return;
		else
			vals[x][y][z][v][j][ind] = val;		
	}
	
	public void set_all_vals(int x,
							 int y,
							 int z,
							 int v,
							 int j,
							 float [] lp) {

		for (int i = 0; i < vals[x][y][z][v][j].length; i++)
			vals[x][y][z][v][j][i] = lp[labels[x][y][z][v][j][i]];
	}
	
	public void set_all_vals_inds(int x,
								  int y,
								  int z,
								  int v,
								  int j,
								  float [] indvals) {

		for (int i = 0; i < indvals.length; i++)
			vals[x][y][z][v][j][i] = indvals[i];
	}
	
	public float get_val(int x,
				   	 	 int y,
				   	 	 int z,
				   	 	 int v,
				   	 	 int j,
				   	 	 short label) {
		
		short [] A = labels[x][y][z][v][j];
		
		if (A == null)
			return(0);
		
		int ind = binary_search(A, label, 0, (A.length-1));
		
		if (ind < 0)
			return(0);
		else
			return(vals[x][y][z][v][j][ind]);
	}
	
	public float get_val_ind(int x,
						 	 int y,
						 	 int z,
						 	 int v,
						 	 int j,
						 	 int ind) {

		if (ind >= labels[x][y][z][v][j].length)
			return(0);
		else
			return(vals[x][y][z][v][j][ind]);		
	}
	
	public float [] get_all_vals(int x,
								 int y,
								 int z,
								 int v,
								 int j) {
		return(vals[x][y][z][v][j]);
	}
	
	public short get_max_label(int x,
						       int y,
						       int z,
						       int v,
						       int j) {
		float max_val = 0;
		short max_label = 0;
		
		for (int i = 0; i < vals[x][y][z][v][j].length; i++)
			if (vals[x][y][z][v][j][i] > max_val) {
				max_val = vals[x][y][z][v][j][i];
				max_label = labels[x][y][z][v][j][i];
			}
		return(max_label);
	}
		
	public short [] get_all_labels(int x,
								   int y,
								   int z,
								   int v,
								   int j) {
		return(labels[x][y][z][v][j]);
	}
		
	public void normalize(int x,
						  int y,
						  int z,
						  int v,
						  int j) {
		
		float valsum = 0;
		for (int i = 0; i < vals[x][y][z][v][j].length; i++)
			valsum += vals[x][y][z][v][j][i];
		for (int i = 0; i < vals[x][y][z][v][j].length; i++)
			vals[x][y][z][v][j][i] /= valsum;
	}
	
    private int binary_search(short [] A,
							  int key,
							  int imin,
							  int imax) {
		
		if (imax < imin) return(-1);
		
		int imid = (imin + imax) / 2;
		
		// solve for the index
		if (A[imid] > key) return(binary_search(A, key, imin, imid-1));
		else if (A[imid] < key) return(binary_search(A, key, imid+1, imax));
		else return(imid);
	}

}
