package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.vanderbilt.masi.algorithms.labelfusion.ObservationBase;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.algorithms.labelfusion.PerformanceParameters;

public abstract class StatisticalFusionBase extends LabelFusionBase {
	
	protected float epsilon;
	protected int maxiter;
	protected int priortype;
	protected PerformanceParametersBase theta;
	protected float [] global_prior;
	protected float [] adaptive_vals;
	protected boolean [][][][] isConverged;
	protected final float convergence_threshold = 0.99999f; 
		
	// constant values for the available prior types
	public static final int PRIORTYPE_GLOBAL = 0;
	public static final int PRIORTYPE_VOXELWISE = 1;
	public static final int PRIORTYPE_ADAPTIVE = 2;
	public static final int PRIORTYPE_WEIGHTED_VOXELWISE = 3;
	
	public StatisticalFusionBase (ObservationBase obs_in,
				   				  float epsilon_in,
				   				  int maxiter_in,
				   				  int priortype_in,
				   				  String outname) {
		super(obs_in, outname);
		
		// set the input observations
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Generic Statistical Fusion Algorithm +++");
		epsilon = epsilon_in;
		maxiter = maxiter_in;
		priortype = priortype_in;
		adaptive_vals = null;
		
		initialize_convergence_detector();
		
		// if necessary, initialize the prior
		if (priortype == PRIORTYPE_GLOBAL)
			set_global_prior();
		
		if (priortype == PRIORTYPE_ADAPTIVE)
			set_adaptive_prior();
		
		JistLogger.logOutput(JistLogger.INFO, "-> The following parameters were found");
		JistLogger.logOutput(JistLogger.INFO, "Convergence Threshold: " + epsilon);
		JistLogger.logOutput(JistLogger.INFO, "Maximum Number of Iterations: " + maxiter);
		JistLogger.logOutput(JistLogger.INFO, "Prior Distribution Type: " + get_priortype_string());
	}
	
	protected abstract void run_EM_voxel(int x, int y, int z, int v, float [] lp);
	
	protected void initialize_priors(float [] lp) {};
	
	private String get_priortype_string() {
		String type;
		switch(priortype) {
			case PRIORTYPE_GLOBAL: type = "Global"; break;
			case PRIORTYPE_VOXELWISE: type = "Voxelwise"; break;
			case PRIORTYPE_ADAPTIVE: type = "Adaptive"; break; 
			case PRIORTYPE_WEIGHTED_VOXELWISE: type = "Weighted-Voxelwise"; break; 
			default: type = "Unknown"; break;
		}
		return(type);
	}
	
