package edu.vanderbilt.masi.algorithms.utilities;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;

public class LabelAnalysis extends AbstractCalculation {

	public static float[][] dice(ParamVolume truthvol,
							     ParamVolume estvol) {
		
		// find the dimensions
		int [] dims = new int [4];
		
		// load both volumes
		ImageData truthimg = truthvol.getImageData(true);

		// get the dimensions from the truth image
		dims[0] = Math.max(truthimg.getRows(), 1);
		dims[1] = Math.max(truthimg.getCols(), 1);
		dims[2] = Math.max(truthimg.getSlices(), 1);
		dims[3] = Math.max(truthimg.getComponents(), 1);

		// load the truth volume
		short [][][][] truth = new short [dims[0]][dims[1]][dims[2]][dims[3]];
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++)
						truth[x][y][z][v] = truthimg.getShort(x, y, z, v);
		truthimg.dispose();
		
		// load the estimated volume
		ImageData estimg = estvol.getImageData(true);

		// make sure that the dimensions match
		if (dims[0] != Math.max(estimg.getRows(), 1) ||
			dims[1] != Math.max(estimg.getCols(), 1) ||
			dims[2] != Math.max(estimg.getSlices(), 1) ||
			dims[3] != Math.max(estimg.getComponents(), 1)) {
			String errstr = "Error: Truth/Estimate Dimensions do not match";
			JistLogger.logOutput(JistLogger.SEVERE, errstr);
			throw new RuntimeException(errstr);
		}
		
		short [][][][] est = new short [dims[0]][dims[1]][dims[2]][dims[3]];
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++)
						est[x][y][z][v] = estimg.getShort(x, y, z, v);
		estimg.dispose();
		
		return(dice(truth, est, dims));
		
	}
	
	public static float[][] dice(short [][][][] truth,
								 short [][][][] est,
								 int [] dims) {
		
		// get the array of unique values
		short [] unique_labels = get_unique_labels(truth, est, dims);
		int num_unique_labels = unique_labels.length;
		
		// set the re-mapping function
		short max_label_num = unique_labels[num_unique_labels-1];
		short [] labelremap = new short[max_label_num+1];
		for (short l = 0; l < num_unique_labels; l++)
			labelremap[unique_labels[l]] = l;

		// get the Dice values
		float [][] dicevals = new float [num_unique_labels][2];
		short truthval;
		short estval;
		
		// create arrays to store the results
		int [] num_truth = new int[num_unique_labels];
		int [] num_est = new int[num_unique_labels];
		int [] num_both = new int[num_unique_labels];
			
		// iterate over all the voxels
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++) {
						
						// get the values from each volume
						truthval = labelremap[truth[x][y][z][v]];
						estval = labelremap[est[x][y][z][v]];
						
						// keep track of the number of occurrences
						num_truth[truthval]++;
						num_est[estval]++;
						if (truthval == estval)
							num_both[truthval]++;
					}
		
		// set the Dice information
		float meandice = 0;
		for (int l = 0; l < num_unique_labels; l++) {
			// set the dice information
			dicevals[l][0] = (float)unique_labels[l];
			dicevals[l][1] = 2*((float)num_both[l]) / (((float)num_truth[l]) + ((float)num_est[l]));
			JistLogger.logOutput(JistLogger.INFO, String.format("[LabelAnalysis] Dice (label %02d): %f", unique_labels[l], dicevals[l][1]));
			if (l > 0)
				meandice += dicevals[l][1];
		}
		meandice /= (num_unique_labels-1);
		JistLogger.logOutput(JistLogger.INFO, String.format("[LabelAnalysis] Mean Dice (Non-Background): %f", meandice));
		
		return(dicevals);
	}
	
	public static short [] get_unique_labels(short[][][][] img1, short[][][][] img2,
											 int [] dims) {
		
		JistLogger.logOutput(JistLogger.INFO, "[LabelAnalysis] Determining unique labels.");
		
		short [] unique_labels;
		short val;

		// first, find the max label number
		short max_label_num = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++) {
						val = img1[x][y][z][v];
						if (val > max_label_num)
							max_label_num = val;
					}
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++) {
						val = img2[x][y][z][v];
						if (val > max_label_num)
							max_label_num = val;
					}
		
		// second, get a histogram of label occurrences
		float [] labelhist = new float [max_label_num+1];
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++){
						labelhist[img1[x][y][z][v]]++;
						labelhist[img2[x][y][z][v]]++;
					}
		
		// third get the number of unique labels
		short num_unique_labels = 0;
		for (int l = 0; l <= max_label_num; l++)
			if (labelhist[l] > 0)
				num_unique_labels++;
		
		// finally, set the unique label numbers
		unique_labels = new short [num_unique_labels];
		int count = 0;
		for (short l = 0; l <= max_label_num; l++)
			if (labelhist[l] > 0) {
				unique_labels[count] = l;
				count++;
			}
		
		return(unique_labels);
	}
	
}
