package edu.vanderbilt.masi.LabelFusion;

public abstract class StatisticalFusionBase extends LabelFusionBase {
	
	// protected variables
	protected double [][][][] theta;
	protected double [][][][][] W;
	protected int init_flag;
	protected double epsilon;
	protected LabelFusionAction calc_theta_action;
	protected LabelFusionAction calc_W_action;
	protected int debug = 1;
	
	public int [][][] run() {
		
		double convergence_factor = Double.MAX_VALUE;
		int num_u = theta[0][0][0].length;
		double [][][][] theta_prev = 
			new double [obs.num_labels][obs.num_labels][obs.num_raters][num_u];
		
		// run the basic statistical fusion algorithm
		while (convergence_factor > epsilon) {
			
			// set the previous theta
			for (int s = 0; s < obs.num_labels; s++)
				for (int l = 0; l < obs.num_labels; l++)
					for (int r = 0; r < obs.num_raters; r++)
						for (int u = 0; u < num_u; u++)
							theta_prev[s][l][r][u] = theta[s][l][r][u];
	
			// determine whether to calculate theta or W first
			if (init_flag == 0) {
				// calculate W
				calc_W(calc_W_action);
				
				// calculate theta
				calc_theta(calc_theta_action);
				
			} else {
				// calculate theta
				calc_theta(calc_theta_action);
				
				// calculate W
				calc_W(calc_W_action);
			}
			
			// calculate the current convergence factor
			convergence_factor = calc_convergence_factor(theta_prev);
		}
		
		// set the estimate from the final value for W
		set_estimate_from_voxelwise_data(W);
		
		return(estimate);
		
	}
	
	public double [][][][] get_theta() { return(theta); }
	
	public double [][][][][] get_W() { return(W); }
	
	protected abstract void calc_theta(LabelFusionAction lfa);
	
	protected abstract void calc_W(LabelFusionAction lfa);

	protected double calc_convergence_factor(double [][][][] theta_prev) {
		double current_sum = 0;
		double previous_sum = 0;
		double convergence_factor;
		int num = theta[0][0][0].length;

		// see if we are converging yet
		for (int u = 0; u < num; u++)
			for (int r = 0; r < obs.num_raters; r++) {
				
				double theta_trace = 0;
				double theta_prev_trace = 0;
				
				for (int s = 0; s < obs.num_labels; s++) {
					theta_trace += theta[s][s][r][u];
					theta_prev_trace += theta_prev[s][s][r][u];
				}
				current_sum += theta_trace;
				previous_sum += theta_prev_trace;
			}

		convergence_factor = Math.abs(current_sum - previous_sum) / 
									 (obs.num_raters * obs.num_labels * num);
		
		if (debug == 1)
			System.out.println(convergence_factor);
		
		return(convergence_factor);
	}
	
	protected void init_theta() {
		
		// initialize theta
		int nu = theta[0][0][0].length;
		for (int i = 0; i < obs.num_labels; i++)
			for (int j = 0; j < obs.num_labels; j++)
				for (int k = 0; k < obs.num_raters; k++)
					for (int u = 0; u < nu; u++)
						if (i == j) {
							theta[i][j][k][u] = 0.9999;
						} else {
							theta[i][j][k][u] = 0.0001;
						}
		
		// normalize theta
		normalize_theta();
	}
	
	protected void normalize_theta() {
		
		int num = theta[0][0][0].length;
		// normalize each row in theta
		for (int l = 0; l < obs.num_labels; l++)
			for (int r = 0; r < obs.num_raters; r++)
				for (int u = 0; u < num; u++) {
				
					// keep track of the current sum
					double sum = 0;
					for (int s = 0; s < obs.num_labels; s++)
						sum += theta[s][l][r][u];
					
					// normalize
					for (int s = 0; s < obs.num_labels; s++)
						if (sum == 0)
							if (s == l)
								theta[s][l][r][u] = 1;
							else
								theta[s][l][r][u] = 0;
						else
							theta[s][l][r][u] /= sum;
			}
	}

	protected void normalize_theta(double [][][] th) {

		// normalize each row in theta
		for (int l = 0; l < obs.num_labels; l++)
			for (int r = 0; r < obs.num_raters; r++) {
				
				// keep track of the current sum
				double sum = 0;
				for (int s = 0; s < obs.num_labels; s++)
					sum += th[s][l][r];
				
				// normalize
				for (int s = 0; s < obs.num_labels; s++)
					if (sum == 0)
						if (s == l)
							th[s][l][r] = 1;
						else
							th[s][l][r] = 0;
					else
						th[s][l][r] /= sum;
			}
	}
	
	protected double [] get_labelwise_t_prior() {
		double [] labelwise_t_prior = new double [obs.num_labels];
		
		// calculate the label wise prior
		obs.iterate_votes(new LabelwiseDataAdd(labelwise_t_prior));
		
		// normalize the label prior
		double sum_l_prior = 0;
		for (int l = 0; l < obs.num_labels; l++)
			sum_l_prior += labelwise_t_prior[l];
		for (int l = 0; l < obs.num_labels; l++)
			labelwise_t_prior[l] /= sum_l_prior;
		
		return labelwise_t_prior;
	}
	
