package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.vanderbilt.masi.algorithms.labelfusion.ObservationBase;

public abstract class LabelFusionBase extends AbstractCalculation {
	
	protected ImageData estimate;
	protected ObservationBase obs;
	
	public LabelFusionBase (ObservationBase obs_in,
			 				String outname) {
		super();
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Generic Label Fusion Algorithm +++");
		
		// set the input observations
		obs = obs_in;
		
		// set the output estimate
		estimate = new ImageDataInt(obs.orig_dimx(), obs.orig_dimy(), obs.orig_dimz(), obs.orig_dimv());
		estimate.setHeader(obs.get_header());
		estimate.setName(outname);
		
		// set the initial estimate
		set_consensus_estimate();
	}

	public abstract ImageData run ();
	
	public ImageData get_label_prob(short l, String name) {
		
		float [] lp = new float [obs.num_labels()];
		float val = 0;
		
		// allocate space for the current label probability
		ImageData labelprob = new ImageDataFloat(name, 
												 obs.orig_dimx(),
												 obs.orig_dimy(),
												 obs.orig_dimz(),
												 obs.orig_dimv());
		
		// set the initial probability
		val = (l == 0) ? 1 : 0;
		for (int x = 0; x < obs.orig_dimx(); x++)
			for (int y = 0; y < obs.orig_dimy(); y++)
				for (int z = 0; z < obs.orig_dimz(); z++) 
					for (int v = 0; v < obs.orig_dimv(); v++)
						labelprob.set(x, y, z, v, val);
				
		// iterate over each voxel
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++) 
					for (int v = 0; v < obs.dimv(); v++) {
						if (obs.is_consensus(x, y, z, v))
							val = (obs.get_consensus_estimate(x, y, z, v) == l) ? 1 : 0;
						else {
							set_label_probabilities(x, y, z, v, lp);
							val = lp[l];
						}
						labelprob.set(x+obs.offx(),
									  y+obs.offy(),
									  z+obs.offz(),
									  v+obs.offv(),
									  val);
					}

