package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MultiSetNonLocalSpatialSTAPLE extends AbstractCalculation {

	private int r,c,s,v,n;
	private int maxLabel;
	private int maxIters;
	private float consThresh;
	private float epsilon;
	private String targetClass;
	private ObservationBase2 obs;
	private MultiSetPerformanceParameters theta;
	private MultiSetPerformanceParameters thetaPrev;
	private String[] classes;
	private short[][][][] segmentation;
	private boolean[][][][] isConverged;
	private int numConvergedIter,numUnconvergedIter;

	public MultiSetNonLocalSpatialSTAPLE(ParamVolume target, ParamVolumeCollection regIms, ParamVolumeCollection regLabs,
			int weight_type, int sv_in, int pv_in, float sp_stdev_in, 
			float int_stdev_in, int sel_type_in, float sel_thresh,
			int num_keep, float cons_thresh_in, int probability_index, 
			boolean use_intensity_normalization_in,String targetClass_in,File raterMap_in,File thetaInit){
		super();
		int[] sv   = {sv_in,sv_in,sv_in,0};
		int[] pv   = {pv_in,pv_in,pv_in,0};
		float[] sp = {sp_stdev_in,sp_stdev_in,sp_stdev_in,sp_stdev_in};
		targetClass = targetClass_in;
		consThresh = cons_thresh_in;
		loadClasses(raterMap_in);
		obs = new ObservationVolumePartial2(target,regLabs,regIms, 
				weight_type, 
				sv, pv,sp,
				int_stdev_in,
				sel_type_in,
				sel_thresh,
				num_keep,1f,
				probability_index,
				use_intensity_normalization_in,
				classes,
				targetClass);

		r = obs.dimx();
		c = obs.dimy();
		s = obs.dimz();
		v = obs.dimv();
		n=obs.num_raters();
		theta = new MultiSetPerformanceParameters(targetClass,classes,thetaInit);
		thetaPrev = new MultiSetPerformanceParameters(theta);
		maxLabel = theta.getNumTarget();
		segmentation = new short[r][c][s][v];
		maxIters = 100;   // Need to do something about this
		epsilon  = 1e-4f; // This too
		setUpConverged();
		run();
	}

	private void setUpConverged(){
		isConverged = new boolean[r][c][s][v];
		for(boolean[][][] th:isConverged)
			for(boolean[][] tw: th)
				for(boolean[] on: tw)
					Arrays.fill(on, false);
	}

	private void run(){
		JistLogger.logOutput(JistLogger.WARNING, "++ Running EM ++");
		int numIter = 0;
		float[] lp = new float[maxLabel + 1];

		double convergence = Float.MAX_VALUE;
		while(convergence > epsilon && numIter < maxIters){

			long time_start = System.nanoTime();

			// increment the number of iterations
			numIter++;

			// iterate over all non-consensus voxels
			runPreEM();
			for (int x = 0; x < obs.dimx(); x++) {
				print_status(x, obs.dimx());
				for (int y = 0; y < obs.dimy(); y++)
					for (int z = 0; z < obs.dimz(); z++)
						for (int v = 0; v < obs.dimv(); v++)
							if (!obs.is_consensus(x, y, z, v))
								runEMVoxel(x, y, z, v, lp);
			}
			runPostEM();

			// calculate the convergence factor
			convergence = theta.get_convergence_factor(thetaPrev);

			// calculate the time that has elapsed for this iteration
			double elapsed_time = ((double)(System.nanoTime() - time_start)) / 1e9;

			JistLogger.logOutput(JistLogger.WARNING, String.format("Convergence Factor (%d, %.3fs): %f. %d/%d converged this iteration", numIter, elapsed_time, convergence,numConvergedIter,numUnconvergedIter));
		}

	}

	private void runEMVoxel(int x,int y,int z,int v,float[] lp){
		lp = new float[maxLabel+1];
		Arrays.fill(lp, 0f);
		if(!isConverged[x][y][z][v]){
			runEStepVoxel(x,y,z,v,lp,thetaPrev);

			normalize(lp);
			
			int l = getEstimate(lp);

			segmentation[x][y][z][v] = (short) l;

			if(getMaxProb(lp) > consThresh){
				isConverged[x][y][z][v] = true;
				numConvergedIter++;
			}
			numUnconvergedIter++;
		} else{
			Arrays.fill(lp, 0f);
			lp[segmentation[x][y][z][v]] = 1f;
		}

		runMStepVoxel(x,y,z,v,lp);
	}
	
	private void normalize(float[] lp){
		float sum = 0f;
		for(float f:lp)
			sum += f;
		for(int i=0;i<lp.length;i++)
			lp[i] = lp[i] / sum;
	}

	private void runMStepVoxel(int x,int y,int z,int v,float[] lp){

		for(int s=0;s<maxLabel+1;s++){
			if(lp[s] > 0){
				for(int j=0;j<n;j++){
					short[] obsLabs = obs.get_all(x, y, z, v, j);
					float[] obsVals = obs.get_all_vals(x, y, z, v, j);

					for(int l=0;l<obsLabs.length;l++)
						theta.add(j, obsLabs[l], s, obsVals[l]*lp[s]);
				}
			}
		}
	}

	private float getMaxProb(float[] lp){
		float mp = 0f;
		for(int i=0;i<lp.length;i++)
			if(lp[i] > mp)
				mp = lp[i];
		return mp;
	}

	private int getEstimate(float[] lp){
		int l = -1;
		float mp = 0f;
		for(int i=0;i<lp.length;i++)
			if(lp[i] > mp){
				mp = lp[i];
				l = i;
			}
		return l;
	}

	private void runEStepVoxel(int x,int y,int z,int v,float[] lp, MultiSetPerformanceParameters t){
		initialize_label_probabilities(x,y,z,v,lp);

		double[] lpd = new double[maxLabel + 1];
		double normfact = 0;
		double maxfact = Double.MIN_VALUE;

		// set initial probabilities
		for(int i=0;i<lpd.length;i++)
			lpd[i] = Math.log(lp[i]);

		// Run E-Step
		for(int j=0;j<n;j++){
			short[] obsLabs = obs.get_all(x, y, z, v, j);
			float[] obsVals = obs.get_all_vals(x, y, z, v, j);
			for(int s=0;s<lpd.length;s++){

				if(lpd[s] == Double.NEGATIVE_INFINITY)
					continue;

				lpd[s] += t.getLog(j,s,obsLabs,obsVals);

				if(lpd[s] > maxfact)
					maxfact = lpd[s];
			}
		}

		for(int s=0;s<maxLabel + 1;s++){
			lpd[s] = Math.exp(lpd[s] - maxfact);
			normfact += lpd[s];
		}

		if(normfact == 0)
			JistLogger.logOutput(JistLogger.SEVERE, "XXXXX - Problem Found - XXXXX");
		
		for (int s = 0; s < maxLabel + 1; s++)
			lp[s] = (float) (lpd[s] / normfact);
	}

	private void runPreEM(){
		thetaPrev.copy(theta);
		theta.reset();
		numConvergedIter = 0;
		numUnconvergedIter = 0;
	}

	private void runPostEM(){
		theta.normalize();
	}

	public ImageData getSegmentation(){
		ImageData im = new ImageDataInt("segmentation",obs.orig_dimx(),obs.orig_dimy(),obs.orig_dimz(),obs.orig_dimv());
		for(int i=0;i<r;i++)
			for(int j=0;j<c;j++)
				for(int k=0;k<s;k++)
					for(int l=0;l<v;l++)
						if(!obs.is_consensus(i, j, k, l))
							im.set(i+obs.offx(), j+obs.offy(),k+obs.offz(),l+obs.offv(),segmentation[i][j][k][l]);
						else
							im.set(i+obs.offx(), j+obs.offy(),k+obs.offz(),l+obs.offv(),obs.get_consensus_estimate(i, j, k, l));
		return im;
	}

	private void loadClasses(File f) {
		ArrayList<String> t = new ArrayList<String>();
		try {
			BufferedReader br = new BufferedReader(new FileReader(f));
			String line;
			while((line=br.readLine())!=null){
				t.add(line);
			}
			br.close();
			classes = new String[t.size()];
			for(int i=0;i<classes.length;i++)
				classes[i] = t.get(i);
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	protected void print_status(int ind,int num) {

		int total = 10;
		int currval = (int)((total * (float)ind) / ((float)(num-1)));
		int prevval = (int)((total * ((float)ind-1)) / ((float)(num-1)));

		if (currval > prevval) {
			String msg = "[";
			for (int i = 0; i < currval; i++)
				msg += "=";
			for (int i = currval; i < total; i++)
				msg += "+";
			msg += "]";

			JistLogger.logOutput(JistLogger.WARNING, msg);
			JistLogger.logFlush();
		}
	}

	private void initialize_label_probabilities(int x,int y,int z,int v,float[] lp){
		for(int i=0;i<n;i++){
			if(!targetClass.equals(classes[i]))
				continue;
			short[] obsLabs = obs.get_all(x, y, z, v, i);
			float[] obsVals = obs.get_all_vals(x, y, z, v, i);
			for(int j=0;j<obsLabs.length;j++){
				short l=obsLabs[j];
				lp[l] += obsVals[j];
			}
		}
	}
}