	private void initialize_convergence_detector() {
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing convergence detection for algorithmic speedup");
		isConverged = new boolean[obs.dimx()][obs.dimy()][obs.dimz()][obs.dimv()];
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++)
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v))
							isConverged[x][y][z][v] = false;
						else
							isConverged[x][y][z][v] = true;
	}
		
	protected void initialize_label_probabilities(int x, int y, int z, int v, float [] lp) {
		initialize_label_probabilities(x, y, z, v, lp, priortype);
	}
	
	protected void initialize_label_probabilities(int x, int y, int z, int v, float [] lp, int priortype) {
		
		// set the label probabilities given the prior
		if (priortype == PRIORTYPE_WEIGHTED_VOXELWISE) {
			obs.get_weighted_prior(x, y, z, v, lp);
			
		} else if (priortype == PRIORTYPE_VOXELWISE) {
			Arrays.fill(lp, 0f);
			for (int j = 0; j < obs.num_raters(); j++)
				if (obs.get_local_selection(x, y, z, v, j)) {
					short [] obslabels = obs.get_all(x, y, z, v, j);
					float [] obsvals = obs.get_all_vals(x, y, z, v, j);
					for (int l = 0; l < obslabels.length; l++)
						lp[obslabels[l]] += obsvals[l];
				}
			for (int s = 0; s < obs.num_labels(); s++)
				lp[s] /= obs.num_raters();
			
		} else if (priortype == PRIORTYPE_GLOBAL) {
			for (int s = 0; s < obs.num_labels(); s++)
				lp[s] = global_prior[s];
			
		} else if (priortype == PRIORTYPE_ADAPTIVE) {
			obs.initialize_adaptive_probabilities(x, y, z, v, 1, lp);
			float [][] obsvals = obs.get_all_vals_full(x, y, z, v);
			for (int l = 0; l < obs.num_labels(); l++)
				if (lp[l] > 0) {
					for (int j = 0; j < obs.num_raters(); j++)
						lp[l] *= (obsvals[j][l] * adaptive_vals[j]) + ((1 - obsvals[j][l]) * (1 - adaptive_vals[j]));
					lp[l] = (float)Math.cbrt(lp[l]);
				}
			normalize_label_probabilities(lp);
			
		} else {
			String errstr = "Error: Invalid Prior Type";
			JistLogger.logOutput(JistLogger.SEVERE, errstr);
			throw new RuntimeException(errstr);
		}
	}
	
	private void set_global_prior() {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing global prior");
		
		// allocate space for the vector
		global_prior = new float [obs.num_labels()];
		
		// set the global prior
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++)
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v))
							for (int j = 0; j < obs.num_raters(); j++) {
								short [] obslabels = obs.get_all(x, y, z, v, j);
								float [] obsvals = obs.get_all_vals(x, y, z, v, j);
								for (int l = 0; l < obslabels.length; l++)
									global_prior[obslabels[l]] += obsvals[l];
							}
		
		// normalize the global prior
		normalize_label_probabilities(global_prior);
	}

	private void set_adaptive_prior() {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing adaptive prior");
		
		// allocate space for the adaptive values
		adaptive_vals = new float [obs.num_raters()];
		for (int j = 0; j < obs.num_raters(); j++)
			adaptive_vals[j] = 0.9f;
		
		// initialize some temporary space
		float [] prev_adaptive_vals = new float [obs.num_raters()];
		float [] lp = new float [obs.num_labels()];
		
		// iterate until convergence
		float convergence_factor = Float.MAX_VALUE;
		while(convergence_factor > 1e-4f) {

			int num = 0;
			
			// initialize the adaptive values
			for (int j = 0; j < obs.num_raters(); j++) {
				prev_adaptive_vals[j] = adaptive_vals[j];
				adaptive_vals[j] = 0;
			}
			
			// iterate over the voxels
			for (int x = 0; x < obs.dimx(); x++)
				for (int y = 0; y < obs.dimy(); y++)
					for (int z = 0; z < obs.dimz(); z++)
						for (int v = 0; v < obs.dimv(); v++)
							if (!obs.is_consensus(x, y, z, v)) {
								num++;
								
								// set the label probabilities
								obs.initialize_adaptive_probabilities(x, y, z, v, 1, lp);
								float [][] obsvals = obs.get_all_vals_full(x, y, z, v);
								for (int j = 0; j < obs.num_raters(); j++)
									for (int l = 0; l < obs.num_labels(); l++)
										lp[l] *= (obsvals[j][l] * prev_adaptive_vals[j]) + ((1 - obsvals[j][l]) * (1 - prev_adaptive_vals[j]));
								normalize_label_probabilities(lp);
								
								// add the impact from this voxel
								for (int j = 0; j < obs.num_raters(); j++)
									for (int l = 0; l < obs.num_labels(); l++)
										adaptive_vals[j] += obsvals[j][l] * lp[l];
							}
			
			// normalize the adaptive values
			for (int j = 0; j < obs.num_raters; j++)
				adaptive_vals[j] /= (float)num;
			
			convergence_factor = 0;
			for (int j = 0; j < obs.num_raters; j++)
				convergence_factor += Math.abs(adaptive_vals[j] - prev_adaptive_vals[j]);
			
		}
		
		for (int j = 0; j < obs.num_raters; j++)
			JistLogger.logOutput(JistLogger.INFO, String.format("Likelihood for Rater %d: %f\n", j, adaptive_vals[j]));
		
	}
	
	protected PerformanceParametersBase get_initial_theta() {
		PerformanceParametersBase prior_theta = new PerformanceParameters(obs.num_labels(), obs.num_raters());
		
		// if it is a global prior, then do a standard initialization
		if (priortype == PRIORTYPE_GLOBAL) {
			prior_theta.initialize();
			return(prior_theta);
		}
		
		// initialize some variables
		float [] lp = new float [obs.num_labels()];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++) 
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							
							// set the probabilities from the prior
							initialize_label_probabilities(x, y, z, v, lp);
							
							// add the impact to theta
							for (int s = 0; s < obs.num_labels(); s++)
								if (lp[s] > 0)
									for (int j = 0; j < obs.num_raters(); j++)
										if (obs.get_local_selection(x, y, z, v, j)) {
											
											// get the rater observations
											short [] obslabels = obs.get_all(x, y, z, v, j);
											float [] obsvals = obs.get_all_vals(x, y, z, v, j);
											
											// add the impact to theta
											for (int l = 0; l < obslabels.length; l++)
												prior_theta.add(j, obslabels[l], s, obsvals[l]*lp[s]);
										}
						}
		
		// normalize properly
		prior_theta.normalize();
		
		// return the result
		return(prior_theta);
	}
	
	public static PerformanceParametersBase get_majority_vote_theta(ObservationBase obs) {
		PerformanceParametersBase prior_theta = new PerformanceParameters(obs.num_labels(), obs.num_raters());
		
		// initialize some variables
		float [] lp = new float [obs.num_labels()];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++) 
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							
							// set the probabilities from the prior
							Arrays.fill(lp, 0f);
							for (int j = 0; j < obs.num_raters(); j++)
								if (obs.get_local_selection(x, y, z, v, j)) {
									short [] obslabels = obs.get_all(x, y, z, v, j);
									float [] obsvals = obs.get_all_vals(x, y, z, v, j);
									for (int l = 0; l < obslabels.length; l++)
										lp[obslabels[l]] += obsvals[l];
								}
							
							// add the impact to theta
							for (int s = 0; s < obs.num_labels(); s++)
								if (lp[s] > 0)
									for (int j = 0; j < obs.num_raters(); j++)
										if (obs.get_local_selection(x, y, z, v, j)) {
											
											// get the rater observations
											short [] obslabels = obs.get_all(x, y, z, v, j);
											float [] obsvals = obs.get_all_vals(x, y, z, v, j);
											
											// add the impact to theta
											for (int l = 0; l < obslabels.length; l++)
												prior_theta.add(j, obslabels[l], s, obsvals[l]*lp[s]);
										}
						}
		
		// normalize properly
		prior_theta.normalize();
		
		// return the result
		return(prior_theta);
	}
	
	protected PerformanceParametersBase get_initial_theta(HierarchicalModel h_model) {
		PerformanceParametersBase prior_theta = new PerformanceParametersVectorized(obs.num_raters(), h_model);
		
		// if it is a global prior, then do a standard initialization
		if (priortype == PRIORTYPE_GLOBAL) {
			prior_theta.initialize();
			return(prior_theta);
		}
		
		// initialize some variables
		float [] lp = new float [obs.num_labels()];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++) 
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							
							// set the probabilities from the prior
							initialize_label_probabilities(x, y, z, v, lp);
							
							// add the impact to theta (M-step)
							for (int i = 0; i < h_model.num_levels(); i++)
								if (h_model.consensus_level(x, y, z, v) < i+1)
									for (int s = 0; s < obs.num_labels(); s++)
										if (lp[s] > 0)
											for (int j = 0; j < obs.num_raters(); j++) {
												
												// use the local selection
												if (!obs.get_local_selection(x, y, z, v, j))
													continue;
												
												// get the rater observations
												short [] obslabels = obs.get_all(x, y, z, v, j);
												float [] obsvals = obs.get_all_vals(x, y, z, v, j);
												
												// add the impact to theta
												for (int l = 0; l < obslabels.length; l++)
													prior_theta.add(i, j, obslabels[l], s, obsvals[l]*lp[s]);
											}
						}
		
		// normalize properly
		prior_theta.normalize();
		
		// return the result
		return(prior_theta);
	}
	
	public PerformanceParametersBase get_theta() { return(theta); }

	public int get_prior_type() { return(priortype); }
}
