package edu.vanderbilt.masi.LabelFusion;

public class SpatialSTAPLE extends StatisticalFusionBase {

	private int prior_flag;
	private double [] labelwise_t_prior;
	private double [][][][][] voxelwise_t_prior;
	private SpatialData sd;
	private double [][][] bias_theta;
	private double bias;
	private int current_u;
	private boolean [][][] consensus;
	private int interp_type;
	private int [][][] num_label_rater_region;
	private int cons_flag;
	
	public SpatialSTAPLE (ObservationBase obs_in,
						  double epsilon_in,
						  int prior_flag_in,
						  int init_flag_in,
						  int cons_flag_in,
						  int interp_type_in,
						  int [] num_up,
						  int [] win_dims,
						  double bias_in) {
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = prior_flag_in;
		cons_flag = cons_flag_in;
		interp_type = interp_type_in;
		obs = obs_in;
		bias = bias_in;
		
		initialize(num_up, win_dims);
		
		// set the biasing theta to the majority vote estimate ignoring consensus
		bias_theta = get_majority_vote_theta(consensus);
		
	}
	
	public SpatialSTAPLE (ObservationBase obs_in,
						  double epsilon_in,
						  double [][][][] voxelwise_t_prior_in,
						  int init_flag_in,
						  int cons_flag_in,
						  int interp_type_in,
						  int [] num_up,
						  int [] win_dims,
						  double bias_in) {
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = 2;
		cons_flag = cons_flag_in;
		interp_type = interp_type_in;
		obs = obs_in;
		bias = bias_in;
		// set the truth prior
		voxelwise_t_prior = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][1];
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					for (int l = 0; l < obs.num_labels; l++)
						voxelwise_t_prior[x][y][z][l][0] = voxelwise_t_prior_in[x][y][z][l];
		
		initialize(num_up, win_dims);
		
