package edu.vanderbilt.masi.LabelFusion;

public class STAPLER extends StatisticalFusionBase {

	private int prior_flag;
	public double [] labelwise_t_prior;
	public double [][][][][] voxelwise_t_prior;
	public double [][][] bias_theta;
	private boolean [][][] consensus; 
	
	public STAPLER (ObservationBase obs_in,
				   double epsilon_in,
				   int prior_flag_in,
				   int init_flag_in,
				   int consensus_flag_in,
				   double [][][] bias_theta_in) {
		
		
		// prior flag
		// 0 - global
		// 1 - voxelwise
		
		// init flag
		// 0 - initialize theta
		// 1 - initialize W
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = prior_flag_in;
		obs = obs_in;
		bias_theta = bias_theta_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		consensus = new boolean [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		if (consensus_flag_in == 1)
			get_consensus_voxels(consensus);
		else
			for (int x = 0; x < obs.dims[0]; x++)
				for (int y = 0; y < obs.dims[1]; y++)
					for (int z = 0; z < obs.dims[2]; z++)
						consensus[x][y][z] = false;
		
		// allocate space for the W and theta
		W = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][1];
		theta = new double [obs.num_labels][obs.num_labels][obs.num_raters][1];
		
		// set the truth prior
		if (prior_flag == 0)
			labelwise_t_prior = get_labelwise_t_prior();
		else
			voxelwise_t_prior = get_voxelwise_t_prior();
		
		// initialize the estimates
		if (init_flag == 0)
			super.init_theta();
		else
			super.set_voxelwise_data(W);
		
		// set the label fusion actions
		calc_theta_action = new STAPLERCalcTheta(this);
		calc_W_action = new STAPLERCalcW(this);
		
	}
	
	public STAPLER (ObservationBase obs_in,
			   		double epsilon_in,
			   		double [][][][] voxelwise_t_prior_in,
			   		int init_flag_in,
			   		int consensus_flag_in,
			   		double [][][] bias_theta_in) {
		
		// init flag
		// 0 - initialize theta
		// 1 - initialize W
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = 1;
		obs = obs_in;
		bias_theta = bias_theta_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		consensus = new boolean [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		if (consensus_flag_in == 1)
			get_consensus_voxels(consensus);
		else
			for (int x = 0; x < obs.dims[0]; x++)
				for (int y = 0; y < obs.dims[1]; y++)
					for (int z = 0; z < obs.dims[2]; z++)
						consensus[x][y][z] = false;
		
		// allocate space for the W and theta
		W = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][1];
		theta = new double [obs.num_labels][obs.num_labels][obs.num_raters][1];
		
		// set the truth prior
		voxelwise_t_prior = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][1];
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					for (int l = 0; l < obs.num_labels; l++)
						voxelwise_t_prior[x][y][z][l][0] = voxelwise_t_prior_in[x][y][z][l];
		
		// initialize the estimates
		if (init_flag == 0)
			super.init_theta();
		else
			super.set_voxelwise_data(W);
		
		// set the label fusion actions
		calc_theta_action = new STAPLERCalcTheta(this);
		calc_W_action = new STAPLERCalcW(this);
		
	}			   

	@Override
	public void calc_W(LabelFusionAction lfa) {
		// set the initial value of W to the truth prior
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					for (int l = 0; l < obs.num_labels; l++)
						if (prior_flag == 0)
							W[x][y][z][l][0] = labelwise_t_prior[l];
						else
							W[x][y][z][l][0] = voxelwise_t_prior[x][y][z][l][0];
	
		// update STAPLEW
		obs.iterate_votes(lfa);
		
		// normalize the data
		normalize_voxelwise_data(W);
	}

	@Override
	public void calc_theta(LabelFusionAction lfa) {
		// initialize theta
		for (int s = 0; s < obs.num_labels; s++)
			for (int l = 0; l < obs.num_labels; l++)
				for (int r = 0; r < obs.num_raters; r++)
					theta[s][l][r][0] = bias_theta[s][l][r];
		
		// iterate over all of the votes and calculate theta
		obs.iterate_votes(lfa);
		
		// normalize the theta estimate
		normalize_theta();
	}

	private class STAPLERCalcW implements LabelFusionAction {
	    private StatisticalFusionBase sfb;

	    public STAPLERCalcW(StatisticalFusionBase sfb_in) {
	        sfb = sfb_in;
	    }

	    public void run(int x, int y, int z, int r, int v) {
	    	if (consensus[x][y][z])
	    		for (int s = 0; s < sfb.obs.num_labels; s++)
	    			sfb.W[x][y][z][s][0] = (s == v) ? 1 : 0;
	    	else	
		        for (int s = 0; s < sfb.obs.num_labels; s++)
		            sfb.W[x][y][z][s][0] *= sfb.theta[v][s][r][0];
	    }
	}
	
	private class STAPLERCalcTheta implements LabelFusionAction {
	    private StatisticalFusionBase sfb;

	    public STAPLERCalcTheta(StatisticalFusionBase sfb_in) {
	        sfb = sfb_in;
	    }

	    public void run(int x, int y, int z, int r, int v) {
	    	if (!consensus[x][y][z])
	    		for (int s = 0; s < sfb.obs.num_labels; s++)
	    			sfb.theta[v][s][r][0] += sfb.W[x][y][z][s][0];
	    }
	}
}
