package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.algorithms.labelfusion.ObservationBase;
import edu.vanderbilt.masi.algorithms.labelfusion.PerformanceParameters;
import edu.vanderbilt.masi.algorithms.labelfusion.StatisticalFusionBase;

public class STAPLE extends StatisticalFusionBase {
	
	protected PerformanceParametersBase theta_prev;
	protected boolean print_status;
		
	public STAPLE (ObservationBase obs_in,
				   float epsilon_in,
				   int maxiter_in,
				   int priortype_in,
				   String outname) {
		
		this(obs_in, epsilon_in, maxiter_in, priortype_in, outname, false);
	}
	
	public STAPLE (ObservationBase obs_in,
				   float epsilon_in,
				   int maxiter_in,
				   int priortype_in,
				   String outname,
				   boolean quiet) {
	
		super(obs_in, epsilon_in, maxiter_in, priortype_in, outname);
		
		JistLogger.logOutput(JistLogger.WARNING, "\n+++ Initializing STAPLE +++");
		setLabel("STAPLE");
		
		if (!quiet) {
			theta = get_initial_theta();
			theta_prev = new PerformanceParameters(obs.num_labels(), obs.num_raters());
		}
		print_status = true;
	}

	public ImageData run () {
		
		float convergence_factor = Float.MAX_VALUE;
		float [] lp = new float [obs.num_labels()];
		int numiter = 0;
		
		// initialize all of the priors
		initialize_priors(lp);
		
		// iterate until convergence
		JistLogger.logOutput(JistLogger.WARNING, String.format("\n-> Running the Expectation-Maximization Algorithm"));
		while (convergence_factor > epsilon && numiter <= maxiter) {
			
			long time_start = System.nanoTime();
			
			// increment the number of iterations
			numiter++;
			
			// iterate over all non-consensus voxels
			run_pre_EM();
			for (int x = 0; x < obs.dimx(); x++) {
				if (print_status)
					print_status(x, obs.dimx());
				for (int y = 0; y < obs.dimy(); y++)
					for (int z = 0; z < obs.dimz(); z++)
						for (int v = 0; v < obs.dimv(); v++)
							if (!obs.is_consensus(x, y, z, v))
								run_EM_voxel(x, y, z, v, lp);
			}
			run_post_EM();
			
			// calculate the convergence factor
			convergence_factor = theta.get_convergence_factor(theta_prev);
			
			// calculate the time that has elapsed for this iteration
			double elapsed_time = ((double)(System.nanoTime() - time_start)) / 1e9;
						
			JistLogger.logOutput(JistLogger.WARNING, String.format("Convergence Factor (%d, %.3fs): %f", numiter, elapsed_time, convergence_factor));
		}
		
		// remap the estimate to the original label space
		obs.remap_estimate(estimate);
		
		return(estimate);
	}
	
	protected void set_label_probabilities(int x,
										   int y,
										   int z,
										   int v,
										   float [] lp) {
		run_E_step_voxel(x, y, z, v, lp, theta);
	}
	
	protected void run_pre_EM() {
		// reset the performance level parameters
		theta_prev.copy(theta);
		theta.reset();
	}
	
	protected void run_post_EM() {
		// normalize the performance level parameters
		theta.normalize();
	}
	
	protected void run_EM_voxel(int x, int y, int z, int v, float [] lp) {
		
		// if this voxel is already converged, then we can skip the E-step
		if (!isConverged[x][y][z][v]) {
			// run the E-step for the voxel
			run_E_step_voxel(x, y, z, v, lp, theta_prev);
			
			// set the estimate for this voxel
			estimate.set(x+obs.offx(), y+obs.offy(), z+obs.offz(), v+obs.offv(), get_estimate_voxel(lp));
			
			// see if we are converged
			if (get_max_label_probability(lp) > convergence_threshold)
				isConverged[x][y][z][v] = true;
						
		} else {
			
			// if it is converged, the probability is 1 at the estimate and 0 otherwise
			Arrays.fill(lp, 0f);
			lp[estimate.getShort(x+obs.offx(), y+obs.offy(), z+obs.offz(), v+obs.offv())] = 1;
		}
		
		// run the M-step for the voxel
		run_M_step_voxel(x, y, z, v, lp);
	}
	
	protected void run_E_step_voxel(int x,
								   int y,
								   int z,
								   int v,
								   float [] lp,
								   PerformanceParametersBase currtheta) {
		
		// Note: because we are potentially multiplying a very large collection of numbers
		//       the E-step is done in log-space.
		
		// initialize the label probabilities to the prior
		initialize_label_probabilities(x, y, z, v, lp);
		
		// temporarily do everything as a double
		double [] lpd = new double [obs.num_labels()];
		double normfact = 0;
		double maxfact = Double.MIN_VALUE;
		
		// set the label probabilities (E-step)
		for (int s = 0; s < obs.num_labels(); s++) {
			
			// set the initial value 
			lpd[s] = Math.log(lp[s]);
			
			// if the probability is non-zero, iterate over all raters
			if (lp[s] > 0) {
				for (int j = 0; j < obs.num_raters(); j++) {
					
					// use the local selection
					if (!obs.get_local_selection(x, y, z, v, j))
						continue;
					
					// see if the probability is already zero here
					if (lpd[s] == Double.NEGATIVE_INFINITY)
						continue;
					
					// get the information for the current rater
					short [] obslabels = obs.get_all(x, y, z, v, j);
					float [] obsvals = obs.get_all_vals(x, y, z, v, j);
		
					// get the contribution from this rater
					lpd[s] += currtheta.get_log(j, s, obslabels, obsvals);
				}
			}
			
			// keep track of the max value so that we can properly normalize
			if (lpd[s] > maxfact)
				maxfact = lpd[s]; 
			
		}
		
		// calculate the normalization constant and go back to linear space
		for (int s = 0; s < obs.num_labels(); s++) {
			lpd[s] = Math.exp(lpd[s] - maxfact);
			normfact += lpd[s];
		}
		
		if (normfact == 0)
			JistLogger.logOutput(JistLogger.SEVERE, "XXXXX - Problem Found - XXXXX");
					
		// normalize across the label probabilities
		for (int s = 0; s < obs.num_labels(); s++)
			lp[s] = (float) (lpd[s] / normfact);

	}
	
	protected void run_M_step_voxel(int x, int y, int z, int v, float [] lp) {
		
		// add the impact to theta (M-step)
		for (int s = 0; s < obs.num_labels(); s++)
			if (lp[s] > 0)
				for (int j = 0; j < obs.num_raters(); j++) {
					
					// use the local selection
					if (!obs.get_local_selection(x, y, z, v, j))
						continue;
					
					// get the rater observations
					short [] obslabels = obs.get_all(x, y, z, v, j);
					float [] obsvals = obs.get_all_vals(x, y, z, v, j);
					
					// add the impact to theta
					for (int l = 0; l < obslabels.length; l++)
						theta.add(j, obslabels[l], s, obsvals[l]*lp[s]);
				}
	}

}
