package edu.vanderbilt.masi.algorithms.labelfusion;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;

public class ObservationVolumeProbability extends ObservationBase {
	
	protected SparseMatrix6D obs;
	protected int probability_index;
	
	public ObservationVolumeProbability(ParamVolumeCollection obsvols,
							  			float cons_thresh_in,
							  			int probability_index_in) {
		this(obsvols, cons_thresh_in, probability_index_in, new int [4], new int [4], false);
	}
	
	public ObservationVolumeProbability(ParamVolumeCollection obsvols,
										float cons_thresh_in,
										int probability_index_in,
							            int [] sv_in,
							            int [] pv_in,
							            boolean quiet) {
		super(cons_thresh_in, sv_in, pv_in);
		setLabel("ObservationVolumeProbability");

		JistLogger.logOutput(JistLogger.INFO,"\n+++ Initializing Label Observations (Probability Type)  +++");
		
		probability_index = probability_index_in;

		// initialize the observation volumes
		determine_cropping_information(obsvols);
		load_probability_images(obsvols);

		if (!quiet) {
			consensus = new boolean[dims[0]][dims[1]][dims[2]][dims[3]];
			cons_estimate = new short[dims[0]][dims[1]][dims[2]][dims[3]];
			for (int x = 0; x < dims[0]; x++)
				for (int y = 0; y < dims[1]; y++)
					for (int z = 0; z < dims[2]; z++)
						for (int v = 0; v < dims[3]; v++) {
							consensus[x][y][z][v] = false;
							cons_estimate[x][y][z][v] = -1;
						}
			update_consensus();

			// print out some information
			JistLogger.logOutput(JistLogger.INFO, "-> Determined the following information");
			JistLogger.logOutput(JistLogger.INFO, "Number of Raters: " + num_raters);
			JistLogger.logOutput(JistLogger.INFO, "Number of Labels: " + num_labels);
			JistLogger.logOutput(JistLogger.INFO, String.format("Original Dimensions (voxels): [%d %d %d %d]", orig_dims[0], orig_dims[1], orig_dims[2], orig_dims[3]));
			JistLogger.logOutput(JistLogger.INFO, String.format("Cropped Dimensions (voxels): [%d %d %d %d]", dims[0], dims[1], dims[2], dims[3]));
			JistLogger.logOutput(JistLogger.INFO, "Ignore Consensus: " + ignore_consensus());
		}
	}
	
