package edu.vanderbilt.masi.LabelFusion;

import java.lang.Error;

public class COLLATE extends StatisticalFusionBase {

	private int prior_flag;
	private double [] labelwise_t_prior;
	private double [][][][][] voxelwise_t_prior;
	private double [][] labelwise_c_prior;
	private double [][][][][] voxelwise_user_prior;
	private double [] alphas;
	private double [] cval;
	private double [][][] offd_Q;
	private double [][][] ond_Q;
	private int num_con_levels;
	
	public COLLATE (ObservationBase obs_in,
				    double epsilon_in,
				    int prior_flag_in,
				    int init_flag_in,
				    double [] alphas_in,
				    double [] cval_in) {
		
		
		// prior flag
		// 0 - global
		// 1 - voxelwise
		
		// init flag
		// 0 - initialize theta
		// 1 - initialize W
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = prior_flag_in;
		obs = obs_in;
		alphas = alphas_in;
		cval = cval_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		
		// set the number of consensus levels
		num_con_levels = alphas.length;
		
		// do a little error checking
		if (cval.length != num_con_levels-1)
			throw(new Error("Length of pcon should be one less than " +
							"the number of consensus levels"));
		
		// allocate space for the W and theta
		W = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][num_con_levels];
		theta = new double [obs.num_labels][obs.num_labels][obs.num_raters][1];
		offd_Q = new double [obs.num_labels][obs.num_labels][obs.num_raters];
		ond_Q = new double [num_con_levels][obs.num_labels][obs.num_raters];
		
		// set the default priors
		set_default_priors();
		
		// initialize the estimates
		if (init_flag == 0 || prior_flag == 0)
			super.init_theta();
		else
			set_W();	
		
