package edu.vanderbilt.masi.algorithms.labelfusion;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.vanderbilt.masi.algorithms.labelfusion.ObservationBase;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class AdaptiveLabelFusion extends LabelFusionBase {
	
	float [] adaptive_vals;
	
	public AdaptiveLabelFusion (ObservationBase obs_in,
						 		String outname) {
		
		super(obs_in, outname);
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Adaptive Label Fusion +++");
		setLabel("AdaptiveLabelFusion");
	}
	
	public ImageData run () {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Running Expectation-Maximization Algorithm");
		
		// allocate space for the adaptive values
		adaptive_vals = new float [obs.num_raters()];
		for (int j = 0; j < obs.num_raters(); j++)
			adaptive_vals[j] = 0.9f;
		
		// initialize some temporary space
		float [] prev_adaptive_vals = new float [obs.num_raters()];
		float [] lp = new float [obs.num_labels()];
		
		int numiter = 0;
		
		// iterate until convergence
		float convergence_factor = Float.MAX_VALUE;
		while(convergence_factor > 1e-4f) {

			long time_start = System.nanoTime();
			
			// increment the number of iterations
			numiter++;
			
			int numvox = 0;
			
			// initialize the adaptive values
			for (int j = 0; j < obs.num_raters(); j++) {
				prev_adaptive_vals[j] = adaptive_vals[j];
				adaptive_vals[j] = 0;
			}
			
			// iterate over the voxels
			for (int x = 0; x < obs.dimx(); x++)
				for (int y = 0; y < obs.dimy(); y++)
					for (int z = 0; z < obs.dimz(); z++)
						for (int v = 0; v < obs.dimv(); v++)
							if (!obs.is_consensus(x, y, z, v)) {
								numvox++;
								
								// set the label probabilities
								obs.initialize_adaptive_probabilities(x, y, z, v, 1, lp);
								float [][] obsvals = obs.get_all_vals_full(x, y, z, v);
								for (int j = 0; j < obs.num_raters(); j++)
									for (int l = 0; l < obs.num_labels(); l++)
										lp[l] *= (obsvals[j][l] * prev_adaptive_vals[j]) + ((1 - obsvals[j][l]) * (1 - prev_adaptive_vals[j]));
								normalize_label_probabilities(lp);
								
								// add the impact from this voxel
								for (int j = 0; j < obs.num_raters(); j++)
									for (int l = 0; l < obs.num_labels(); l++)
										adaptive_vals[j] += obsvals[j][l] * lp[l];
							}
			
			// normalize the adaptive values
			for (int j = 0; j < obs.num_raters; j++)
				adaptive_vals[j] /= (float)numvox;
			
			convergence_factor = 0;
			for (int j = 0; j < obs.num_raters; j++)
				convergence_factor += Math.abs(adaptive_vals[j] - prev_adaptive_vals[j]);
			
			// calculate the time that has elapsed for this iteration
			double elapsed_time = ((double)(System.nanoTime() - time_start)) / 1e9;
						
			JistLogger.logOutput(JistLogger.INFO, String.format("Convergence Factor (%d, %.3fs): %f", numiter, elapsed_time, convergence_factor));
			
		}
		
		// set the final estimate
		JistLogger.logOutput(JistLogger.INFO, "-> Setting the final segmentation estimate");
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++) 
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							
							// set the label probabilities for this voxel
							set_label_probabilities(x, y, z, v, lp);
							
							// save the estimate
							estimate.set(x+obs.offx(), y+obs.offy(), z+obs.offz(), v+obs.offv(), get_estimate_voxel(lp));
						}
		
		// remap the estimate to the original label space
		obs.remap_estimate(estimate);
		
		return(estimate);
	}
	
	public SparseMatrix5D get_sparse_label_probabilites() {
		
		// initialize some variables
		SparseMatrix5D sparseW;
		sparseW = new SparseMatrix5D(obs.dimx(), obs.dimy(), obs.dimz(), obs.dimv(), obs.num_labels());
		float [] lp = new float [obs.num_labels()];
		
		// initialize the sparse W matrix
		for (int x = 0; x < obs.dimx(); x++) {
			print_status(x, obs.dimx());
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++)
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							set_label_probabilities(x, y, z, v, lp);	
							sparseW.init_voxel(x, y, z, v, lp);
						}
		}
		return(sparseW);
	}

	protected void set_label_probabilities(int x,
			                               int y,
			                               int z,
			                               int v,
			                               float [] lp) {
		set_label_probabilities(x, y, z, v, lp, adaptive_vals);
	}
	
	private void set_label_probabilities(int x,
										 int y,
										 int z,
										 int v,
										 float [] lp,
										 float [] vals) {
		
		// set the label probabilities
		obs.initialize_adaptive_probabilities(x, y, z, v, 1, lp);
		float [][] obsvals = obs.get_all_vals_full(x, y, z, v);
		for (int j = 0; j < obs.num_raters(); j++)
			for (int l = 0; l < obs.num_labels(); l++)
				lp[l] *= (obsvals[j][l] * vals[j]) + ((1 - obsvals[j][l]) * (1 - vals[j]));
		normalize_label_probabilities(lp);
	}
	
}