	protected double [][][] get_majority_vote_theta() {
		boolean [][][] consensus = new boolean [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		return(get_majority_vote_theta(consensus));
	}
	
	protected double [][][] get_majority_vote_theta(boolean [][][] consensus) {
		double [][][] mvtheta = new double [obs.num_labels][obs.num_labels][obs.num_raters];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++) {
					if (!consensus[x][y][z]) {
						// get the label probabilities from a majority vote
						double [] labelprobs = new double [obs.num_labels];
						obs.iterate_voxel(new VotingVoxelAdd(labelprobs), x, y, z);
						normalize_label_probabilities(labelprobs);
						
						// add the theta values in the appropriate places
						for (int r = 0; r < obs.num_raters; r++) {
							int [] vote = obs.get_vote(x, y, z, r);
							for (int v = 0; v < vote.length; v++)
								for (int l = 0; l < obs.num_labels; l++)
									mvtheta[vote[v]][l][r] += labelprobs[l];
						}
					}
				}
		
		// normalize the majority vote theta
		normalize_theta(mvtheta);
		
		return(mvtheta);
	}
	
	protected boolean [][][] get_consensus_voxels() {
		
		boolean [][][] consensus = new boolean [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++) {
					
					double [] labelprobs = new double [obs.num_labels];
					obs.iterate_voxel(new VotingVoxelAdd(labelprobs), x, y, z);
					normalize_label_probabilities(labelprobs);
					
					for (int l = 0; l < obs.num_labels; l++)
						if (labelprobs[l] == 1)
							consensus[x][y][z] = true;
				}
		
		return(consensus);
	}
	
	protected void get_consensus_voxels(boolean [][][] consensus) {
		
		// iterate over every voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++) {
					
					double [] labelprobs = new double [obs.num_labels];
					obs.iterate_voxel(new VotingVoxelAdd(labelprobs), x, y, z);
					normalize_label_probabilities(labelprobs);
					
					for (int l = 0; l < obs.num_labels; l++)
						if (labelprobs[l] == 1)
							consensus[x][y][z] = true;
				}
	}
	
	protected boolean [][][][] get_boundary_voxels() {
		
		boolean [][][][] boundary = new boolean [obs.dims[1]][obs.dims[2]][obs.dims[3]][obs.num_raters];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					for (int r = 0; r < obs.num_raters; r++) {
						if (x == 0 || x == obs.dims[0]-1)
							boundary[x][y][z][r] = true;
						else if (y == 0 || y == obs.dims[1]-1)
							boundary[x][y][z][r] = true;
						else if (z == 0 || z == obs.dims[2]-1)
							boundary[x][y][z][r] = true;
						else {
							
							// first get the users vote at this voxel
							int [] vote = obs.get_vote(x, y, z, r);
							if (vote.length == 0)
								boundary [x][y][z][r] = false;
							else {
								
								// set the label of interest for this voxel
								int label = vote[0];
								
								// make sure they dont have a discrepancy at this voxel
								for (int v = 1; v < vote.length; v++)
									if (vote[v] != label) {
										boundary[x][y][z][r] = true;
										continue;
									}
								// check the six immediate neighbors
								vote = obs.get_vote(x-1, y, z, r);
								for (int v = 0; v < vote.length; v++)
									if (vote[v] != label) {
										boundary[x][y][z][r] = true;
										continue;
									}
								vote = obs.get_vote(x+1, y, z, r);
								for (int v = 0; v < vote.length; v++)
									if (vote[v] != label) {
										boundary[x][y][z][r] = true;
										continue;
									}
								vote = obs.get_vote(x, y-1, z, r);
								for (int v = 0; v < vote.length; v++)
									if (vote[v] != label) {
										boundary[x][y][z][r] = true;
										continue;
									}
								vote = obs.get_vote(x, y+1, z, r);
								for (int v = 0; v < vote.length; v++)
									if (vote[v] != label) {
										boundary[x][y][z][r] = true;
										continue;
									}
								vote = obs.get_vote(x, y, z-1, r);
								for (int v = 0; v < vote.length; v++)
									if (vote[v] != label) {
										boundary[x][y][z][r] = true;
										continue;
									}
								vote = obs.get_vote(x, y, z+1, r);
								for (int v = 0; v < vote.length; v++)
									if (vote[v] != label) {
										boundary[x][y][z][r] = true;
										continue;
									}
								
							}
						}
						
						
					
					}
		
		return(boundary);
	}
		
	protected double [][][][][] get_voxelwise_t_prior() {
		// voxelwise prior
		double [][][][][] voxelwise_t_prior = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][1];
		set_voxelwise_data(voxelwise_t_prior);
		return(voxelwise_t_prior);
	}

}
