package edu.vanderbilt.masi.LabelFusion;

public abstract class LabelFusionBase {

	// public variables
	public int [][][] estimate;
	
	public abstract int [][][] run();
	
	// protected variables
	protected ObservationBase obs;

	protected int get_estimate_value(double [] labelprobs) {
		double max_val = -1;
		int est = 0;
		for (int l = 0; l < obs.num_labels; l++) {
			if (labelprobs[l] > max_val) {
				max_val = labelprobs[l];
				est = l;
			}
		}
		return(est);
	}
	
	protected void set_voxelwise_data(double [][][][][] voxeldata) {
		
		// calculate the voxelwise data
		obs.iterate_votes(new VoxelwiseDataAdd(voxeldata));
		
		// normalize the data
		normalize_voxelwise_data(voxeldata);
	}
	
	protected void normalize_voxelwise_data(double [][][][][] voxeldata) {
		
		int num_con_levels = voxeldata[0][0][0][0].length;
		// iterate over each voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++) {
					
					// keep track of the number of votes
					double num_votes = 0;
					for (int l = 0; l < obs.num_labels; l++)
						for (int c = 0; c < num_con_levels; c++)
							num_votes += voxeldata[x][y][z][l][c];
					
					// normalize
					for (int l = 0; l < obs.num_labels; l++)
						for (int c = 0; c < num_con_levels; c++)
							voxeldata[x][y][z][l][c] /= num_votes;
				}
	}
	
	protected void set_estimate_from_voxelwise_data(double [][][][][] voxeldata) {
		
		int num_con_levels = voxeldata[0][0][0][0].length;
		
		// iterate over each voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++) {
					
					// find the label that has the max value
					double max = -1;
					int max_ind = -1;
					for (int s = 0; s < obs.num_labels; s++) {
						double val = 0;
						for (int c = 0; c < num_con_levels; c++)
							val += voxeldata[x][y][z][s][c];
						
						if (val > max) {
							max = val;
							max_ind = s;
						}
					}
			
					// set the estimate for that voxel
					estimate[x][y][z] = max_ind;
				}
				
	}

	protected void normalize_label_probabilities(double [] labelprobs) {
		// keep track of the number of votes
		double sum = 0;
		for (int l = 0; l < obs.num_labels; l++)
			sum += labelprobs[l];
		for (int l = 0; l < obs.num_labels; l++)
			labelprobs[l] /= sum;
	}
	
	protected class VoxelwiseDataAdd implements LabelFusionAction {
	    private double [][][][][] data;
	    private double [][][][] voxelweights = null;
	    private double [] globalweights = null;
	    private double [][] labelweights = null;

	    public VoxelwiseDataAdd(double [][][][][] data_in) {
	        data = data_in;
	    }
	    
	    public VoxelwiseDataAdd(double [][][][][] data_in,
	    						double [][][][] weights_in) {
	    	data = data_in;
	    	voxelweights = weights_in;
	    }
	    
	    public VoxelwiseDataAdd(double [][][][][] data_in,
								double [] weights_in) {
	    	data = data_in;
			globalweights = weights_in;
	    }
	    
	    public VoxelwiseDataAdd(double [][][][][] data_in,
							    double [][] weights_in) {
	    	data = data_in;
	    	labelweights = weights_in;
	    }

	    public void run(int x, int y, int z, int r, int v) {
	    	if (voxelweights != null)
	    		data[x][y][z][v][0] += voxelweights[x][y][z][r];
	    	else if (globalweights != null)
	    		data[x][y][z][v][0] += globalweights[r];
	    	else if (labelweights != null)
	    		data[x][y][z][v][0] += labelweights[r][v];
	    	else
	    		data[x][y][z][v][0] += 1;
	    }
	}
	
	protected class LabelwiseDataAdd implements LabelFusionAction {
	    private double [] data;

	    public LabelwiseDataAdd(double [] data_in) {
	        data = data_in;
	    }

	    public void run(int x, int y, int z, int r, int v) {
	        data[v] += 1;
	    }
	}
	
	protected class VotingVoxelAdd implements LabelFusionAction {
	    private double [] labelprobs;
	    private double [][][][] voxelweights = null;
	    private double [] globalweights = null;
	    private double [][] labelweights = null;

	    public VotingVoxelAdd(double [] lp_in) {
	        labelprobs = lp_in;
	    }
	    
	    public VotingVoxelAdd(double [] lp_in,
	    					  double [][][][] weights) {
	    	voxelweights = weights;
	        labelprobs = lp_in;
	    }
	    
	    public VotingVoxelAdd(double [] lp_in,
	    					  double [] weights) {
	    	globalweights = weights;
	        labelprobs = lp_in;
	    }
	    
	    public VotingVoxelAdd(double [] lp_in,
	    					  double [][] weights) {
	    	labelweights = weights;
	        labelprobs = lp_in;
	    }

	    public void run(int x, int y, int z, int r, int v) {
	    	if (voxelweights != null)
	    		labelprobs[v] += voxelweights[x][y][z][r];
	    	else if (globalweights != null)
	    		labelprobs[v] += globalweights[r];
	    	else if (labelweights != null)
	    		labelprobs[v] += labelweights[r][v];
	    	else
	    		labelprobs[v] += 1;
	    }
	}
	
}