	protected void determine_cropping_information(ParamVolumeCollection obsvols) {
		
		// set the number of raters
		num_raters = obsvols.getParamVolumeList().size();

		JistLogger.logOutput(JistLogger.INFO, String.format("-> Attempting to determine valid cropping region (found %d label images)", num_raters));
		
		// Allocate some space
		boolean [][][][] crop = null;
		dims = new int [4];
		orig_dims = new int [4];

		for (int i = 0; i < num_raters; i++) {
			
			// Let the user know where we are
			print_status(i, num_raters);
			
			// load the image information (quietly)
			int orig_level = prefs.getDebugLevel();
			prefs.setDebugLevel(JistLogger.SEVERE);
			ImageData img = obsvols.getParamVolume(i).getImageData(true);
			
			// if this is the first volume, then get the pertinent information
			if (i == 0) {
				
				// set the dimensions
				orig_dims[0] = Math.max(img.getRows(), 1);
				orig_dims[1] = Math.max(img.getCols(), 1);
				if (probability_index == 1) {
					orig_dims[2] = 1;
					orig_dims[3] = 1;
					num_labels = (short)Math.max(img.getSlices(), 1);
				} else {
					orig_dims[2] = Math.max(img.getSlices(), 1);
					orig_dims[3] = 1;
					num_labels = (short)Math.max(img.getComponents(), 1);
				}
				
				header = img.getHeader().clone();
				
				float [] tempdimres = header.getDimResolutions();
				int lengthdimres = tempdimres.length;
				for (int d = 0; d < 4; d++) {
					dimres[d] = (d < lengthdimres) ? tempdimres[d] : 1; 
					if (dimres[d] == 0)
						dimres[d] = 1;
					
					// set the patch / search half-window size in mm / voxels
					sv[d] = (sv[d] >= 0) ? Math.min(Math.round((float)sv[d] / dimres[d]), (orig_dims[d]-1)/2) : -sv[d];
					pv[d] = (pv[d] >= 0) ? Math.min(Math.round((float)pv[d] / dimres[d]), (orig_dims[d]-1)/2) : -pv[d];
					
					// take care of any issues regarding image size
					sv[d] = Math.min(sv[d], (orig_dims[d]-1)/2);
					pv[d] = Math.min(pv[d], (orig_dims[d]-1)/2);
				}
				
				// allocate the observation matrix
				crop = new boolean [orig_dims[0]][orig_dims[1]][orig_dims[2]][orig_dims[3]];
				for (int x = 0; x < orig_dims[0]; x++)
					for (int y = 0; y < orig_dims[1]; y++)
						for (int z = 0; z < orig_dims[2]; z++)
							for (int v = 0; v < orig_dims[3]; v++)
								crop[x][y][z][v] = true;
				
			}
			
			if (probability_index == 1) {
				// make sure that the dimensions match
				if (orig_dims[0] != Math.max(img.getRows(), 1) ||
					orig_dims[1] != Math.max(img.getCols(), 1) ||
					num_labels != (short)Math.max(img.getSlices(), 1) ||
					orig_dims[3] != Math.max(img.getComponents(), 1)) {
					JistLogger.logOutput(JistLogger.SEVERE, "Error: Rater Dimensions do not match");
					throw new RuntimeException("Error: Rater Dimensions do not match");
				}
				
				
				// Iterate over all potential voxels
				for (int x = 0; x < orig_dims[0]; x++)
					for (int y = 0; y < orig_dims[1]; y++)
						if (num_labels == 1) {
							if (crop[x][y][0][0] == true)
								if (img.getFloat(x, y, 0, 0) != 0)
									crop[x][y][0][0] = false;
						} else {
							if (crop[x][y][0][0] == true)
								if (img.getFloat(x, y, 0, 0) != 1)
									crop[x][y][0][0] = false;
						}
			} else {
				// make sure that the dimensions match
				if (orig_dims[0] != Math.max(img.getRows(), 1) ||
					orig_dims[1] != Math.max(img.getCols(), 1) ||
					orig_dims[2] != Math.max(img.getSlices(), 1) ||
					num_labels != (short)Math.max(img.getComponents(), 1)) {
					JistLogger.logOutput(JistLogger.SEVERE, "Error: Rater Dimensions do not match");
					throw new RuntimeException("Error: Rater Dimensions do not match");
				}
				
				// Iterate over all potential voxels
				for (int x = 0; x < orig_dims[0]; x++)
					for (int y = 0; y < orig_dims[1]; y++)
						for (int z = 0; z < orig_dims[2]; z++)
							if (num_labels == 1) {
								if (crop[x][y][z][0] == true)
									if (img.getFloat(x, y, z, 0) != 0)
										crop[x][y][z][0] = false;
							} else {
								if (crop[x][y][z][0] == true)
									if (img.getFloat(x, y, z, 0) != 1)
										crop[x][y][z][0] = false;
							}
			}

			
			// free the current volume (quietly)
			obsvols.getParamVolume(i).dispose();
			prefs.setDebugLevel(orig_level);
		}
		
		// initialize the cropping indices
		crop_inds = new int [4][2];
		for (int i = 0; i < 4; i++)
				for (int j = 0; j < 2; j++)
					crop_inds[i][j] = -1;
			
		// determine which locations along each dimension are valid for cropping
		boolean [][] valid = new boolean [4][];
		valid[0] = new boolean[orig_dims[0]];
		valid[1] = new boolean[orig_dims[1]];
		valid[2] = new boolean[orig_dims[2]];
		valid[3] = new boolean[orig_dims[3]];
		for (int d = 0; d < 4; d++)
			for (int i = 0; i < orig_dims[d]; i++)
				valid[d][i] = true;
		for (int x = 0; x < orig_dims[0]; x++)
			for (int y = 0; y < orig_dims[1]; y++)
				for (int z = 0; z < orig_dims[2]; z++)
					for (int v = 0; v < orig_dims[3]; v++)
						if (crop[x][y][z][v] == false) {
							valid[0][x] = false;
							valid[1][y] = false;
							valid[2][z] = false;
							valid[3][v] = false;
						}
		
		// handle each dimension
		for (int d = 0; d < 4; d++) {
			
			// the forward case
			if (!valid[d][0])
				crop_inds[d][0] = 0;
			else
				for (int i = 1; i < orig_dims[d]; i++)
					if (!valid[d][i] && crop_inds[d][0] == -1)
						crop_inds[d][0] = i-1;
			
			// the reverse case
			if (!valid[d][orig_dims[d]-1])
				crop_inds[d][1] = orig_dims[d]-1;
			else
				for (int i = orig_dims[d]-2; i > 0; i--)
					if (!valid[d][i] && crop_inds[d][1] == -1)
						crop_inds[d][1] = i+1;
			
			if (crop_inds[d][0] > crop_inds[d][1])
				crop_inds[d][0] = crop_inds[d][1];
		}
		
		// fix any problems introduced by the sv,pv
		for (int d = 0; d < 4; d++) {
			crop_inds[d][0] = Math.max(crop_inds[d][0] - sv[d] - pv[d], 0);
			crop_inds[d][1] = Math.min(crop_inds[d][1] + sv[d] + pv[d], orig_dims[d]-1);
		}
		
		JistLogger.logOutput(JistLogger.INFO, "-> Found Cropping Region:");
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimension 1: (0, %d) -> (%d, %d)", orig_dims[0]-1, crop_inds[0][0], crop_inds[0][1]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimension 2: (0, %d) -> (%d, %d)", orig_dims[1]-1, crop_inds[1][0], crop_inds[1][1]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimension 3: (0, %d) -> (%d, %d)", orig_dims[2]-1, crop_inds[2][0], crop_inds[2][1]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimension 4: (0, %d) -> (%d, %d)", orig_dims[3]-1, crop_inds[3][0], crop_inds[3][1]));
		
		for (int d = 0; d < 4; d++)
			dims[d] = crop_inds[d][1] - crop_inds[d][0]+1;
				
	}
	
	private void load_probability_images(ParamVolumeCollection obsvols) {
		
		// handle the case where we only have a single image
		boolean singleim = false;
		if (num_labels == 1) {
			num_labels = 2;
			singleim = true;
		}
		
		// allocate space for the observations
		obs = new SparseMatrix6D(dims[0], dims[1], dims[2], dims[3], num_raters, num_labels);
		float [] lp = new float [num_labels];
		
		JistLogger.logOutput(JistLogger.INFO, "-> Loading all label information");

		for (int i = 0; i < num_raters; i++) {
			
			// Let the user know where we are
			print_status(i, num_raters);
			
			// load the image information (quietly)
			int orig_level = prefs.getDebugLevel();
			prefs.setDebugLevel(JistLogger.SEVERE);
			ImageData img = obsvols.getParamVolume(i).getImageData(true);
			
			// Iterate over all potential voxels
			int v = 0;
			float norm_val = 1;
			float sum = 0;
			if (probability_index == 1) {
				int z = 0;
				for (int x = 0; x < dims[0]; x++)
					for (int y = 0; y < dims[1]; y++) {
						if (singleim) {
							lp[1] = img.getFloat(x + offx(), y + offy(), z + offz(), v + offv());
							sum = lp[1];
							lp[0] = 1 - lp[1];
						} else {
							sum = 0;
							for (int l = 0; l < num_labels; l++) {
								lp[l] = img.getFloat(x + offx(), y + offy(), l + offz(), v + offv());
								sum += lp[l];
							}
						}
						
						if (sum > (norm_val + 0.01)) {
							JistLogger.logOutput(JistLogger.SEVERE, "Error: Value found greater than 1, invalid for probabilities");
							JistLogger.logOutput(JistLogger.SEVERE, String.format("Location (%d, %d %d, %d): Value %f", x, y, z, v, sum));
							throw new RuntimeException("Error: Value found greater than 1, invalid for probabilities");
						}
						
						// initialize and normalize this voxel
						obs.init_voxel(x, y, z, v, i, lp, (short)0);
						obs.normalize(x, y, z, v, i);
					}
			} else {
				for (int x = 0; x < dims[0]; x++)
					for (int y = 0; y < dims[1]; y++)
						for (int z = 0; z < dims[2]; z++) {
							if (singleim) {
								lp[1] = img.getFloat(x + offx(), y + offy(), z + offz(), v + offv());
								sum = lp[1];
								lp[0] = 1 - lp[1];
							} else {
								sum = 0;
								for (int l = 0; l < num_labels; l++) {
									lp[l] = img.getFloat(x + offx(), y + offy(), z + offz(), l + offv());
									sum += lp[l];
								}
							}
							
							if (sum > (norm_val + 0.01)) {
								JistLogger.logOutput(JistLogger.SEVERE, "Error: Value found greater than 1, invalid for probabilities");
								JistLogger.logOutput(JistLogger.SEVERE, String.format("Location (%d, %d %d, %d): Value %f", x, y, z, v, sum));
								throw new RuntimeException("Error: Value found greater than 1, invalid for probabilities");
							}
							
							// initialize and normalize this voxel
							obs.init_voxel(x, y, z, v, i, lp, (short)0);
							obs.normalize(x, y, z, v, i);
						}
			}
			
			// free the current volume (quietly)
			obsvols.getParamVolume(i).dispose();
			prefs.setDebugLevel(orig_level);
		}
		
		// unmap the label numbers on the observations
		label_remap = new short [num_labels];
		label_unmap = new short [num_labels];
		for (short i = 0; i < num_labels; i++) {
			label_remap[i] = i;
			label_unmap[i] = i;
		}
	}
	
	public short get(int x, int y, int z, int v, int j) { return(obs.get_max_label(x, y, z, v, j)); }
	public float get_val(int x, int y, int z, int v, int j) {
		short label = obs.get_max_label(x, y, z, v, j);
		return(obs.get_val(x, y, z, v, j, label)); 
	}
	public short [] get_all(int x, int y, int z, int v, int j) { return obs.get_all_labels(x, y, z, v, j); }
	public float [] get_all_vals(int x, int y, int z, int v, int j) { return obs.get_all_vals(x, y, z, v, j); }
	public void free_obs(int x, int y, int z, int v) { obs.free(x, y, z, v); }
	
}
	