package edu.vanderbilt.masi.algorithms.labelfusion;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SparseMatrix5D extends AbstractCalculation {
	
	private short [][][][][] labels;
	private float [][][][][] vals;
	private int d1, d2, d3, d4;
	private short maxd5;
	
	public SparseMatrix5D (int dim1,
						   int dim2,
						   int dim3,
						   int dim4,
						   short maxdim5) {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing new Sparse Matrix 5D");
		// set the dimensional arguments
		d1 = dim1;
		d2 = dim2;
		d3 = dim3;
		d4 = dim4;
		maxd5 = maxdim5;
		
		// allocate the indexes and values
		labels = new short [d1][d2][d3][d4][];
		vals = new float [d1][d2][d3][d4][];
		
	}
	
	public SparseMatrix5D (int dim1,
						   int dim2,
						   int dim3,
						   int dim4,
						   short maxdim5,
						   SparseMatrix5D sparse_copy) {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing new Sparse Matrix 5D (Full Copy)");
		// set the dimensional arguments
		d1 = dim1;
		d2 = dim2;
		d3 = dim3;
		d4 = dim4;
		maxd5 = maxdim5;
		
		// allocate the indexes and values
		labels = new short [d1][d2][d3][d4][];
		vals = new float [d1][d2][d3][d4][];
		
		for (int x = 0; x < d1; x++)
			for (int y = 0; y < d2; y++)
				for (int z = 0; z < d3; z++)
					for (int v = 0; v < d4; v++)
						if (sparse_copy.get_all_vals(x, y, z, v) != null)
							init_voxel(x, y, z, v,
									   sparse_copy.get_all_labels(x, y, z, v),
									   sparse_copy.get_all_vals(x, y, z, v));
		
	}
	
	public void copy(SparseMatrix5D mat2) {
		
		for (int x = 0; x < d1; x++)
			for (int y = 0; y < d2; y++)
				for (int z = 0; z < d3; z++)
					for (int v = 0; v < d4; v++)
						if (vals[x][y][z][v] != null)
							set_all_vals_inds(x, y, z, v, mat2.get_all_vals(x, y, z, v));
	}
	
	public void init_voxel(int x,
						   int y,
						   int z,
						   int v,
						   short [] labels_in,
						   float [] vals_in) {
		
		int numkeep = labels_in.length;
		
		// allocate the space for this voxel
		labels[x][y][z][v] = new short [numkeep];
		vals[x][y][z][v] = new float [numkeep];
		
		// set the indexes values for this voxel
		for (int k = 0; k < numkeep; k++) {
			labels[x][y][z][v][k] = labels_in[k];
			vals[x][y][z][v][k] = vals_in[k];
		}
	}
	
	public void init_voxel(int x,
						   int y,
						   int z,
						   int v,
						   float [] lp) {
		
		// make sure it is the appropriate length
		if (lp.length != maxd5)
			System.err.println("Cannot init voxel");
		
		// set the number of values (labels) to keep for this voxel
		int numkeep = 0;
		for (int l = 0; l < maxd5; l++)
			if (lp[l] > 0)
				numkeep++;
		
		// allocate the space for this voxel
		labels[x][y][z][v] = new short [numkeep];
		vals[x][y][z][v] = new float [numkeep];
		
		// set the indexes values for this voxel
		int currind = 0;
		for (short l = 0; l < maxd5; l++)
			if (lp[l] > 0) {
				labels[x][y][z][v][currind] = l;
				vals[x][y][z][v][currind] = lp[l];
				currind++;
			}
	}
	
	public void set_val(int x,
						int y,
						int z,
						int v,
						short label,
						float val) {
		
		// get the current array of interest
		short [] A = labels[x][y][z][v];
		
		// if it is empty don't set anything
		if (A == null)
			return;
		
		// if not empty try and find it
		int ind = binary_search(A, label, 0, (A.length-1));
		
		// if we didn't find it don't set anything
		if (ind < 0)
			return;
		
		// if everything worked -- set the value
		vals[x][y][z][v][ind] = val;
	}
	
	public void set_val_ind(int x,
							int y,
							int z,
							int v,
							int ind,
							float val) {
		
		if (ind >= labels[x][y][z][v].length)
			return;
		else
			vals[x][y][z][v][ind] = val;		
	}
	
	public void set_all_vals(int x,
							 int y,
							 int z,
							 int v,
							 float [] lp) {

		for (int i = 0; i < vals[x][y][z][v].length; i++)
			vals[x][y][z][v][i] = lp[labels[x][y][z][v][i]];
	}
	
	public void set_all_vals_inds(int x,
								  int y,
								  int z,
								  int v,
								  float [] indvals) {

		for (int i = 0; i < indvals.length; i++)
			vals[x][y][z][v][i] = indvals[i];
	}
	
	public float get_val(int x,
				   	 	 int y,
				   	 	 int z,
				   	 	 int v,
				   	 	 short label) {
		
		short [] A = labels[x][y][z][v];
		
		if (A == null)
			return(0);
		
		int ind = binary_search(A, label, 0, (A.length-1));
		
		if (ind < 0)
			return(0);
		else
			return(vals[x][y][z][v][ind]);
	}
	
	public float get_val_ind(int x,
						 	 int y,
						 	 int z,
						 	 int v,
						 	 int ind) {

		if (ind >= labels[x][y][z][v].length)
			return(0);
		else
			return(vals[x][y][z][v][ind]);		
	}
	
	public float [] get_all_vals(int x,
								 int y,
								 int z,
								 int v) {
		return(vals[x][y][z][v]);
	}
	
	public void get_label_probabilities(int x,
										int y,
										int z,
										int v,
										float [] lp) {
		
		for (int l = 0; l < labels[x][y][z][v].length; l++)
			lp[labels[x][y][z][v][l]] = vals[x][y][z][v][l];
	}
	
	public short get_max_label(int x,
						       int y,
						       int z,
						       int v) {
		float max_val = 0;
		short max_label = 0;
		
		for (int i = 0; i < vals[x][y][z][v].length; i++)
			if (vals[x][y][z][v][i] > max_val) {
				max_val = vals[x][y][z][v][i];
				max_label = labels[x][y][z][v][i];
			}
		return(max_label);
	}
	
	public float get_max_val(int x,
   					         int y,
					         int z,
					         int v) {
		float max_val = 0;
		
		for (int i = 0; i < vals[x][y][z][v].length; i++)
			if (vals[x][y][z][v][i] > max_val)
				max_val = vals[x][y][z][v][i];
		return(max_val);
	}
		
	public short [] get_all_labels(int x,
								   int y,
								   int z,
								   int v) {
		return(labels[x][y][z][v]);
	}
	
	public float get_convergence_factor(SparseMatrix5D prev) {
		
		int length;
		float convergence_factor = 0;
		float convergence_norm = 0;
		
		for (int x = 0; x < d1; x++)
			for (int y = 0; y < d2; y++)
				for (int z = 0; z < d3; z++)
					for (int v = 0; v < d4; v++)
						if (vals[x][y][z][v] != null) {
							length = vals[x][y][z][v].length;
							for (int i = 0; i < length; i++)
								convergence_factor += Math.abs(vals[x][y][z][v][i] - prev.get_val_ind(x, y, z, v, i));
							convergence_norm += length;
						}
		
		// return the convergence factor
		return(convergence_factor / convergence_norm);
		
	}
	
	public void normalize(int x,
						  int y,
						  int z,
						  int v) {
		
		float valsum = 0;
		for (int i = 0; i < vals[x][y][z][v].length; i++)
			valsum += vals[x][y][z][v][i];
		for (int i = 0; i < vals[x][y][z][v].length; i++)
			vals[x][y][z][v][i] /= valsum;
	}
	
    private int binary_search(short [] A,
							  int key,
							  int imin,
							  int imax) {
		
		if (imax < imin) return(-1);
		
		int imid = (imin + imax) / 2;
		
		// solve for the index
		if (A[imid] > key) return(binary_search(A, key, imin, imid-1));
		else if (A[imid] < key) return(binary_search(A, key, imid+1, imax));
		else return(imid);
	}

}
