package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.JistPreferences;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public abstract class ObservationBase extends AbstractCalculation {
	
	protected short num_labels;
	protected int num_raters;
	protected int [] dims;
	protected int [] orig_dims;
	protected int [][] crop_inds;
	protected ImageHeader header;
	protected short [][][][] cons_estimate;
	protected boolean [][][][] consensus;
	protected float cons_thresh;
	protected short [] label_remap;
	protected short [] label_unmap;
	protected int [] sv;
	protected int [] pv;
	protected float [] dimres;
	protected boolean use_weighted_prior;
	JistPreferences prefs = JistPreferences.getPreferences();
	
	public ObservationBase (float cons_thresh_in,
							int [] sv_in,
							int [] pv_in) {
		super();
		setLabel("ObservationBase");
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Generic Observation Base +++");
		cons_thresh = Math.max(0.7f, cons_thresh_in);
		sv = sv_in;
		pv = pv_in;
		dimres = new float [4];
		use_weighted_prior = false;
	}
	
	public ObservationBase (float cons_thresh_in) {
		this(cons_thresh_in, new int [4], new int [4]);
	}
	
	// abstract functions
	public abstract short get(int x, int y, int z, int v, int j);
	public abstract float get_val(int x, int y, int z, int v, int j);
	public abstract short [] get_all(int x, int y, int z, int v, int j);
	public abstract float [] get_all_vals(int x, int y, int z, int v, int j);
	public abstract void free_obs(int x, int y, int z, int v);
	
	public float [][] get_all_vals_full(int x, int y, int z, int v) {
		float [][] full_vals = new float [num_raters][num_labels];
		
		for (int j = 0; j < num_raters; j++) {
			short [] obslabs = get_all(x, y, z, v, j);
			float [] obsvals = get_all_vals(x, y, z, v, j);
			for (int i = 0; i < obslabs.length; i++)
				full_vals[j][obslabs[i]] = obsvals[i];
		}
		
		return(full_vals);
	}
	
	public void initialize_adaptive_probabilities(int x, int y, int z, int v, int dilation, float [] lp) {

		// set the current region of interest
		int xl = Math.max(x - dilation, 0);
		int xh = Math.min(x + dilation, dimx()-1);
		int yl = Math.max(y - dilation, 0);
		int yh = Math.min(y + dilation, dimy()-1);
		int zl = Math.max(z - dilation, 0);
		int zh = Math.min(z + dilation, dimz()-1);
		int vl = Math.max(v - dilation, 0);
		int vh = Math.min(v + dilation, dimv()-1);
		
		Arrays.fill(lp, 0f);
		
		for (int xi = xl; xi <= xh; xi++)
			for (int yi = yl; yi <= yh; yi++)
				for (int zi = zl; zi <= zh; zi++)
					for (int vi = vl; vi <= vh; vi++)
						if (!this.is_consensus(xi, yi, zi, vi))
							for (int j = 0; j < num_raters; j++) {
								short [] obslabels = this.get_all(xi, yi, zi, vi, j);
								for (int l = 0; l < obslabels.length; l++)
									lp[obslabels[l]] = 1;
							}
	}
		
	// public functions
	public int[] dims() { return(dims); }
	public int dimx() { return(dims[0]); }
	public int dimy() { return(dims[1]); }
	public int dimz() { return(dims[2]); }
	public int dimv() { return(dims[3]); }
	public int num_vox() { return(dims[0]*dims[1]*dims[2]*dims[3]); }
	public int[] orig_dims() { return(orig_dims); }
	public int orig_dimx() { return(orig_dims[0]); }
	public int orig_dimy() { return(orig_dims[1]); }
	public int orig_dimz() { return(orig_dims[2]); }
	public int orig_dimv() { return(orig_dims[3]); }
	public int offx() { return(crop_inds[0][0]); }
	public int offy() { return(crop_inds[1][0]); }
	public int offz() { return(crop_inds[2][0]); }
	public int offv() { return(crop_inds[3][0]); }
	public short num_labels() { return(num_labels); }
	public int num_raters() { return(num_raters); }
	public ImageHeader get_header() { return(header); }
	public short get_consensus_estimate(int x, int y, int z, int v) { return(cons_estimate[x][y][z][v]); }
	public boolean is_consensus(int x, int y, int z, int v) { return(consensus[x][y][z][v]); }
	public short get_label_remap(short l) { return(label_remap[l]); }
	public void create_atlas_selection_matrix(float global_thresh,
			  							      float local_thresh) {
		throw new RuntimeException("[ObservationBase] create_atlas_selection_matrix not implemented.");
	}
	public void create_weighted_prior() {
		throw new RuntimeException("[ObservationBase] create_weighted_prior not implemented.");
	}
	public void get_weighted_prior(int x, int y, int z, int v, float [] lp) {
		throw new RuntimeException("[ObservationBase] get_weighted_prior not implemented.");
	}
	public boolean get_local_selection(int x, int y, int z, int v, int j) { return(true); }
	public boolean ignore_consensus() { return(cons_thresh <= 1); }
	public float consensus_threshold() { return(cons_thresh); }
	public void normalize_all() {}
	public void update_consensus() {
		
		// we can skip this if we don't want to ignore consensus voxels
		if (!ignore_consensus())
			return;
		
		float [] lp = new float [num_labels];
		float maxval;
		short labelest = 0;
		
		// first set the initial consensus
		JistLogger.logOutput(JistLogger.INFO, "-> Setting final consensus voxels/estimate");
		float numcon = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++) 
						if (!is_consensus(x, y, z, v)) {
							
							maxval = 0;
							labelest = 0;
							
							if (use_weighted_prior) {
								
								/**
								 * Use the weighted prior
								 */
								get_weighted_prior(x, y, z, v, lp);
							} else {
							
								/**
								 * Use a standard voting technique
								 */
								
								// initialize the label probability
								Arrays.fill(lp, 0f);
								
								// set the label probability for this voxel
								for (int j = 0; j < num_raters; j++)
									if (get_local_selection(x, y, z, v, j)) {
										short [] obslabels = get_all(x, y, z, v, j);
										float [] obsvals = get_all_vals(x, y, z, v, j);
										for (int l = 0; l < obslabels.length; l++)
											lp[obslabels[l]] += obsvals[l];
									}
							}
							
							// normalize everything
							float sum = 0;
							for (short s = 0; s < num_labels; s++) {
								sum += lp[s];
								if (lp[s] > maxval) {
									maxval = lp[s];
									labelest = s;
								}
							}
							maxval /= sum;
							
							// set the consensus estimate
							cons_estimate[x][y][z][v] = labelest;
							consensus[x][y][z][v] = (ignore_consensus()) ? maxval >= (cons_thresh - 0.00001f) : false;
							
							if (consensus[x][y][z][v]) {
								numcon++;
								free_obs(x, y, z, v);
							}
							
						} else {
							numcon++; 
						}
		JistLogger.logOutput(JistLogger.INFO, "Final Fraction Consensus: " + numcon / (dims[0]*dims[1]*dims[2]*dims[3]));
	}
	
	// protected functions
	protected void copy_data(ObservationBase obs2) {
		this.num_raters = obs2.num_raters;
		this.num_labels = obs2.num_labels;
		this.orig_dims = obs2.orig_dims;
		this.dims = obs2.dims;
		this.dimres = obs2.dimres;
		this.label_remap = obs2.label_remap;
		this.label_unmap = obs2.label_unmap;
		this.sv = obs2.sv;
		this.pv = obs2.pv;
		this.crop_inds = obs2.crop_inds;
		this.header = obs2.header;
	}
	protected void determine_cropping_information(ParamVolumeCollection obsvols) {
				
		// set the number of raters
		num_raters = obsvols.getParamVolumeList().size();

		JistLogger.logOutput(JistLogger.INFO, String.format("-> Attempting to determine valid cropping region (found %d label images)", num_raters));
		
		// Allocate some space
		boolean [][][][] crop = null;
		dims = new int [4];
		orig_dims = new int [4];

		for (int i = 0; i < num_raters; i++) {
			
			// Let the user know where we are
			print_status(i, num_raters);
			
			// load the image information (quietly)
			int orig_level = prefs.getDebugLevel();
			prefs.setDebugLevel(JistLogger.SEVERE);
			ImageData img = obsvols.getParamVolume(i).getImageData(true);
			
			// if this is the first volume, then get the pertinent information
			if (i == 0) {
				
				// set the dimensions
				orig_dims[0] = Math.max(img.getRows(), 1);
				orig_dims[1] = Math.max(img.getCols(), 1);
				orig_dims[2] = Math.max(img.getSlices(), 1);
				orig_dims[3] = Math.max(img.getComponents(), 1);
				header = img.getHeader().clone();
				
				float [] tempdimres = header.getDimResolutions();
				int lengthdimres = tempdimres.length;
				for (int d = 0; d < 4; d++) {
					dimres[d] = (d < lengthdimres) ? tempdimres[d] : 1; 
					if (dimres[d] == 0)
						dimres[d] = 1;
					
					// set the patch / search half-window size in mm / voxels
					sv[d] = (sv[d] >= 0) ? Math.min(Math.round((float)sv[d] / dimres[d]), (orig_dims[d]-1)/2) : -sv[d];
					pv[d] = (pv[d] >= 0) ? Math.min(Math.round((float)pv[d] / dimres[d]), (orig_dims[d]-1)/2) : -pv[d];
					
					// take care of any issues regarding image size
					sv[d] = Math.min(sv[d], (orig_dims[d]-1)/2);
					pv[d] = Math.min(pv[d], (orig_dims[d]-1)/2);
				}
				
				// allocate the observation matrix
				crop = new boolean [orig_dims[0]][orig_dims[1]][orig_dims[2]][orig_dims[3]];
				for (int x = 0; x < orig_dims[0]; x++)
					for (int y = 0; y < orig_dims[1]; y++)
						for (int z = 0; z < orig_dims[2]; z++)
							for (int v = 0; v < orig_dims[3]; v++)
								crop[x][y][z][v] = true;
				
			}
			
			// make sure that the dimensions match
			if (orig_dims[0] != Math.max(img.getRows(), 1) ||
				orig_dims[1] != Math.max(img.getCols(), 1) ||
				orig_dims[2] != Math.max(img.getSlices(), 1) ||
				orig_dims[3] != Math.max(img.getComponents(), 1)) {
				JistLogger.logOutput(JistLogger.SEVERE, "Error: Rater Dimensions do not match");
				throw new RuntimeException("Error: Rater Dimensions do not match");
			}
			
			// Iterate over all potential voxels
			for (int x = 0; x < orig_dims[0]; x++)
				for (int y = 0; y < orig_dims[1]; y++)
					for (int z = 0; z < orig_dims[2]; z++)
						for (int v = 0; v < orig_dims[3]; v++)
							if (crop[x][y][z][v] == true)
								if (img.getFloat(x, y, z, v) != 0)
									crop[x][y][z][v] = false;
			
			// free the current volume (quietly)
			obsvols.getParamVolume(i).dispose();
			prefs.setDebugLevel(orig_level);
		}
		
		// initialize the cropping indices
		crop_inds = new int [4][2];
		for (int i = 0; i < 4; i++)
				for (int j = 0; j < 2; j++)
					crop_inds[i][j] = -1;
			
		// determine which locations along each dimension are valid for cropping
		boolean [][] valid = new boolean [4][];
		valid[0] = new boolean[orig_dims[0]];
		valid[1] = new boolean[orig_dims[1]];
		valid[2] = new boolean[orig_dims[2]];
		valid[3] = new boolean[orig_dims[3]];
		for (int d = 0; d < 4; d++)
			for (int i = 0; i < orig_dims[d]; i++)
				valid[d][i] = true;
		for (int x = 0; x < orig_dims[0]; x++)
			for (int y = 0; y < orig_dims[1]; y++)
				for (int z = 0; z < orig_dims[2]; z++)
					for (int v = 0; v < orig_dims[3]; v++)
						if (crop[x][y][z][v] == false) {
							valid[0][x] = false;
							valid[1][y] = false;
							valid[2][z] = false;
							valid[3][v] = false;
						}
		
		// handle each dimension
		for (int d = 0; d < 4; d++) {
			
			// the forward case
			if (!valid[d][0])
				crop_inds[d][0] = 0;
			else
				for (int i = 1; i < orig_dims[d]; i++)
					if (!valid[d][i] && crop_inds[d][0] == -1)
						crop_inds[d][0] = i-1;
			
			// the reverse case
			if (!valid[d][orig_dims[d]-1])
				crop_inds[d][1] = orig_dims[d]-1;
			else
				for (int i = orig_dims[d]-2; i > 0; i--)
					if (!valid[d][i] && crop_inds[d][1] == -1)
						crop_inds[d][1] = i+1;
			
			if (crop_inds[d][0] > crop_inds[d][1])
				crop_inds[d][0] = crop_inds[d][1];
		}
		
		// fix any problems introduced by the sv,pv
		for (int d = 0; d < 4; d++) {
			crop_inds[d][0] = Math.max(crop_inds[d][0] - sv[d] - pv[d], 0);
			crop_inds[d][1] = Math.min(crop_inds[d][1] + sv[d] + pv[d], orig_dims[d]-1);
		}
		
		JistLogger.logOutput(JistLogger.INFO, "-> Found Cropping Region:");
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimension 1: (0, %d) -> (%d, %d)", orig_dims[0]-1, crop_inds[0][0], crop_inds[0][1]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimension 2: (0, %d) -> (%d, %d)", orig_dims[1]-1, crop_inds[1][0], crop_inds[1][1]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimension 3: (0, %d) -> (%d, %d)", orig_dims[2]-1, crop_inds[2][0], crop_inds[2][1]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimension 4: (0, %d) -> (%d, %d)", orig_dims[3]-1, crop_inds[3][0], crop_inds[3][1]));
		
		for (int d = 0; d < 4; d++)
			dims[d] = crop_inds[d][1] - crop_inds[d][0]+1;
				
	}
	
	protected short [][][][][] load_label_images(ParamVolumeCollection obsvols) {
		
		// set the number of raters
		short [][][][][] obs = new short [dims[0]][dims[1]][dims[2]][dims[3]][num_raters];
		num_labels = 0;
		short val;
		
		JistLogger.logOutput(JistLogger.INFO, "-> Loading all label information");

		for (int i = 0; i < num_raters; i++) {
			
			// Let the user know where we are
			print_status(i, num_raters);
			
			// load the image information (quietly)
			int orig_level = prefs.getDebugLevel();
			prefs.setDebugLevel(JistLogger.SEVERE);
			ImageData img = obsvols.getParamVolume(i).getImageData(true);
			
			// make sure that the dimensions match
			if (orig_dims[0] != Math.max(img.getRows(), 1) ||
				orig_dims[1] != Math.max(img.getCols(), 1) ||
				orig_dims[2] != Math.max(img.getSlices(), 1) ||
				orig_dims[3] != Math.max(img.getComponents(), 1)) {
				JistLogger.logOutput(JistLogger.SEVERE, "Error: Rater Dimensions do not match");
				throw new RuntimeException("Error: Rater Dimensions do not match");
			}
			
			// Iterate over all potential voxels
			for (int x = 0; x < dims[0]; x++)
				for (int y = 0; y < dims[1]; y++)
					for (int z = 0; z < dims[2]; z++)
						for (int v = 0; v < dims[3]; v++) {
							val = img.getShort(x + offx(),
											   y + offy(),
											   z + offz(),
											   v + offv());
							
							// check to see if it increases the number of labels
							if (val > num_labels)
								num_labels = val;
							
							// add to the observation matrix
							obs[x][y][z][v][i] = val;
						}
			
			// free the current volume (quietly)
			obsvols.getParamVolume(i).dispose();
			prefs.setDebugLevel(orig_level);
		}
		
		// set the final number of labels
		num_labels++;
		
		// unmap the label numbers on the observations
		set_label_remap(obs);
		
		return(obs);
	}
	
	protected void print_status(int ind,
								int num) {
					
		int total = (num > 10) ? 10 : num;
		
		int currval = (int)(((float)total * (((float)(ind+1)) / ((float)(num)))));
		int prevval = (int)(((float)total * (((float)(ind)) / ((float)(num)))));
		
		if (currval > prevval) {
			String msg = "[";
			for (int i = 0; i < currval; i++)
				msg += "=";
			for (int i = currval; i < total; i++)
				msg += "+";
			msg += "]";
			JistLogger.logOutput(JistLogger.INFO, msg);
		}
	}

	protected void set_label_remap(short [][][][][] obs) {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Unmapping labels to simple space\n");
		
		// construct the temporary histogram
		short [] tmp_hist = new short [num_labels];
		for (short l = 0; l < num_labels; l++)
			tmp_hist[l] = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++)
						for (int j = 0; j < num_raters; j++)
							tmp_hist[obs[x][y][z][v][j]] = 1;
		
		// set the number of unique labels
		short num_unique_labels = 0;
		for (short l = 0; l < num_labels; l++)
			if (tmp_hist[l] > 0)
				num_unique_labels++;
		
		// construct the mappings
		label_remap = new short [num_unique_labels];
		label_unmap = new short [num_labels];
		short count = 0;
		for (short l = 0; l < num_labels; l++)
			if (tmp_hist[l] > 0) {
				label_remap[count] = l;
				label_unmap[l] = count;
				count++;
			} else
				label_unmap[l] = -1;
		
		// unmap the obs
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++)
						for (int j = 0; j < num_raters; j++)
							obs[x][y][z][v][j] = label_unmap[obs[x][y][z][v][j]];
		
		num_labels = num_unique_labels;
		
	}

	protected void remap_estimate(ImageData estimate) {
		JistLogger.logOutput(JistLogger.INFO, "In ObservationBase - Remapping labels to original space\n");
		for (int x = 0; x < orig_dims[0]; x++)
			for (int y = 0; y < orig_dims[1]; y++)
				for (int z = 0; z < orig_dims[2]; z++)
					for (int v = 0; v < orig_dims[3]; v++)
						estimate.set(x, y, z, v, label_remap[estimate.getInt(x,  y, z, v)]);
	}
	
	public short unmap_label(short l) { return(label_unmap[l]); }
	
	public void add_to_probabilities(int x, int y, int z, int v, int j, float [] lp, float fact) {
		short [] labs = this.get_all(x, y, z, v, j);
		float [] vals = this.get_all_vals(x, y, z, v, j);
		for (int i = 0; i < labs.length; i++)
			lp[labs[i]] += fact*vals[i];
	}
}