		// set the biasing theta to the majority vote estimate ignoring consensus
		bias_theta = get_majority_vote_theta(consensus);
	}
	
	public SpatialSTAPLE (ObservationBase obs_in,
						  double epsilon_in,
						  int prior_flag_in,
						  int init_flag_in,
						  int cons_flag_in,
						  int interp_type_in,
						  int [] num_up,
						  int [] win_dims,
						  double bias_in,
						  double [][][] bias_theta_in) {
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = prior_flag_in;
		cons_flag = cons_flag_in;
		interp_type = interp_type_in;
		obs = obs_in;
		bias = bias_in;

		initialize(num_up, win_dims);
		
		// set the biasing theta to the majority vote estimate ignoring consensus
		bias_theta = bias_theta_in;
		
	}
	
	public SpatialSTAPLE (ObservationBase obs_in,
						  double epsilon_in,
						  double [][][][] voxelwise_t_prior_in,
						  int init_flag_in,
						  int cons_flag_in,
						  int interp_type_in,
						  int [] num_up,
						  int [] win_dims,
						  double bias_in,
						  double [][][] bias_theta_in) {
		
		// take in the input arguments
		epsilon = epsilon_in;
		init_flag = init_flag_in;
		prior_flag = 2;
		cons_flag = cons_flag_in;
		interp_type = interp_type_in;
		obs = obs_in;
		bias = bias_in;
		
		// set the truth prior
		voxelwise_t_prior = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][1];
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					for (int l = 0; l < obs.num_labels; l++)
						voxelwise_t_prior[x][y][z][l][0] = voxelwise_t_prior_in[x][y][z][l];

		initialize(num_up, win_dims);
		
		// set the biasing theta to the majority vote estimate ignoring consensus
		bias_theta = bias_theta_in;
	}

	public int [][] get_seed_points() { return sd.seed_points; }
	
	private void initialize (int [] num_up,
			  				 int [] win_dims) {
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		
		sd = new SpatialData(obs, num_up, win_dims, interp_type);
		num_label_rater_region = new int [obs.num_labels][obs.num_raters][sd.num_upds]; 
		
		// allocate space for the W and theta
		W = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][1];
		theta = new double [obs.num_labels][obs.num_labels][obs.num_raters][sd.num_upds];
		
		// set the consensus values
		consensus = new boolean [obs.dims[0]][obs.dims[1]][obs.dims[2]];
		if (prior_flag <= 1 || cons_flag <= 1)
			get_consensus_voxels(consensus);
		else
			for (int x = 0; x < obs.dims[0]; x++)
				for (int y = 0; y < obs.dims[1]; y++)
					for (int z = 0; z < obs.dims[2]; z++) {
						double maxprob = -1;
						for (int l = 0; l < obs.num_labels; l++)
							if (voxelwise_t_prior[x][y][z][l][0] > maxprob)
								maxprob = voxelwise_t_prior[x][y][z][l][0];
						consensus[x][y][z] = maxprob > 0.95 ? true : false;
					}
			
		// set the truth prior
		if (prior_flag == 0)
			labelwise_t_prior = get_labelwise_t_prior();
		else if (prior_flag == 1)
			voxelwise_t_prior = get_voxelwise_t_prior();
		else
			prior_flag = 1;
		
		// initialize the estimates
		if (init_flag == 0)
			super.init_theta();
		else
			super.set_voxelwise_data(W);
		
		// set the label fusion actions
		calc_theta_action = new SpatialSTAPLECalcTheta(this);
		calc_W_action = new SpatialSTAPLECalcW(this);
	}
	
	@Override
	protected void calc_W(LabelFusionAction lfa) {
		// set the initial value of W to the truth prior
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++)
					for (int l = 0; l < obs.num_labels; l++)
						if (prior_flag == 0)
							W[x][y][z][l][0] = labelwise_t_prior[l];
						else
							W[x][y][z][l][0] = voxelwise_t_prior[x][y][z][l][0];
		
		// update STAPLEW
		obs.iterate_votes(lfa);
		
		// normalize the data
		normalize_voxelwise_data(W);
	}

	@Override
	protected void calc_theta(LabelFusionAction lfa) {
		// initialize theta
		for (int s = 0; s < obs.num_labels; s++)
			for (int r = 0; r < obs.num_raters; r++)
				for (int u = 0; u < sd.num_upds; u++) {
					num_label_rater_region[s][r][u] = 0;
					for (int l = 0; l < obs.num_labels; l++)
						theta[s][l][r][u] = 0;
				}
		
		// iterate over all of the votes and calculate theta
		for (int u = 0; u < sd.num_upds; u++) {
			current_u = u;
			obs.iterate_region(lfa, 
							   sd.start_coords[u][0],
							   sd.start_coords[u][1],
							   sd.start_coords[u][2],
							   sd.end_coords[u][0],
							   sd.end_coords[u][1],
							   sd.end_coords[u][2]);
		}
		
		// apply the bias to the theta estimate
		for (int s = 0; s < obs.num_labels; s++)
			for (int r = 0; r < obs.num_raters; r++)
				for (int u = 0; u < sd.num_upds; u++)
					if (num_label_rater_region[s][r][u] < bias) {
						double bias_amt = bias - num_label_rater_region[s][r][u];
						for (int l = 0; l < obs.num_labels; l++)
							theta[s][l][r][u] += bias_amt * bias_theta[s][l][r];
					}

		// normalize the theta estimate
		normalize_theta();
	}

	private class SpatialSTAPLECalcW implements LabelFusionAction {
	    private StatisticalFusionBase sfb;

	    public SpatialSTAPLECalcW(StatisticalFusionBase sfb_in) {
	        sfb = sfb_in;
	    }

	    public void run(int x, int y, int z, int r, int v) {
	    	if (consensus[x][y][z] && cons_flag == 1)
	    		for (int s = 0; s < sfb.obs.num_labels; s++)
	    			sfb.W[x][y][z][s][0] = (s == v) ? 1 : 0;
	    	else if (!consensus[x][y][z] || cons_flag == 0) {
		    	double [] rater_theta = sd.get_interp_theta(x, y, z, r, v, sfb.theta);
		        for (int s = 0; s < sfb.obs.num_labels; s++)
		            sfb.W[x][y][z][s][0] *= rater_theta[s];
	    	}
	    }
	}
	
	private class SpatialSTAPLECalcTheta implements LabelFusionAction {
	    private StatisticalFusionBase sfb;

	    public SpatialSTAPLECalcTheta(StatisticalFusionBase sfb_in) {
	        sfb = sfb_in;
	    }

	    public void run(int x, int y, int z, int r, int v) {
	    	num_label_rater_region[v][r][current_u]++;
	    	if (!consensus[x][y][z] || cons_flag == 0) {
		        for (int s = 0; s < sfb.obs.num_labels; s++)
		        	sfb.theta[v][s][r][current_u] += sfb.W[x][y][z][s][0];
	    	}
	    }
	}
}
