package edu.vanderbilt.masi.algorithms.labelfusion;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class ApplyMarkovRandomField extends AbstractCalculation {
	
	// the main parameters
	private float [][][][] target;
	private float epsilon;
	private int maxiter;
	private ObservationBase obs;
	
	// the derived parameters
	private float [][] H;
	private float [] betas;
	private float ffact;
	private int [] sv;
	private int [] pv;
	private double [][][][] dist_factor;
		
	public ApplyMarkovRandomField(ObservationBase obs_in,
								  float [][][][] target_in,
								  float epsilon_in,
								  int maxiter_in,
								  float val_ondiag,
								  float val_offdiag,
								  float int_stdev,
								  float beta0,
								  float beta1,
								  float beta2,
								  float beta3,
								  int sv0,
								  int sv1,
								  int sv2,
								  int sv3,
								  int pv0,
								  int pv1,
								  int pv2,
								  int pv3,
								  float [] dimres) {
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing a Markov Random Field +++");
		
		// set the over-arching parameters
		epsilon = epsilon_in;
		maxiter = maxiter_in;
		target = target_in;
		obs = obs_in;
		
		// set the front factor for the gaussian intensity difference
		ffact = (float) (-1 / (2 * int_stdev * int_stdev));
		
		// set the strength parameters
		betas = new float [4];
		betas[0] = beta0;
		betas[1] = beta1;
		betas[2] = beta2;
		betas[3] = beta3;
		
		// set the search volume
		sv = new int [4];
		sv[0] = Math.min(sv0, obs.dimx()-1);
		sv[1] = Math.min(sv1, obs.dimy()-1);
		sv[2] = Math.min(sv2, obs.dimz()-1);
		sv[3] = Math.min(sv3, obs.dimv()-1);
		
		// set the patch volume
		pv = new int [4];
		pv[0] = Math.min(Math.round((float)pv0 / dimres[0]), (obs.dimx()-1)/2);
		pv[1] = Math.min(Math.round((float)pv1 / dimres[1]), (obs.dimy()-1)/2);
		pv[2] = Math.min(Math.round((float)pv2 / dimres[2]), (obs.dimz()-1)/2);
		pv[3] = Math.min(Math.round((float)pv3 / dimres[3]), (obs.dimv()-1)/2);
		
		dist_factor = new double [2*sv[0]+1][2*sv[1]+1][2*sv[2]+1][2*sv[3]+1];
		set_dist_factor(dimres);
		
		// set the H matrix
		H = new float [obs.num_labels()][obs.num_labels()];
		for (int i = 0; i < obs.num_labels(); i++)
			for (int j = 0; j < obs.num_labels(); j++)
				H[i][j] = (i == j) ? val_ondiag : val_offdiag;
		
		// print out some information
		JistLogger.logOutput(JistLogger.INFO, "-> Determined the following information");
		JistLogger.logOutput(JistLogger.INFO, "Convergence Threshold: " + epsilon);
		JistLogger.logOutput(JistLogger.INFO, "Maximum Number of Iterations: " + maxiter);
		JistLogger.logOutput(JistLogger.INFO, "On-Diagonal Label Compatibility: " + val_ondiag);
		JistLogger.logOutput(JistLogger.INFO, "Off-Diagonal Label Compatibility: " + val_offdiag);
		JistLogger.logOutput(JistLogger.INFO, "Intensity Standard Deviation: " + int_stdev);
		JistLogger.logOutput(JistLogger.INFO, String.format("Search Volume Dimensions (voxels): [%d %d %d %d]", sv[0], sv[1], sv[2], sv[3]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Patch Volume Dimensions (voxels): [%d %d %d %d]", pv[0], pv[1], pv[2], pv[3]));
		JistLogger.logOutput(JistLogger.INFO, "MRF Strength (Mean Component): " + betas[0]);
		JistLogger.logOutput(JistLogger.INFO, "MRF Strength (Intensity Component): " + betas[1]);
		JistLogger.logOutput(JistLogger.INFO, "MRF Strength (Label Component): " + betas[2]);
		JistLogger.logOutput(JistLogger.INFO, "MRF Strength (Joint Component): " + betas[3]);
	}
	
	public void run (SparseMatrix5D sparseW) {
		
		// set the previous value for the Sparse Matrix
		SparseMatrix5D sparseW_prev = new SparseMatrix5D(obs.dimx(),
														 obs.dimy(),
														 obs.dimz(),
														 obs.dimv(),
														 obs.num_labels(),
														 sparseW);
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Running Markov Random Field +++");
		
		// initialize some variables
		float convergence_factor = 100000;
		int numiter = 0;
		
		while (convergence_factor > epsilon && numiter < maxiter) {
			
			// copy the previous estimate
			if (numiter > 0)
				sparseW_prev.copy(sparseW);
			
			// increment the number of iterations
			numiter++;
			
			// iterate over every voxel
			for (int x = 0; x < obs.dimx(); x++)
				for (int y = 0; y < obs.dimy(); y++)
					for (int z = 0; z < obs.dimz(); z++) 
						for (int v = 0; v < obs.dimv(); v++)
							if (!obs.is_consensus(x, y, z, v) && 
								sparseW_prev.get_max_val(x, y, z, v) < 1)
								apply_MRF_voxel(sparseW, sparseW_prev, x, y, z, v);
			
			// calculate the convergence factor and print the result
			convergence_factor = sparseW.get_convergence_factor(sparseW_prev);
			JistLogger.logOutput(JistLogger.INFO, String.format("Convergence Factor (%d): %f", numiter, convergence_factor));				
		}
	}
	
	private void apply_MRF_voxel(SparseMatrix5D sparseW,
								 SparseMatrix5D sparseW_prev,
								 int x,
								 int y,
								 int z,
								 int v) {
		
		// initialize some arrays
		short [] Wlabels;
		short [] W2labels;
		float[] W2probs;
				
		// set the labels that we are interested in for the current location
		Wlabels = sparseW_prev.get_all_labels(x, y, z, v);
		int num_keep = Wlabels.length;
		
		double [] finalvals = new double [num_keep];
		double normfact = 0;
		
		// set the search neighborhood
		int xm = Math.min(x, obs.dimx() - x - 1);
		int ym = Math.min(y, obs.dimy() - y - 1);
		int zm = Math.min(z, obs.dimz() - z - 1);
		int vm = Math.min(v, obs.dimv() - v - 1);
		int xs = (xm < sv[0]) ? xm : sv[0];
		int ys = (ym < sv[1]) ? ym : sv[1];
		int zs = (zm < sv[2]) ? zm : sv[2];
		int vs = (vm < sv[3]) ? vm : sv[3];
		
		// iterate over all of the considered labels
		for (int k = 0; k < num_keep; k++) {
			
			// initialize some variables
			short currlabel = Wlabels[k];
			short label;
			double diff, ival, hval, val;
			double priorval = sparseW_prev.get_val_ind(x, y, z, v, k);
			double expval1 = 0;
			double expval2 = 0;
			double expval3 = 0;
			double expval4 = 0;
			
			// iterate over the search neighborhood
			for (int xp = x-xs; xp <= x+xs; xp++)
				for (int yp = y-ys; yp <= y+ys; yp++)
					for (int zp = z-zs; zp <= z+zs; zp++)
						for (int vp = v-vs; vp <= v+vs; vp++) {
							
							int xd = (xp - x) + sv[0];
							int yd = (yp - y) + sv[1];
							int zd = (zp - z) + sv[2];
							int vd = (vp - v) + sv[3];
							double df = dist_factor[xd][yd][zd][vd];
							
							// set the intensity difference value
							diff = get_norm_diff(x, y, z, v, xp, yp, zp, vp);
							ival = Math.exp(ffact * diff * diff);
							
							// things are handled differently if the current voxel is consensus or not
							if (!obs.is_consensus(xp, yp, zp, vp)) {
								
								// get the label/probability information for this voxel
								W2labels = sparseW_prev.get_all_labels(xp, yp, zp, vp);
								W2probs = sparseW_prev.get_all_vals(xp, yp, zp, vp);
								int num_keep2 = W2labels.length;
																
								for (int k2 = 0; k2 < num_keep2; k2++) {
									label = W2labels[k2];
									val = W2probs[k2];
									hval = H[currlabel][label];
									if (label == currlabel) {
										expval1 += df * val;
										expval2 += df * ival * val;
									}
									expval3 += df * H[currlabel][label] * val;
									expval4 += df * hval * ival * val;
								}
							} else {
								label = obs.get_consensus_estimate(xp, yp, zp, vp);
								hval = H[currlabel][label];
								if (label == currlabel) {
									expval1 += df * 1;
									expval2 += df * ival;
								}
								expval3 += df * H[currlabel][label];
								expval4 += df * hval * ival;
							}
						}
			
			// set the final value based upon the impact of the energy function
			finalvals[k] = (priorval * Math.exp(betas[0] * expval1 +
											   	betas[1] * expval2 +
											   	betas[2] * expval3 +
											   	betas[3] * expval4));
			normfact += finalvals[k];
		}
		
		// set the final values
		for (int k = 0; k < num_keep; k++)
			sparseW.set_val_ind(x, y, z, v, k, (float) (finalvals[k] / normfact)); 
		
	}
	
	private void set_dist_factor(float [] dimres) {
			
		// scale the resolutions for proper weighting
		float [] dr = new float [4];
		float minval = 10000f;
		for (int i = 0; i < 4; i++) {
			if (sv[i] == 0)
				dr[i] = 0;
			else
				dr[i] = dimres[i];
			if (dr[i] < minval && dr[i] > 0)
				minval = dr[i];
		}
		for (int i = 0; i < 4; i++) {
			dr[i] /= minval;
			//JistLogger.logOutput(JistLogger.INFO, String.format("%f", dr[i]));
		}
		
		int svxf = 2*sv[0]+1;
		int svyf = 2*sv[1]+1;
		int svzf = 2*sv[2]+1;
		int svvf = 2*sv[3]+1;
		
		for (int x = 0; x < svxf; x++)
			for (int y = 0; y < svyf; y++)
				for (int z = 0; z < svzf; z++)
					for (int v = 0; v < svvf; v++) {
						double dist = Math.sqrt((dr[0] * (sv[0]-x))*(dr[0] * (sv[0]-x)) + 
											  	(dr[1] * (sv[1]-y))*(dr[1] * (sv[1]-y)) + 
											  	(dr[2] * (sv[2]-z))*(dr[2] * (sv[2]-z)) + 
											  	(dr[3] * (sv[3]-v))*(dr[3] * (sv[3]-v)));
						dist_factor[x][y][z][v] = (dist == 0) ? 1 : 1 / dist;
						JistLogger.logOutput(JistLogger.INFO, String.format("%d %d %d %d %f", x, y, z, v, dist_factor[x][y][z][v]));
					}
	}
	
	private float get_norm_diff(int x,
								int y,
								int z,
								int v,
								int xi,
								int yi,
								int zi,
								int vi) {
		float diff = 0;
		float tval, ival;
		
		int xm = Math.min(Math.min(x, xi), obs.dimx() - Math.max(x, xi) - 1);
		int ym = Math.min(Math.min(y, yi), obs.dimy() - Math.max(y, yi) - 1);
		int zm = Math.min(Math.min(z, zi), obs.dimz() - Math.max(z, zi) - 1);
		int vm = Math.min(Math.min(v, vi), obs.dimv() - Math.max(v, vi) - 1);
		
		int xs = (xm < pv[0]) ? xm : pv[0];
		int ys = (ym < pv[1]) ? ym : pv[1];
		int zs = (zm < pv[2]) ? zm : pv[2];
		int vs = (vm < pv[3]) ? vm : pv[3];
		
		for (int xp = -xs; xp <= xs; xp++)
			for (int yp = -ys; yp <= ys; yp++)
				for (int zp = -zs; zp <= zs; zp++)
					for (int vp = -vs; vp <= vs; vp++) {
						tval = target[x+xp][y+yp][z+zp][v+vp];
						ival = target[xi+xp][yi+yp][zi+zp][vi+vp];
						diff += (tval - ival) * (tval - ival);
					}
		diff /= (2*xs+1) * (2*ys+1) * (2*zs+1) * (2*vs+1); 
			
		return(diff);
		
	}
	
	
}