		// return the estimated label probabilities
		return(labelprob);
	}
	
	public ImageData get_all_label_probs(String name) {
		
		// we can't do this unless the original dimensions are less than 4-D.
		if (obs.orig_dimv() > 1) {
			String errstr = "Error: Cannot save label probabilities as full matrix with 4-D input";
			JistLogger.logOutput(JistLogger.SEVERE, errstr);
			throw new RuntimeException(errstr);
		}
		
		// allocate space for the current label probability
		ImageData labelprobs = new ImageDataFloat(name, 
												 obs.orig_dimx(),
												 obs.orig_dimy(),
												 obs.orig_dimz(),
												 obs.num_labels());
		
		// account for the original cropping
		for (int x = 0; x < obs.orig_dimx(); x++)
			for (int y = 0; y < obs.orig_dimy(); y++)
				for (int z = 0; z < obs.orig_dimz(); z++) 
					for (int l = 0; l < obs.num_labels(); l++)
						labelprobs.set(x, y, z, l, (l==0) ? 1 : 0);
		
		// iterate over each voxel
		float val = 0;
		float [] lp = new float [obs.num_labels()];
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++) 
					if (obs.is_consensus(x, y, z, 0)) {
						for (short l = 0; l < obs.num_labels(); l++) {
							val = (obs.get_consensus_estimate(x, y, z, 0) == l) ? 1 : 0;
							labelprobs.set(x+obs.offx(),
										   y+obs.offy(),
										   z+obs.offz(),
										   l,
										   val);
						}
					
					} else {
						set_label_probabilities(x, y, z, 0, lp);
						for (short l = 0; l < obs.num_labels(); l++) {
							labelprobs.set(x+obs.offx(),
										   y+obs.offy(),
										   z+obs.offz(),
										   l,
										   lp[l]);
						}
					}
		
		// return the estimated label probabilities
		return(labelprobs);
	}

	protected abstract void set_label_probabilities(int x, int y, int z, int v, float [] lp);
	
	protected void set_consensus_estimate() {
		
		for (int x = 0; x < obs.orig_dimx(); x++)
			for (int y = 0; y < obs.orig_dimy(); y++)
				for (int z = 0; z < obs.orig_dimz(); z++)
					for (int v = 0; v < obs.orig_dimv(); v++)
							estimate.set(x, y, z, v, 0);
		
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++)
					for (int v = 0; v < obs.dimv(); v++)
						if (obs.is_consensus(x, y, z, v))
							estimate.set(x+obs.offx(),
										 y+obs.offy(),
										 z+obs.offz(),
										 v+obs.offv(),
										 obs.get_consensus_estimate(x, y, z, v));
	}
	
	protected short get_estimate_voxel(float [] lp) {
	
		// get the current estimated label and normalize the probabilities
		short estlabel = 0;
		float maxval = 0;
		for (short s = 0; s < obs.num_labels(); s++) {
			if (lp[s] > maxval) {
				maxval = lp[s];
				estlabel = s;
			}
		}
	
		return(estlabel);
	}
	
	public float get_max_label_probability(float [] lp) {
		float maxval = 0;
		
		for (int s = 0; s < obs.num_labels(); s++)
			if (lp[s] > maxval)
				maxval = lp[s];
		
		return(maxval);
	}
	
	protected void normalize_label_probabilities(float [] lp) {
		// get the current estimated label and normalize the probabilities
		float lpsum = 0;
		for (short s = 0; s < obs.num_labels(); s++)
			lpsum += lp[s];
		for (short s = 0; s < obs.num_labels(); s++)
			lp[s] /= lpsum;
	}
	
	protected void print_status(int ind,
								int num) {
		
		int total = 10;
	    int currval = (int)((total * (float)ind) / ((float)(num-1)));
	    int prevval = (int)((total * ((float)ind-1)) / ((float)(num-1)));
	    
	    if (currval > prevval) {
	    	String msg = "[";
	    	for (int i = 0; i < currval; i++)
	    		msg += "=";
	    	for (int i = currval; i < total; i++)
	    		msg += "+";
	    	msg += "]";
	    	
	    	JistLogger.logOutput(JistLogger.INFO, msg);
	    }
	}
	
	public static SparseMatrix5D get_majority_vote_probabilities(ObservationBase obs) {
		
		SparseMatrix5D sparseW = new SparseMatrix5D(obs.dimx(), obs.dimy(), obs.dimz(), obs.dimv(), obs.num_labels());
		
		// initialize some variables
		float [] lp = new float [obs.num_labels()];
		
		// iterate over every voxel
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++) 
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							
							// initialize the label probabilities
							Arrays.fill(lp, 0f);
							
							// add up the votes from each rater
							for (int j = 0; j < obs.num_raters(); j++) {
								short [] obslabels = obs.get_all(x, y, z, v, j);
								float [] obsvals = obs.get_all_vals(x, y, z, v, j);
								for (int i = 0; i < obslabels.length; i++)
									lp[obslabels[i]] += obsvals[i];
							}
							
							// normalize everything
							float sum = 0;
							for (short s = 0; s < obs.num_labels(); s++)
								sum += lp[s];
							for (short s = 0; s < obs.num_labels(); s++)
								lp[s] /= sum;
							
							// add to the sparse matrix
							sparseW.init_voxel(x, y, z, v, lp);
						}
		
		return(sparseW);
	}

	protected void load_truth_labels(ParamVolume truthvol,
								   short [][][][] truth) {

		ImageData img = truthvol.getImageData(true);
		//make sure that the dimensions match
		if (obs.orig_dims[0] != Math.max(img.getRows(), 1) ||
				obs.orig_dims[1] != Math.max(img.getCols(), 1) ||
				obs.orig_dims[2] != Math.max(img.getSlices(), 1) ||
				obs.orig_dims[3] != Math.max(img.getComponents(), 1)) {
			JistLogger.logOutput(JistLogger.SEVERE, "Error: Target Image Dimensions do not match");
			throw new RuntimeException("Error: Rater Dimensions do not match");
		}

		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++)
					for (int v = 0; v < obs.dimv(); v++)
						truth[x][y][z][v] = img.getShort(x + obs.offx(),
								y + obs.offy(),
								z + obs.offz(),
								v + obs.offv());
		
		img.dispose();
	}
	
}
