package edu.vanderbilt.masi.LabelFusion;

public class LabelFusionTools {
	
	public static double dice(ObservationBase obs,
							  int [][][] truth,
							  int rater,
							  double [] dsc) {
		
		// allocate some space first	
		double [] numer = new double [obs.num_labels];
		double [] denom = new double [obs.num_labels];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++) {
					
					// take all votes into account
					int [] v = obs.get_vote(x, y, z, rater);
					for (int vi = 0; vi < v.length; vi++)
						for (int l = 0; l < obs.num_labels; l++) {
							if (truth[x][y][z] == l && v[vi] == l)
								numer[l] += 1;
							if (truth[x][y][z] == l)
								denom[l] += 1;
							if (v[vi] == l)
								denom[l] += 1;
						}
					}
		
		// calculate the individual dsc values from the numerater and denominator
		for (int l = 0; l < obs.num_labels; l++) {
			if (denom[l] == 0)
				dsc[l] = 0;
			else
				dsc[l] = 2 * numer[l] / denom[l];
		}
		
		return(get_mean_val(dsc));
	}
	
	public static double dice(ObservationBase obs,
			  				  int [][][] truth,
			  				  int rater) {
		double [] dsc = null;
		return(dice(obs, truth, rater, dsc));
	}
	
	public static double dice(int [][][] truth,
								 int [][][] estimate,
								 int num_labels,
								 double [] dsc) {

		// allocate some space first	
		double [] numer = new double [num_labels];
		double [] denom = new double [num_labels];
		
		// iterate over every voxel
		for (int x = 0; x < truth.length; x++)
			for (int y = 0; y < truth[0].length; y++)
				for (int z = 0; z < truth[0][0].length; z++)
					for (int l = 0; l < num_labels; l++) {
						if (truth[x][y][z] == l && estimate[x][y][z] == l)
							numer[l] += 1;
						if (truth[x][y][z] == l)
							denom[l] += 1;
						if (estimate[x][y][z] == l)
							denom[l] += 1;
					}
		
		// calculate the individual dsc values from the numerater and denominator
		for (int l = 0; l < num_labels; l++) {
			if (denom[l] == 0)
				dsc[l] = 0;
			else
				dsc[l] = 2 * numer[l] / denom[l];
		}
		
		return(get_mean_val(dsc));
	}

	public static double dice(int [][][] truth,
			 				  int [][][] estimate,
			 				  int num_labels) {
		double [] dsc = null;
		return(dice(truth, estimate, num_labels, dsc));
	}
	
	public static double jaccard(ObservationBase obs,
							  int [][][] truth,
							  int rater,
							  double [] ji) {
		
		// allocate some space first
		double [] numer = new double [obs.num_labels];
		double [] denom = new double [obs.num_labels];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++) {
					
					// take all votes into account
					int [] v = obs.get_vote(x, y, z, rater);
					for (int vi = 0; vi < v.length; vi++)
						for (int l = 0; l < obs.num_labels; l++) {
							if (truth[x][y][z] == l && v[vi] == l)
								numer[l] += 1;
							if (truth[x][y][z] == l || v[vi] == l)
								denom[l] += 1;
						}
					}
		
		// calculate the individual dsc values from the numerater and denominator
		for (int l = 0; l < obs.num_labels; l++) {
			if (denom[l] == 0)
				ji[l] = 0;
			else
				ji[l] = numer[l] / denom[l];
		}
		
		return(get_mean_val(ji));
	}
	
	public static double jaccard(ObservationBase obs,
			  				  int [][][] truth,
			  				  int rater) {
		double [] ji = null;
		return(jaccard(obs, truth, rater, ji));
	}
	
	public static double jaccard(int [][][] truth,
								 int [][][] estimate,
								 int num_labels,
								 double [] ji) {

		// allocate some space first
		double [] numer = new double [num_labels];
		double [] denom = new double [num_labels];
		
		// iterate over every voxel
		for (int x = 0; x < truth.length; x++)
			for (int y = 0; y < truth[0].length; y++)
				for (int z = 0; z < truth[0][0].length; z++)
					for (int l = 0; l < num_labels; l++) {
						if (truth[x][y][z] == l && estimate[x][y][z] == l)
							numer[l] += 1;
						if (truth[x][y][z] == l || estimate[x][y][z] == l)
							denom[l] += 1;
					}
		
		// calculate the individual dsc values from the numerater and denominator
		for (int l = 0; l < num_labels; l++) {
			if (denom[l] == 0)
				ji[l] = 0;
			else
				ji[l] = numer[l] / denom[l];
		}
		
		return(get_mean_val(ji));
	}

	public static double jaccard(int [][][] truth,
			 				  int [][][] estimate,
			 				  int num_labels) {
		double [] ji = null;
		return(jaccard(truth, estimate, num_labels, ji));
	}
	
	public static double sensitivity(ObservationBase obs,
							         int [][][] truth,
							         int rater,
							         double [] sens) {
		
		// allocate some space first	
		double [] numer = new double [obs.num_labels];
		double [] denom = new double [obs.num_labels];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dims[0]; x++)
			for (int y = 0; y < obs.dims[1]; y++)
				for (int z = 0; z < obs.dims[2]; z++) {
					
					// take all votes into account
					int [] v = obs.get_vote(x, y, z, rater);
					for (int vi = 0; vi < v.length; vi++)
						for (int l = 0; l < obs.num_labels; l++) {
							if (truth[x][y][z] == l && v[vi] == l)
								numer[l] += 1;
							if (truth[x][y][z] == l)
								denom[l] += 1;
						}
					}
		
		// calculate the individual dsc values from the numerater and denominator
		for (int l = 0; l < obs.num_labels; l++) {
			if (denom[l] == 0)
				sens[l] = 0;
			else
				sens[l] = numer[l] / denom[l];
		}
		
		return(get_mean_val(sens));
	}
	
	public static double sensitivity(ObservationBase obs,
			  				         int [][][] truth,
			  				         int rater) {
		double [] sens = null;
		return(sensitivity(obs, truth, rater, sens));
	}
	
	public static double sensitivity(int [][][] truth,
								 	 int [][][] estimate,
								 	 int num_labels,
								 	 double [] sens) {

		// allocate some space first
		double [] numer = new double [num_labels];
		double [] denom = new double [num_labels];
		
		// iterate over every voxel
		for (int x = 0; x < truth.length; x++)
			for (int y = 0; y < truth[0].length; y++)
				for (int z = 0; z < truth[0][0].length; z++)
					for (int l = 0; l < num_labels; l++) {
						if (truth[x][y][z] == l && estimate[x][y][z] == l)
							numer[l] += 1;
						if (truth[x][y][z] == l)
							denom[l] += 1;
					}
		
		// calculate the individual dsc values from the numerater and denominator
		for (int l = 0; l < num_labels; l++) {
			if (denom[l] == 0)
				sens[l] = 0;
			else
				sens[l] = numer[l] / denom[l];
		}
		
		return(get_mean_val(sens));
	}

	public static double sensitivity(int [][][] truth,
			 				  	     int [][][] estimate,
			 				  	     int num_labels) {
		double [] sens = null;
		return(sensitivity(truth, estimate, num_labels, sens));
	}

	private static double get_mean_val(double [] arr) {
		int num_labels = arr.length;
		// calculate the mean JI value
		if (num_labels == 1)
			return(arr[0]);
		else {
			// NOTE: we assume that label 0 is the background and should
			//       not be included in the mean JI value.
			double mean = 0;
			for (int l = 1; l < num_labels; l++)
				mean += arr[l];
			mean /= num_labels - 1;
			return(mean);
		}
	}

}
