package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.HashMap;
import java.util.Set;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SimplifiedMultiSetSTAPLE {

	private short[][][][][] raters;
	private boolean[][][][] consensus;
	private boolean[][][][] converged;
	private int[][][][] segmentation;
	private float[][][][] probabilities;
	private String[] raterMap;
	private int[] classes;
	private int r,c,s,t;
	private int numLabels;
	private int maxIter;
	private float epsilon;
	private float consensusVal;
	private float regularization;
	private int numRaters;
	private String targetClass;
	private HashMap<String,Integer> stringToInt;
	private HashMap<Integer,String> intToString;
	private SimplifiedMultiClassPerformanceParameters theta;
	private SimplifiedMultiClassPerformanceParameters thetaOld;

	public SimplifiedMultiSetSTAPLE(List<ImageData> ims, File ratersMap, File initTheta,
			Number iter, Number consThresh, Number eps, String targetSegmentation, Number reg) {
		
		maxIter = iter.intValue();
		epsilon = eps.floatValue();
		consensusVal = consThresh.floatValue();
		regularization = reg.floatValue();
		loadRaters(ims);
		
			
		raterMap = new String[numRaters];
		classes = new int[numRaters];
		loadRaterMap(ratersMap);
		targetClass = targetSegmentation;
		theta = new SimplifiedMultiClassPerformanceParameters(initTheta,raterMap,targetSegmentation);
		numLabels = theta.getNumLabels();
		determineConsensus();
		System.out.flush();
		run();
	}

	public ImageData getSegmentation(){
		ImageData im = new ImageDataInt(r,c,s,t);
		for(int i=0;i<r;i++)
			for(int j=0;j<c;j++)
				for(int k=0;k<s;k++)
					for(int l=0;l<t;l++)
						im.set(i,j,k,l,segmentation[i][j][k][l]);
		im.setName(targetClass);
		return im;
	}

	private void run(){
	
		JistLogger.logOutput(JistLogger.WARNING, "++ Starting statistical fusion ++");
		JistLogger.logFlush();
		boolean convergenceStatus = false;
		int numIter = 0;
		while(!convergenceStatus){
			JistLogger.logOutput(JistLogger.WARNING,"Starting iteration "+(++numIter));
			JistLogger.logFlush();
			runPreEM();
			for(int i=0;i<r;i++)
				for(int j=0;j<c;j++)
					for(int k=0;k<s;k++)
						for(int l=0;l<t;l++)
							if(!consensus[i][j][k][l])
								runEMVoxel(i,j,k,l);
			runPostEM();
			convergenceStatus = testConvergence(numIter);
		}
	}

	private void runPreEM(){
		thetaOld = theta.copy();
		theta.clear();
	}

	private void runPostEM(){
		theta.regularize(regularization);
		theta.normalize();
	}

	private void runEMVoxel(int i,int j,int k,int l){
		float[] lp = calculatePrior(i,j,k,l);
		if(!converged[i][j][k][l]){
			lp = runEStepVoxel(i,j,k,l,lp,thetaOld);
		}
		runMStepVoxel(i,j,k,l,lp);
		setSegAndProb(i,j,k,l,lp);
	}
	
	private void setSegAndProb(int i,int j,int k,int l,float[] lp){
		int maxLabel = -1;
		 float maxProb = 0;
		 for(int x=0;x<lp.length;x++){
			 if(lp[x] > maxProb){
				 maxProb = lp[x];
				 maxLabel = x;
			 }
		 }
		 if(maxLabel<0){
			 JistLogger.logOutput(JistLogger.WARNING, "Something went wrong and the maximum probability was 0");
		 }
		 probabilities[i][j][k][l] = maxProb;
		 segmentation[i][j][k][l]  = maxLabel;
		 if(maxProb > consensusVal){
			 consensus[i][j][k][l] = true;
			 probabilities[i][j][k][l] = 1;
		 }
	}

	private float[] calculatePrior(int i,int j,int k, int l){
		float[] lp = new float[numLabels];
		Arrays.fill(lp, 0f);
		if(converged[i][j][k][l]){
			lp[segmentation[i][j][k][l]] = 1f;
		}else{
			float sum = 0;
			boolean noneClass = true;
			for(int m=0;m<numRaters;m++){
				if(raterMap[m].equals(targetClass)){
					noneClass = false;
					int obs = raters[i][j][k][l][m];
					lp[obs] += 1f;
					sum += 1f;
				}
			}
			if(noneClass){
				Arrays.fill(lp, 1f/numLabels);
				sum = 1;
			}
			if(sum==0)
				JistLogger.logOutput(JistLogger.WARNING, "Something went wrong in the prior calculation and the total value is 0");
			for(int m=0;m<lp.length;m++)
				lp[m] = lp[m]/sum;
		}
		return lp;
	}

	private float[] runEStepVoxel(int i,int j,int k,int l,float[] lp,
			SimplifiedMultiClassPerformanceParameters currTheta){

		double[] lpd = new double[numLabels];
		double normfact = 0;
		double maxfact = Double.MIN_VALUE;

		for (int s = 0; s < numLabels; s++) {

			// set the initial value 
			lpd[s] = Math.log(lp[s]);

			// if the probability is non-zero, iterate over all raters
			if (lp[s] > 0) {
				for (int m = 0; m < numRaters; m++) {

					// see if the probability is already zero here
					if (lpd[s] == Double.NEGATIVE_INFINITY)
						continue;

					// get the information for the current rater
					short obsLabel = raters[i][j][k][l][m];

					// get the contribution from this rater
					lpd[s] += currTheta.getLog(m, obsLabel, s);
				}
			}

			if (lpd[s] > maxfact)
				maxfact = lpd[s]; 

		}

		// calculate the normalization constant and go back to linear space
		for (int s = 0; s < numLabels; s++) {
			lpd[s] = Math.exp(lpd[s] - maxfact);
			normfact += lpd[s];
		}

		if (normfact == 0)
			JistLogger.logOutput(JistLogger.SEVERE, "XXXXX - Problem Found - XXXXX");

		// normalize across the label probabilities
		for (int s = 0; s < numLabels; s++)
			lp[s] = (float) (lpd[s] / normfact);



		return lp;
	}

	protected void runMStepVoxel(int x, int y, int z, int v, float [] lp) {
		
		// add the impact to theta (M-step)
		for (int s = 0; s < numLabels; s++)
			if (lp[s] > 0)
				for (int j = 0; j < numRaters; j++) {
					
					// get the rater observations
					short obsLabel = raters[x][y][z][v][j];
					
					// add the impact to theta
					theta.add(j, obsLabel, s, lp[s]);
				}
	}
	
	private void loadRaterMap(File f){
		try{
			BufferedReader br = new BufferedReader(new FileReader(f));
			String line;
			int n=0;
			while((line=br.readLine())!=null)
				raterMap[n++] = line;
			br.close();
		}catch(IOException e){
			e.printStackTrace();
		}
		int n=0;
		determineRaterClassRelationship();
		for(String s:raterMap){
			classes[n++] = stringToInt.get(s);
		}
	}

	private void determineRaterClassRelationship(){
		stringToInt = new HashMap<String,Integer>();
		intToString = new HashMap<Integer,String>();
		int n = 0;
		for(String s: raterMap){
			if(!stringToInt.containsKey(s)){
				stringToInt.put(s, n);
				intToString.put(n, s);
				n++;
			}
		}
	}

	private void loadRaters(List<ImageData> ims){
		ImageData im = ims.get(0);
		numRaters = ims.size();
		r = im.getRows();
		if(r <= 0)
			r=1;
		c = im.getCols();
		if(c <= 0)
			c=1;
		s = im.getSlices();
		if(s <= 0)
			s=1;
		t = im.getComponents();
		if(t <= 0)
			t=1;
		raters = new short[r][c][s][t][numRaters];
		for(int i=0;i<numRaters;i++){
			im = ims.get(i);
			for(int j=0;j<r;j++)
				for(int k=0;k<c;k++)
					for(int l=0;l<s;l++)
						for(int m=0;m<t;m++)
							raters[j][k][l][m][i] = im.getShort(j,k,l,m);
			im.dispose();
		}
	}

	private void determineConsensus(){
		JistLogger.logOutput(JistLogger.WARNING, "++ Starting Consensus Region Determination ++");
		JistLogger.logFlush();
		consensus = new boolean[r][c][s][t];
		segmentation = new int[r][c][s][t];
		converged = new boolean[r][c][s][t];
		probabilities = new float[r][c][s][t];
		for(int i=0;i<r;i++){
			for(int j=0;j<c;j++){
				for(int k=0;k<s;k++){
					for(int l=0;l<t;l++){
						probabilities[i][j][k][l] = 1;
						consensus[i][j][k][l] = true;
						converged[i][j][k][l] = false;
						segmentation[i][j][k][l] = -1;
						int lab = -1;
						boolean noneClass = true;
						raterloop:
							for(int m=0;m<numRaters;m++){
								String cl = raterMap[m];
								if(cl.equals(targetClass)){
									noneClass = false;
									int obs = raters[i][j][k][l][m];
									if(lab < 0){
										segmentation[i][j][k][l] = obs;
										lab = obs;
									}
									if(lab != obs){
										consensus[i][j][k][l] = false;
										segmentation[i][j][k][l] = -1;
										probabilities[i][j][k][l] = 1f/numLabels;
										break raterloop;
									}
								}
							}
						if(noneClass)
							consensus[i][j][k][l] = false;
					}
				}
			}
			
		}
		
		JistLogger.logOutput(JistLogger.WARNING, "++ Finished Consensus Region Determination ++");
		JistLogger.logFlush();
	}

	public ImageData getProbabilities(){
		ImageData im = new ImageDataFloat(r,c,s,t);
		im.setName("Probabilities");
		for(int i=0;i<r;i++)
			for(int j=0;j<c;j++)
				for(int k=0;k<s;k++)
					for(int l=0;l<t;l++)
						im.set(i,j,k,l,probabilities[i][j][k][l]);
		return im;
	}

	private boolean testConvergence(int numIter){
		if(numIter >= maxIter)
			return true;
		float diff = theta.calculateDifference(thetaOld);
		JistLogger.logOutput(JistLogger.WARNING, "Epsilon value was "+diff);
		if(diff < epsilon){
			return true;
		}
		return false;
	}
}
