package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SpatialSTAPLE extends STAPLE {
	
	protected int [] hws;
	protected float bias;
	protected SparseMatrix5D sparseW;
	protected SparseMatrix5D sparseW_prev;
	protected PerformanceParametersBase local_theta = null;
	protected double [][] normval_lut;
	protected int [][] normval_count;
	protected final int LUT_SIZE = 20;
	
	public SpatialSTAPLE (ObservationBase obs_in,
						  int [] hws_in,
						  float epsilon_in,
						  float bias_in,
						  int maxiter_in,
						  int priortype_in,
						  String outname) {
		this(obs_in, hws_in, epsilon_in, bias_in, maxiter_in, priortype_in, outname, false);
	}
		
	public SpatialSTAPLE (ObservationBase obs_in,
						  int [] hws_in,
						  float epsilon_in,
						  float bias_in,
						  int maxiter_in,
						  int priortype_in,
						  String outname,
						  boolean quiet1) {
		
		super(obs_in, epsilon_in, maxiter_in, priortype_in, outname, quiet1);
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Spatial STAPLE +++");
		setLabel("Spatial STAPLE");

		/*
		 * Handle the local performance parameters
		 */
		
		// set the half-window radii
		hws = hws_in;
		
		// set the bias amount
		bias = (bias_in < 0) ? 0 : ((bias_in > 1) ? 1 : bias_in);
		normval_lut = new double [LUT_SIZE+1][2*obs.num_labels()];
		normval_count = new int [LUT_SIZE+1][2*obs.num_labels()];
		for (int i = 0; i < LUT_SIZE+1; i++)
			for (int l = 0; l < obs.num_labels(); l++) {
				normval_lut[i][l] = 0.3;
				normval_count[i][l] = 1;
			}
		
		// initialize the local theta
		local_theta = new PerformanceParameters(obs.num_labels(), obs.num_raters());
		local_theta.initialize();
		
		// initialize the sparse label probability matrices
		sparseW = new SparseMatrix5D(obs.dimx(), obs.dimy(), obs.dimz(), obs.dimv(), obs.num_labels());
		sparseW_prev = new SparseMatrix5D(obs.dimx(), obs.dimy(), obs.dimz(), obs.dimv(), obs.num_labels());
		
		print_status = true;
		
		JistLogger.logOutput(JistLogger.INFO, "-> The following parameters were found");
		JistLogger.logOutput(JistLogger.INFO, String.format("Spatial Window Radius (Voxels): [%d %d %d %d]", hws[0], hws[1], hws[2], hws[3]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Spatial Window Dimensions (Voxels): [%d %d %d %d]", 2*hws[0]+1, 2*hws[1]+1, 2*hws[2]+1, 2*hws[3]+1));
		JistLogger.logOutput(JistLogger.INFO, "Bias fraction: " + bias);
	}
	
	protected void run_pre_EM() {
		super.run_pre_EM();

		// reset the previous sparseW
		sparseW_prev.copy(sparseW);
	}
	
	protected void initialize_priors(float [] lp) {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing the priors");
		
		// make sure that we initialize W with a voxelwise prior
		int prev_priortype = priortype;
		if (priortype == PRIORTYPE_GLOBAL)
			priortype = PRIORTYPE_VOXELWISE; // temporary
		
		// initialize the sparse W matrix
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++)
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							initialize_label_probabilities(x, y, z, v, lp);
							sparseW.init_voxel(x, y, z, v, lp);
							sparseW_prev.init_voxel(x, y, z, v, lp);
						}

		// if it was a global prior, reset it appropriately
		priortype = prev_priortype;
	}
	
	protected void set_label_probabilities(int x,
										   int y,
										   int z,
										   int v,
										   float [] lp) {
		for (short l = 0; l < obs.num_labels(); l++)
			lp[l] = sparseW.get_val(x, y, z, v, l);
	}
	
	protected void run_EM_voxel(int x, int y, int z, int v, float [] lp) {
		
		if (!isConverged[x][y][z][v]) {
			
			// get the prior for this voxel
			initialize_label_probabilities(x, y, z, v, lp);
		
			// run the local M step first
			run_M_step_voxel_local(x, y, z, v, lp);
	
			// run the E step
			run_E_step_voxel(x, y, z, v, lp);
			
			// get some summary information from the estimated label probabilities
			short est_label = sparseW.get_max_label(x, y, z, v);
			float max_val = sparseW.get_max_val(x, y, z, v);
			
			// update the estimate
			estimate.set(x + obs.offx(),
					     y + obs.offy(),
					     z + obs.offz(),
					     v + obs.offv(),
					     est_label);
			
			// check if we are converged
			if (max_val > convergence_threshold) {
				
				isConverged[x][y][z][v] = true;
				
				short [] voxlabs = sparseW.get_all_labels(x, y, z, v);
				float [] voxvals = sparseW.get_all_vals(x, y, z, v);
				
				for (int ii = 0; ii < voxlabs.length; ii++)
					voxvals[ii] = (voxlabs[ii] == est_label) ? 1 : 0;
			}
			
		} else {
			// if it is converged, the probability is 1 at the estimate and 0 otherwise
			Arrays.fill(lp, 0f);
			lp[estimate.getShort(x+obs.offx(), y+obs.offy(), z+obs.offz(), v+obs.offv())] = 1;
		}
		
		// always run the global M-step
		run_M_step_voxel(x, y, z, v, lp);
	}
	
	protected void run_M_step_voxel_local(int x,
								  	      int y,
								  	      int z,
								  	      int v,
								  	      float [] lp) {
		
		if (bias == 1)
			return;
		
		// set the current region of interest
		int xl = Math.max(x - hws[0], 0);
		int xh = Math.min(x + hws[0], obs.dimx()-1);
		int yl = Math.max(y - hws[1], 0);
		int yh = Math.min(y + hws[1], obs.dimy()-1);
		int zl = Math.max(z - hws[2], 0);
		int zh = Math.min(z + hws[2], obs.dimz()-1);
		int vl = Math.max(v - hws[3], 0);
		int vh = Math.min(v + hws[3], obs.dimv()-1);
		
		// first reset the theta
		local_theta.reset();
		
		float [] num_obs = new float [obs.num_labels()];
		short [] voxlabs;
		float [] voxvals;
		short [] obslabels;
		float [] obsvals;
		float biasval;
		float winsize = 0;
		int num_vox_keep;
		short tmp_lab;
		float tmp_val;
		
		// iterate over all non-consensus voxels the neighborhood
		for (int xi = xl; xi <= xh; xi++)
			for (int yi = yl; yi <= yh; yi++)
				for (int zi = zl; zi <= zh; zi++)
					for (int vi = vl; vi <= vh; vi++)
						if (!obs.is_consensus(xi, yi, zi, vi)) {
							winsize++;
							
							// get the labels we are considering for this voxel
							voxlabs = sparseW_prev.get_all_labels(xi, yi, zi, vi);
							voxvals = sparseW_prev.get_all_vals(xi, yi, zi, vi);
							num_vox_keep = voxlabs.length;
							
							// iterate over the labels with nonzero probability
							for (int ii = 0; ii < num_vox_keep; ii++)
								if (lp[voxlabs[ii]] > 0 && voxvals[ii] > 0) {
									
									tmp_lab = voxlabs[ii];
									tmp_val = voxvals[ii];
									num_obs[tmp_lab] += tmp_val;
									
									// iterate over the selected raters
									for (int j = 0; j < obs.num_raters(); j++)
										if (obs.get_local_selection(x, y, z, v, j)) {
											
											// get the observations from this rater
											obslabels = obs.get_all(xi, yi, zi, vi, j);
											obsvals = obs.get_all_vals(xi, yi, zi, vi, j);
											
											for (int l = 0; l < obslabels.length; l++)
												local_theta.add(j, obslabels[l], tmp_lab, obsvals[l]*tmp_val);
										}
								}
							
							
						}
		
		// add back in the bias
		for (int ss = 0; ss < obs.num_labels(); ss++)
			if (lp[ss] > 0)
				for (int j = 0; j < obs.num_raters(); j++)
					for (int s = 0; s < obs.num_labels(); s++) {
						biasval = theta_prev.get(j, s, ss);
						if (biasval > 0)
							local_theta.add(j, s, ss, (winsize - num_obs[ss])*biasval);
					}
		
		// normalize theta
		local_theta.normalize(lp);
		
	}
	
	protected void run_E_step_voxel(int x,
								    int y,
								    int z,
								    int v,
								    float [] lp) {
		
		// Note: because we are potentially multiplying a very large collection of numbers
		//       the E-step is done in log-space.
		
		// temporarily do everything as a double
		double [] lpd = new double [obs.num_labels()];
		double normfact = 0;
		double maxfact = Double.MIN_VALUE;
		
		// set the label probabilities (E-step)
		for (int s = 0; s < obs.num_labels(); s++) {
			
			// set the initial value 
			lpd[s] = Math.log(lp[s]);
			
			// if the probability is non-zero, iterate over all raters
			if (lp[s] > 0) {
				for (int j = 0; j < obs.num_raters(); j++) {
					
					// use the local selection
					if (!obs.get_local_selection(x, y, z, v, j))
						continue;
					
					if (lpd[s] == Double.NEGATIVE_INFINITY)
						continue;
					
					// get the rater information
					short [] obslabels = obs.get_all(x, y, z, v, j);
					float [] obsvals = obs.get_all_vals(x, y, z, v, j);
					
					// add the impact from this rater
					lpd[s] += get_combined_val_log(j, s, obslabels, obsvals);
				}
			}
			
			// keep track of the max value so that we can properly normalize
			if (lpd[s] > maxfact)
				maxfact = lpd[s]; 
			
		}
		
		// calculate the normalization constant and go back to linear space
		for (int s = 0; s < obs.num_labels(); s++) {
			lpd[s] = Math.exp(lpd[s] - maxfact);
			normfact += lpd[s];
		}
		
		if (normfact == 0)
			JistLogger.logOutput(JistLogger.SEVERE, "XXXXX - Problem Found - XXXXX");
		
		// normalize the label probabilities and convert back to float
		for (int s = 0; s < obs.num_labels(); s++)
			lp[s] = (float) (lpd[s] / normfact);
		
		sparseW.set_all_vals(x, y, z, v, lp);
	}
	
	protected double get_local_theta_log(int j, int s, short [] obslabels, float [] obsvals) {
		return(local_theta.get_log(j, s, obslabels, obsvals));
	}
	
	protected float get_local_theta_val(int j, int s1, int s2) {
		return(local_theta.get(j, s1, s2));
	}
	
	protected double get_combined_val_log(int j, int s, short [] obslabels, float [] obsvals) {
		
		// first let's make sure we can't skip this stuff
		if (bias == 1)
			return(theta_prev.get_log(j, s, obslabels, obsvals));
		if (bias == 0)
			return(get_local_theta_log(j, s, obslabels, obsvals));
		
		// allocate some space
		double [] vec = new double [obs.num_labels()];
		int veclength = 0;
		double val;
		double expnorm;
		double vecsum = 0;
		double estvalue = 0.3;
		boolean keep = false;
		int index = 0;

		// get the part of the performance parameters that we have to normalize
		for (int s2 = 0; s2 < obs.num_labels(); s2++) {
			val = Math.pow(theta_prev.get(j, s2, s), bias);
			val *= Math.pow(get_local_theta_val(j, s2, s), 1-bias);
			if (val > 0) {
				vec[veclength] = val;
				veclength++;
				vecsum += val;
			}
		}
		
		if (vecsum < 0.99) {
		
			// given the sum, estimate the norm value
			index = (int)(LUT_SIZE * vecsum);
			estvalue = normval_lut[index][veclength-1];
			keep = true;
		
			// get the exponential norm
			expnorm = PerformanceParametersBase.get_exponential_labelnorm(vec, veclength, estvalue);
			
			// add this value to the LUT, if we need to
			if (keep && normval_count[index][veclength-1] < 1000) {
				double currval = normval_lut[index][veclength-1];
				double currnum = normval_count[index][veclength-1];
				normval_lut[index][veclength-1] = ((currval * currnum) + expnorm) / (currnum+1);
				normval_count[index][veclength-1]++;
			}
		} else
			expnorm = 1;
		
		// get the combined impact from this rater
		val = 0;
		val += (1-bias) * get_local_theta_log(j, s, obslabels, obsvals);
		val += bias * theta_prev.get_log(j, s, obslabels, obsvals);
		val *= expnorm;
		return(val);
		
	}
}
