package edu.vanderbilt.masi.LabelFusion;

import edu.vanderbilt.masi.LabelFusion.LabelFusionTools;

public class SIMPLE extends LabelFusionBase {
	
	private double alpha;
	private int num_iter_keep;
	private int [][][] prev_estimate;
	private double [][] performance;
	private int iter;
	private double thresh;
	private int performance_type;
	
	public SIMPLE (ObservationBase obs_in,
				   int num_iter_keep_in,
				   double alpha_in) {
		
		// take in the input arguments
		obs = obs_in;
		alpha = alpha_in;
		num_iter_keep = num_iter_keep_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		prev_estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		performance = new double [obs.num_raters][obs.num_labels];
		iter = 0;
		performance_type = 0;
	}
	
	public SIMPLE (ObservationBase obs_in,
				   int num_iter_keep_in,
				   double alpha_in,
				   int performance_type_in) {
		
		// take in the input arguments
		obs = obs_in;
		alpha = alpha_in;
		num_iter_keep = num_iter_keep_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		prev_estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		performance = new double [obs.num_raters][obs.num_labels];
		iter = 0;
		performance_type = performance_type_in;
	}
	
	public int [][][] run () {
		
		boolean converge = false;
		
		VotingFusionBase fb = new MajorityVote(obs, 1);
		estimate = fb.run();

		while (!converge) {
			
			iter++;
			
			for (int x = 0; x < obs.dims[0]; x++)
				for (int y = 0; y < obs.dims[1]; y++)
					for (int z = 0; z < obs.dims[2]; z++)
						prev_estimate[x][y][z] = estimate[x][y][z];
			
			// set the performance parameters
			calc_performance();
			
			// set the threshold
			set_threshold();
			
			// calculate the new estimate
			fb = new WeightedVote(obs, performance, 1);
			estimate = fb.run();
			
			// determine whether or not we have converged
			converge = calc_convergence();
			
			System.out.println("SIMPLE Iteration: " + iter);
		}
		
		// return the estimate
		return(estimate);
	}
	
	private void calc_performance() {
		for (int r = 0; r < obs.num_raters; r++)
			if (performance_type == 1)
				LabelFusionTools.dice(obs, estimate, r, performance[r]);
			else if (performance_type == 2)
				LabelFusionTools.jaccard(obs, estimate, r, performance[r]);
			else
				LabelFusionTools.sensitivity(obs, estimate, r, performance[r]);
		
		set_threshold();
		
		double [] rater_performance = new double [obs.num_raters];
		
		for (int r = 0; r < obs.num_raters; r++) {
			
			// calculate the mean performance for each rater
			for (int l = 0; l < obs.num_labels; l++)
				rater_performance[r] += performance[r][l];
			rater_performance[r] /= obs.num_labels;
			
			if (rater_performance[r] < thresh && iter > num_iter_keep)
				for (int l = 0; l < obs.num_labels; l++)
					performance[r][l] = 0;
		}
	}
	
	private boolean calc_convergence() {
		boolean converge = true;
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					if (prev_estimate[x][y][z] != estimate[x][y][z])
						converge = false;
		return(converge);
	}

	private void set_threshold() {
		
		double mean_performance = 0;
		double std_performance = 0;
		
		double [] rater_performance = new double [obs.num_raters];
		
		// calculate the mean performance
		for (int r = 0; r < obs.num_raters; r++) {
			
			// calculate the mean performance for each rater
			for (int l = 0; l < obs.num_labels; l++)
				rater_performance[r] += performance[r][l];
			rater_performance[r] /= obs.num_labels;
			
			mean_performance += rater_performance[r];
		}
		mean_performance /= obs.num_raters;
		
		// calculate the std of the performance
		for (int r = 0; r < obs.num_raters; r++)
			std_performance += Math.pow(rater_performance[r] - mean_performance, 2);
		std_performance = Math.sqrt((1/(obs.num_raters)) * std_performance);
		
		// set the threshold
		thresh = mean_performance - alpha * std_performance;
	}
}
