package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.algorithms.labelfusion.ObservationBase;

public class SpatialSTAPLESimple extends SpatialSTAPLE {
	
	protected float [] local_scalar_perf;
	
	public SpatialSTAPLESimple (ObservationBase obs_in,
						  		int [] hws_in,
						  		float epsilon_in,
						  		float bias_in,
						  		int maxiter_in,
						  		int priortype_in,
						  		String outname) {
		this(obs_in, hws_in, epsilon_in, bias_in, maxiter_in, priortype_in, outname, false);
	}
		
	public SpatialSTAPLESimple (ObservationBase obs_in,
						  		int [] hws_in,
						  		float epsilon_in,
						  		float bias_in,
						  		int maxiter_in,
						  		int priortype_in,
						  		String outname,
						  		boolean quiet) {
		
		super(obs_in, hws_in, epsilon_in, bias_in, maxiter_in, priortype_in, outname, quiet);
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Simplified Spatial STAPLE +++");
		setLabel("Simplified Spatial STAPLE");
		
		// initialize the local performance parameters
		local_scalar_perf = new float [obs.num_raters()];
	}
	
	protected void run_M_step_voxel_local(int x,
								  	      int y,
								  	      int z,
								  	      int v,
								  	      float [] prior_lp) {
		super.run_M_step_voxel_local(x, y, z, v, prior_lp);
		
		// first reset the theta
		Arrays.fill(local_scalar_perf, 0f);
		float [] lp = new float [obs.num_labels()];

		// some temporary variables
		short [] voxlabs;
		float [] voxvals;
		short [] obslabels;
		float [] obsvals;
		float num_keep = 0;
		
		// set the current region of interest
		int xl = Math.max(x - hws[0], 0);
		int xh = Math.min(x + hws[0], obs.dimx()-1);
		int yl = Math.max(y - hws[1], 0);
		int yh = Math.min(y + hws[1], obs.dimy()-1);
		int zl = Math.max(z - hws[2], 0);
		int zh = Math.min(z + hws[2], obs.dimz()-1);
		int vl = Math.max(v - hws[3], 0);
		int vh = Math.min(v + hws[3], obs.dimv()-1);
		
		// iterate over all non-consensus voxels the neighborhood
		for (int xi = xl; xi <= xh; xi++)
			for (int yi = yl; yi <= yh; yi++)
				for (int zi = zl; zi <= zh; zi++)
					for (int vi = vl; vi <= vh; vi++)
						if (!obs.is_consensus(xi, yi, zi, vi)) {
							
							// keep track of the number of non-consensus voxels
							num_keep++;
							
							// get the labels we are considering for this voxel
							voxlabs = sparseW_prev.get_all_labels(xi, yi, zi, vi);
							voxvals = sparseW_prev.get_all_vals(xi, yi, zi, vi);
							
							// set the label probabilities
							for (int ii = 0; ii < voxlabs.length; ii++)
								lp[voxlabs[ii]] = voxvals[ii];
														
							// iterate over the selected raters
							for (int j = 0; j < obs.num_raters(); j++)
								if (obs.get_local_selection(x, y, z, v, j)) {
												
									// get the observations from this rater
									obslabels = obs.get_all(xi, yi, zi, vi, j);
									obsvals = obs.get_all_vals(xi, yi, zi, vi, j);
								
									// add the impact to the estimate of the performance level parameters
									for (int l = 0; l < obslabels.length; l++)
										local_scalar_perf[j] += obsvals[l]*lp[obslabels[l]];
								}
							
							// reset the label probabilities to zero
							for (int ii = 0; ii < voxlabs.length; ii++)
								lp[voxlabs[ii]] = 0;
						}
		
		// normalize theta
		for (int j = 0; j < obs.num_raters(); j++)
			local_scalar_perf[j] /= num_keep;
	}
	
	protected double get_combined_val_log(int j, int s, short [] obslabels, float [] obsvals) {
		return(super.get_combined_val_log(j, s, obslabels, obsvals) + 
			    get_local_scalar_perf_log(j, s, obslabels, obsvals));
	}
	
	protected double get_local_scalar_perf_log(int j, int s, short [] obslabels, float [] obsvals) {
		double localval = 0;
		for (int ii = 0; ii < obslabels.length; ii++)
			if (obslabels[ii] == s)
				localval += obsvals[ii] * local_scalar_perf[j];
			else
				localval += obsvals[ii] * (1 - local_scalar_perf[j]);
		return(Math.log(localval));
	}
}
