package edu.vanderbilt.masi.LabelFusion;

public class WeightedVote extends VotingFusionBase {

	private double [][][][] voxelweights;
	private double [] globalweights;
	private double [][] labelweights;
	private byte type;
	
	//
	// voxelwise weighted vote
	//
	public WeightedVote (ObservationBase obs_in, double [][][][] weights_in) {
		// take in the input arguments
		obs = obs_in;
		voxelweights = weights_in;
		type = 1;
		fast = 0;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
	}
	
	public WeightedVote (ObservationBase obs_in, double [][][][] weights_in, int fast_in) {
		// take in the input arguments
		voxelweights = weights_in;
		type = 1;
		obs = obs_in;
		fast = fast_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
	}
	
	//
	// globally weighted vote
	//
	public WeightedVote (ObservationBase obs_in, double [] weights_in) {
		// take in the input arguments
		obs = obs_in;
		globalweights = weights_in;
		type = 2;
		fast = 0;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
	}
	
	public WeightedVote (ObservationBase obs_in, double [] weights_in, int fast_in) {
		// take in the input arguments
		globalweights = weights_in;
		type = 2;
		obs = obs_in;
		fast = fast_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
	}
	
	//
	// globally labelwise weighted vote
	//
	public WeightedVote (ObservationBase obs_in, double [][] weights_in) {
		// take in the input arguments
		obs = obs_in;
		labelweights = weights_in;
		type = 3;
		fast = 0;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
	}
	
	public WeightedVote (ObservationBase obs_in, double [][] weights_in, int fast_in) {
		// take in the input arguments
		labelweights = weights_in;
		type = 3;
		obs = obs_in;
		fast = fast_in;
		estimate = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]];
	}
		
	public int [][][] run () {
		
		if (fast == 0) {
			// get the per voxel votes
			W = new double [obs.dims[0]][obs.dims[1]][obs.dims[2]][obs.num_labels][1];
			
			// calculate the voxelwise data
			if (type == 1)
				obs.iterate_votes(new VoxelwiseDataAdd(W, voxelweights));
			else if (type == 2)
				obs.iterate_votes(new VoxelwiseDataAdd(W, globalweights));
			else if (type == 3)
				obs.iterate_votes(new VoxelwiseDataAdd(W, labelweights));
			
			// normalize the data
			normalize_voxelwise_data(W);
			
			// set the estimate
			set_estimate_from_voxelwise_data(W);
			
			// return the estimate
			return(estimate);
		} else {
			// iterate over every voxel
			for (int x = 0; x < obs.dims[0]; x++)
				for (int y = 0; y < obs.dims[1]; y++)
					for (int z = 0; z < obs.dims[2]; z++) {
						double [] labelprobs = new double [obs.num_labels];
						
						if (type == 1)
							obs.iterate_voxel(new VotingVoxelAdd(labelprobs, voxelweights), x, y, z);
						else if (type == 2)
							obs.iterate_voxel(new VotingVoxelAdd(labelprobs, globalweights), x, y, z);
						else if (type == 3)
							obs.iterate_voxel(new VotingVoxelAdd(labelprobs, labelweights), x, y, z);
						
						estimate[x][y][z] = get_estimate_value(labelprobs);
					}
		
			// return the estimate
			return(estimate);
		}
	}
}