		// set the label fusion actions
		calc_theta_action = new COLLATECalcTheta(this);
		calc_W_action = new COLLATECalcW(this, alphas);
		
	}
	
	public COLLATE (ObservationBase obs_in,
				    double epsilon_in,
				    int prior_flag_in,
				    int init_flag_in,
				    double alphas_in) {
		
		
		// prior flag
		// 0 - global
		// 1 - voxelwise
		
		// init flag
		// 0 - initialize theta
		// 1 - initialize W
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = prior_flag_in;
		obs = obs_in;
		alphas = new double [1];
		alphas[0] = alphas_in;
		cval = null;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		
		// set the number of consensus levels
		num_con_levels = alphas.length;
		
		// allocate space for the W and theta
		W = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][num_con_levels];
		theta = new double [obs.num_labels][obs.num_labels][obs.num_raters][1];
		offd_Q = new double [obs.num_labels][obs.num_labels][obs.num_raters];
		ond_Q = new double [num_con_levels][obs.num_labels][obs.num_raters];
		
		// set the default priors
		set_default_priors();
		
		// initialize the estimates
		if (init_flag == 0 || prior_flag == 0)
			super.init_theta();
		else
			set_W();	
		
		// set the label fusion actions
		calc_theta_action = new COLLATECalcTheta(this);
		calc_W_action = new COLLATECalcW(this, alphas);
		
	}
	
	public COLLATE (ObservationBase obs_in,
				    double epsilon_in,
				    double [][][][][] voxelwise_user_prior_in,
				    int init_flag_in,
				    double [] alphas_in,
				    double [] cval_in) {
		
		
		// prior flag
		// 0 - global
		// 1 - voxelwise
		
		// init flag
		// 0 - initialize theta
		// 1 - initialize W
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = 1;
		voxelwise_user_prior = voxelwise_user_prior_in;
		obs = obs_in;
		alphas = alphas_in;
		cval = cval_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		
		// set the number of consensus levels
		num_con_levels = alphas.length;
		
		// do a little error checking
		if (cval.length != num_con_levels-1)
			throw(new Error("Length of pcon should be one less than " +
							"the number of consensus levels"));
		
		// allocate space for the W and theta
		W = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][num_con_levels];
		theta = new double [obs.num_labels][obs.num_labels][obs.num_raters][1];
		offd_Q = new double [obs.num_labels][obs.num_labels][obs.num_raters];
		ond_Q = new double [num_con_levels][obs.num_labels][obs.num_raters];
		
		// initialize the estimates
		if (init_flag == 0 || prior_flag == 0)
			super.init_theta();
		else
			set_W();	
		
		// set the label fusion actions
		calc_theta_action = new COLLATECalcTheta(this);
		calc_W_action = new COLLATECalcW(this, alphas);
		
	}
	
	public COLLATE (ObservationBase obs_in,
				    double epsilon_in,
				    double [][][][][] voxelwise_user_prior_in,
				    int init_flag_in,
				    double alphas_in) {
		
		
		// prior flag
		// 0 - global
		// 1 - voxelwise
		
		// init flag
		// 0 - initialize theta
		// 1 - initialize W
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = 1;
		voxelwise_user_prior = voxelwise_user_prior_in;
		obs = obs_in;
		alphas = new double [1];
		alphas[0] = alphas_in;
		cval = null;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		
		// set the number of consensus levels
		num_con_levels = alphas.length;
		
		// allocate space for the W and theta
		W = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][num_con_levels];
		theta = new double [obs.num_labels][obs.num_labels][obs.num_raters][1];
		offd_Q = new double [obs.num_labels][obs.num_labels][obs.num_raters];
		ond_Q = new double [num_con_levels][obs.num_labels][obs.num_raters];
		
		// initialize the estimates
		if (init_flag == 0 || prior_flag == 0)
			super.init_theta();
		else
			set_W();	
		
		// set the label fusion actions
		calc_theta_action = new COLLATECalcTheta(this);
		calc_W_action = new COLLATECalcW(this, alphas);
		
	}

	private void set_default_priors() {
		
		// calculate the consensus prior
		labelwise_c_prior = new double [obs.num_labels][num_con_levels];
		
		if (num_con_levels == 1) {
			for (int l = 0; l < obs.num_labels; l++)
				labelwise_c_prior[l][0] = 1;
		} else {
			// iterate over every voxel
			for (int x = 0; x < obs.dims[0]; x++)
				for (int y = 0; y < obs.dims[1]; y++)
					for (int z = 0; z < obs.dims[2]; z++) {
						
						// calculate the label probabilities for this voxel
						double [] labelprobs = new double [obs.num_labels];
						obs.iterate_voxel(new VotingVoxelAdd(labelprobs), x, y, z);
						int est = get_estimate_value(labelprobs);
						
						// normalize the label probabilities
						double lsum = 0;
						for (int l = 0; l < obs.num_labels; l++)
							lsum += labelprobs[l];
						for (int l = 0; l < obs.num_labels; l++)
							labelprobs[l] /= lsum;
						
						// set the maximum label probability
						double max_label_prob = labelprobs[est];
					
						// set the appropriate consensus level information
						int crange = num_con_levels-1;
						for (int c = 0; c < num_con_levels-1 ; c++)
							if (max_label_prob < cval[c]) {
								crange = c;
								max_label_prob = Double.MAX_VALUE;
							}
						labelwise_c_prior[est][crange] += 1;
					}
			
			// normalize the labelwise_c_prior
			for (int l = 0; l < obs.num_labels; l++) {
				double psum = 0;
				for (int c = 0; c < num_con_levels; c++)
					psum += labelwise_c_prior[l][c];
				for (int c = 0; c < num_con_levels; c++)
					labelwise_c_prior[l][c] /= psum;
			}	
		}

		// set the truth prior
		if (prior_flag == 0)
			labelwise_t_prior = get_labelwise_t_prior();
		else 
			voxelwise_t_prior = get_voxelwise_t_prior();
	}
	
	private void set_W() {
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					for (int l = 0; l < obs.num_labels; l++)
						for (int c = 0; c < num_con_levels; c++)
							if (voxelwise_user_prior != null)
								W[x][y][z][l][c] = voxelwise_user_prior[x][y][z][l][c];
							else
								W[x][y][z][l][c] = voxelwise_t_prior[x][y][z][l][0] * labelwise_c_prior[l][c];
	}
	
	@Override
	protected void calc_W (LabelFusionAction lfa) {

		// set the initial value of W to the truth prior
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					for (int l = 0; l < obs.num_labels; l++)
						for (int c = 0; c < num_con_levels; c++)
							if (voxelwise_user_prior != null)
								W[x][y][z][l][c] = voxelwise_user_prior[x][y][z][l][c]; 
							else if (prior_flag == 0)
								W[x][y][z][l][c] = labelwise_t_prior[l] * labelwise_c_prior[l][c];
							else
								W[x][y][z][l][c] = voxelwise_t_prior[x][y][z][l][0] * labelwise_c_prior[l][c];
		
		// update
		obs.iterate_votes(lfa);
		
		// normalize the data
		normalize_voxelwise_data(W);
	}

	@Override
	protected void calc_theta(LabelFusionAction lfa) {
		
		// initialize theta
		for (int s = 0; s < obs.num_labels; s++)
			for (int l = 0; l < obs.num_labels; l++)
				for (int r = 0; r < obs.num_raters; r++) {
					theta[s][l][r][0] = 0;
					offd_Q[s][l][r] = 0;
				}

		for (int c = 0; c < num_con_levels; c++)
			for (int l = 0; l < obs.num_labels; l++)
				for (int r = 0; r < obs.num_raters; r++)
					ond_Q[c][l][r] = 0;
					
		// iterate over all of the votes and calculate theta
		obs.iterate_votes(lfa);
		
		// perform the 1-D optimization for each label and each rater
		for (int r = 0; r < obs.num_raters; r++)
			for (int s = 0; s < obs.num_labels; s++) {
				
				double sum_offd_Q = 0;
				for (int i = 0; i < obs.num_labels; i++)
					sum_offd_Q += offd_Q[i][s][r];
			
				double epsilon = 0.001;
				double cfact = Double.MAX_VALUE;
				double lower_bound = 0;
				double upper_bound = 1;
				double lambda = 0;
				double estimate = 0;
				
				int count = 0;
				
				while (cfact > epsilon) {
					
					// set the estimate
					estimate = (upper_bound + lower_bound) / 2;
					
					// calcualate the lambda estimate
					lambda = 0;
					for (int c = 0; c < num_con_levels; c++)
						lambda += ond_Q[c][s][r] / (alphas[c] + estimate);
					lambda *= -1;
					
					// see how accurate it the estimate is
					double guess = estimate + sum_offd_Q / (-lambda);
					cfact = Math.abs(1 - guess);
					
					// adjust the bounds
					if (guess > 1)
						upper_bound = estimate;
					else
						lower_bound = estimate;
					
					count++;
					if (count > 20)
						cfact = 0;
					
				}
				
				for (int l = 0; l < obs.num_labels; l++)
					if (s == l)
						theta[l][s][r][0] = estimate;
					else
						theta[l][s][r][0] = offd_Q[l][s][r] / (-lambda);
			}
		
		normalize_theta();
	}

	private class COLLATECalcW implements LabelFusionAction {
		private StatisticalFusionBase sfb;
		private int num_con_levels;
		private double [] alphas;
		
		public COLLATECalcW(StatisticalFusionBase sfb_in, double [] alphas_in) {
			alphas = alphas_in;
			sfb = sfb_in;
			num_con_levels = W[0][0][0][0].length;
		}
		
		public void run(int x, int y, int z, int r, int v) {
			
			
			// iterate over all consensus levels
			for (int c = 0; c < num_con_levels; c++)
				// iterate over all labels
				for (int s = 0; s < sfb.obs.num_labels; s++) {
					double label_factor;
					if (v == s)
						label_factor = (alphas[c] + sfb.theta[v][s][r][0]) / (1 + alphas[c]);
					else
						label_factor = (sfb.theta[v][s][r][0]) / (1 + alphas[c]);
						
					W[x][y][z][s][c] *= label_factor;
				}

		}
	}
	
	private class COLLATECalcTheta implements LabelFusionAction {
		private StatisticalFusionBase sfb;
		private int num_con_levels;
		
		
		public COLLATECalcTheta(StatisticalFusionBase sfb_in) {
			sfb = sfb_in;
			num_con_levels = W[0][0][0][0].length;
		}
		
		public void run(int x, int y, int z, int r, int v) {
			for (int s = 0; s < sfb.obs.num_labels; s++)
				for (int c = 0; c < num_con_levels; c++) {
					double W_fact = sfb.W[x][y][z][s][c];
					if (s != v)
						offd_Q[v][s][r] += W_fact;
					else
						ond_Q[c][s][r] += W_fact;
				}	
		}
	}

}
