package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Set;

import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.utility.JistLogger;


public class MultiClassSTAPLE {

	private ArrayList<Integer> classes;
	private HashMap<Integer,String> intToClass;
	private HashMap<String,Integer> classToInt;
	private int numClasses;
	private int[] classCount;
	private MultiClassPerformanceParameters theta;
	private MultiClassPerformanceParameters thetaPrev;
	private int maxIters;
	private int numRaters;
	private float epsilon;
	private float consensusThresh;
	private List<ImageData> images;
	private ImageData estimate;
	private ImageData probability;
	private ImageData converged;
	private ImageData consensus;
	private List<ImageData> resultLabels;
	private int[] croppingRegion;
	private float nudge;
	private int numConsensus;


	public MultiClassSTAPLE(ParamVolumeCollection obsVals, ParamFile raterMap,
			ParamFile initialTheta, ParamInteger maxIt,
			ParamFloat cons, ParamFloat eps,ParamFloat nud,boolean dC){
		setRaterMap(raterMap.getValue());
		loadInitialTheta(initialTheta.getValue());
		maxIters = maxIt.getInt();
		numConsensus = 0;
		epsilon = eps.getFloat();
		consensusThresh = cons.getFloat();
		nudge = nud.getFloat();
		JistLogger.logOutput(JistLogger.INFO, "Initializing Multi-Class STAPLE");
		JistLogger.logOutput(JistLogger.INFO, "The maximum number of iterations is "+maxIters);
		JistLogger.logOutput(JistLogger.INFO, "The epsilon value for theta convergence is "+epsilon);
		JistLogger.logOutput(JistLogger.INFO, "The value to determine consensus is "+consensusThresh);
		JistLogger.logOutput(JistLogger.INFO,String.format("The nudge value is %f", nudge));
		JistLogger.logFlush();
		images = obsVals.getImageDataList();
		JistLogger.logOutput(JistLogger.INFO, "There are "+images.size() + " images");
		JistLogger.logFlush();
		initialize();
		if(dC){
			JistLogger.logOutput(JistLogger.INFO, "Determining Consensus Voxels");
			determineConsensusRegion();
		}
		runEM();
	}

	private void initialize(){
		estimate = new ImageDataInt(images.get(0).getRows(),images.get(0).getCols(),images.get(0).getSlices(),images.get(0).getComponents());
		probability = new ImageDataFloat(images.get(0).getRows(),images.get(0).getCols(),images.get(0).getSlices(),images.get(0).getComponents());
		probability.setName("ProbabilityVolume");
		converged = new ImageDataInt(images.get(0).getRows(),images.get(0).getCols(),images.get(0).getSlices(),images.get(0).getComponents());
		consensus = new ImageDataInt(images.get(0).getRows(),images.get(0).getCols(),images.get(0).getSlices(),images.get(0).getComponents());
		determineCroppingRegion();
		// Not actually sure if this is necessary, but it can't hurt, right?
		resultLabels = new ArrayList<ImageData>();
		for(int i=0;i<numClasses;i++){
			ImageData im = new ImageDataInt(estimate.getRows(),estimate.getCols(),estimate.getSlices(),estimate.getComponents());
			im.setName("Segmentation_"+intToClass.get(i));
			resultLabels.add(im);
		}
		for(int i=0;i<estimate.getRows();i++){
			for(int j=0;j<estimate.getCols();j++){
				for(int k=0;k<estimate.getSlices();k++){
					for(int l=0;l<estimate.getComponents();l++){
						estimate.set(i,j,k,l,0);
						probability.set(i,j,k,l,inCroppingRegion(i,j,k) ? 0 :1 );
						converged.set(i,j,k,l,0);
						consensus.set(i,j,k,l,0);
						for(int m=0;m<resultLabels.size();m++){
							resultLabels.get(m).set(i,j,k,l,0);
						}
					}
				}
			}
		}
	}

	private void setRaterMap(File f){
		try 
		{
			BufferedReader br = new BufferedReader(new FileReader(f));
			String line;

			classes = new ArrayList<Integer>();
			intToClass = new HashMap<Integer,String>();
			classToInt = new HashMap<String,Integer>();
			int n = 0;
			while((line=br.readLine())!=null){
				if(!classToInt.containsKey(line)){
					classToInt.put(line, n);
					intToClass.put(n, line);
					n++;
				}
				classes.add(classToInt.get(line));
				//System.out.println(line+" "+classToInt.get(line));
			}
			numClasses = classToInt.values().size();
			numRaters = classes.size();
			br.close();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		classCount = new int[numClasses];
		Arrays.fill(classCount, 0);
		for(int i:classes)
			classCount[i]++;
	}

	private void loadInitialTheta(File f){
		theta = new MultiClassPerformanceParameters();
		theta.loadFromFile(f,classToInt);
		theta.setInitialThetas(classes);
		theta.initialize();
		thetaPrev = new MultiClassPerformanceParameters();
		intToClass = theta.getClassMap();
		numClasses = theta.getNumClasses();
	}

	private void runEM(){
		int numIters = 0;
		float convergenceValue = Float.MAX_VALUE;
		float[] lp;
		float[] prob;
		JistLogger.logOutput(JistLogger.INFO, "Starting EM Algorithm");
		while(convergenceValue > epsilon && numIters < maxIters){
			numIters++;
			long timeStart = System.nanoTime();
			runPreEM();
			for(int i=croppingRegion[0];i<croppingRegion[1];i++){
				double elapsedTime = ((double)(System.nanoTime() - timeStart)) / 1e9;
				JistLogger.logOutput(JistLogger.FINE, String.format("Starting row %d after %.3fs. Thus far %d voxels have converged.", i,elapsedTime,numConsensus));
				JistLogger.logFlush();
				for(int j=croppingRegion[2];j<croppingRegion[3];j++){
					for(int k=croppingRegion[4];k<croppingRegion[5];k++){
						for(int l=0;l<estimate.getComponents();l++){
							if(!consensus.getBoolean(i, j,k,l)){
								lp = determineLabelPrior(i,j,k,l);
								prob = runEMVoxel(i,j,k,l,lp);
								setLabels(prob,i,j,k,l);
							}
						}
					}
				}
			}
			runPostEM();
			convergenceValue = theta.calculateConvergence(thetaPrev);
			double elapsedTime = ((double)(System.nanoTime() - timeStart)) / 1e9;
			JistLogger.logOutput(JistLogger.INFO, String.format("Convergence Factor (%d, %.3fs): %f", numIters, elapsedTime, convergenceValue));
		}
	}

	private boolean isConverged(int i, int j, int k, int l){
		return (converged.getInt(i, j,k,l) == 1);
	}

	private void runPostEM(){
		theta.nudge(nudge);
		theta.normalize();
	}

	private float[] runEMVoxel(int i,int j, int k, int l,float[] lp){
		float[] prob;
		if(!isConverged(i,j,k,l)){
			prob = runEStepVoxel(i,j,k,l,lp);
		}
		else{
			prob = new float[lp.length];
			Arrays.fill(prob, 0f);
			int e = estimate.getInt(i, j,k,l);
			prob[e]=1;
		}
		prob = normalizeProb(prob);
		for(int x=0;x<prob.length;x++){
		}
		runMStepVoxel(i,j,k,l,prob);
		return prob;
	}

	private float[] determineLabelPrior(int x, int y,int z,int c){
		List<HashMap<Integer,Float>> labelCounts = new ArrayList<HashMap<Integer,Float>>();
		for(int i=0;i<numClasses;i++){
			labelCounts.add(new HashMap<Integer,Float>());
		}
		for(int i=0;i<numRaters;i++){
			int obs = images.get(i).getInt(x, y,z,c);
			int cl  = classes.get(i);
			HashMap<Integer,Float> hm  = labelCounts.get(cl);
			float n   = hm.containsKey(obs) ? hm.get(obs) : 0f;
			n = n + 1;
			labelCounts.get(cl).put(obs, n);
		}
		for(int i=0;i<numClasses;i++){
			int cl  = classCount[i];
			HashMap<Integer,Float> hm  = labelCounts.get(i);
			for(int k:hm.keySet()){
				hm.put(k, hm.get(k)/cl);
			}
		}
		float[] lp = determineAllProbs(labelCounts);
		return lp;
	}

	private float[] determineAllProbs(List<HashMap<Integer,Float>> labelCounts){
		int numLabels = theta.getNumLabels();
		int[][] ind = determineAllLabelSets(labelCounts);
		float[] lp = new float[numLabels];
		Arrays.fill(lp, 0f);
		for(int[] row:ind){
			int l = theta.findLabel(row);
			float p = 1f;
			for(int i=0;i<row.length;i++){
				p *= labelCounts.get(i).get(row[i]);
			}
			if(l!=-1)
				lp[l] = p;
		}
		return lp;
	}

	private int[][] determineAllLabelSets(List<HashMap<Integer,Float>> labelCounts){
		int[] numLabels = new int[labelCounts.size()];
		for(int i=0;i<numLabels.length;i++){
			numLabels[i] = labelCounts.get(i).size();
		}
		int[][] indTemp1 = new int[labelCounts.get(0).size()][1];
		HashMap<Integer,Float> hm = labelCounts.get(0);
		Set<Integer> keys = hm.keySet();
		int n=0;
		for(int k:keys){
			indTemp1[n][0]=k;
			n++;
		}
		for(int i=1;i<numClasses;i++){
			hm = labelCounts.get(i);
			n=hm.size();
			int[][] indTemp2 = new int[indTemp1.length*n][i+1];
			keys = hm.keySet();
			int j=0;
			for(int l:keys){
				for(int k=0;k<indTemp1.length;k++){
					int idx = j*indTemp1.length + k;
					indTemp2[idx][i] = l;
					for(int m=0;m<indTemp1[0].length;m++){
						indTemp2[idx][m] = indTemp1[k][m];
					}
				}
				j++;
			}
			indTemp1 = indTemp2;
		}

		int[][] ind = indTemp1;
		return ind;
	}

	private void runPreEM(){
		theta.calculateAlpha();
		theta.normalize();
		thetaPrev.copy(theta);
		theta.reset();
	}

	private float[] runEStepVoxel(int x,int y,int z,int c,float[] lp){
		double[] lbp = new double[lp.length];
		double maxFact = Double.MIN_VALUE;
		double prob;
		int obs;
		//for each label
		for(int i=0;i<lbp.length;i++){
			if(lp[i]>0){
				lbp[i] = Math.log(lp[i]);
				//for each rater
				for(int j=0;j<numRaters;j++){
					obs = images.get(j).getInt(x,y,z,c);
					prob = thetaPrev.getLogProbability(j,obs,i);
					lbp[i] += prob;
				}
			}else{
				lbp[i] = Double.NEGATIVE_INFINITY;
			}
			if(lbp[i] > maxFact) 
				maxFact = lbp[i];
		}
		for(int i=0;i<lbp.length;i++){
			lp[i] = (float) Math.exp(lbp[i] - maxFact);
		}
		return lp;
	}

	private float[] normalizeProb(float[] prob){
		float sum = 0;
		for(int i = 0;i<prob.length;i++){
			if(prob[i]==Double.NaN) prob[i] = 0;
			sum += prob[i];
		}
		if(sum==0){
			JistLogger.logOutput(JistLogger.SEVERE,"Sum of 0 found for weight probability.\nExpect NaNs in the next error message and more verbosity.");
		}
		for(int i = 0;i<prob.length;i++){
			prob[i] = prob[i]/sum;
		}
		return prob;
	}

	private void setLabels(float[] prob,int r,int c,int s,int l){
		float maxProb = Float.MIN_VALUE;
		int maxInd = -1;
		for(int i = 0;i < prob.length;i++){
			if(prob[i] > maxProb){
				maxProb = prob[i];
				maxInd = i;
			}
		}
		if(maxProb>consensusThresh && converged.getInt(r, c,s,l)!=1){
			converged.set(r,c,s,l,1);
			numConsensus++;
		}
		if(maxInd==-1){
			JistLogger.logOutput(JistLogger.SEVERE, "Something went wrong in label calculation");
			JistLogger.logOutput(JistLogger.SEVERE, String.format("No Label had any weight at voxel %d %d %d %d",r,c,s,l));
			String str = "Here are the observed labels: ";
			for(ImageData im: images)
				str += im.getInt(r, c, s, l)+", ";
			JistLogger.logOutput(JistLogger.SEVERE, str);
			str = "Here are the label probabilities: ";
			for(float p:prob)
				str += p+", ";
			JistLogger.logOutput(JistLogger.SEVERE, str);
		}
		int[] labelSet = theta.getLabelSet(maxInd);
		for(int i = 0;i<labelSet.length;i++){
			resultLabels.get(i).set(r,c,s,l,labelSet[i]);
		}
		probability.set(r,c,s,l,maxProb);
		estimate.set(r,c,s,l,maxInd);
	}

	public ImageData getProbabilityVolume(){ return probability; }
	public List<ImageData> getLabelSet(){ return resultLabels; }

	private void runMStepVoxel(int r,int c,int s,int k,float[] prob){
		for(int i = 0;i<numRaters;i++){
			int obs = images.get(i).getInt(r, c, s, k);
			theta.runMStep(i,obs,prob);
		}
	}

	private void determineCroppingRegion(){
		JistLogger.logOutput(JistLogger.INFO, "Determining cropping region");
		int r = images.get(0).getRows();
		int c = images.get(0).getCols();
		int s = images.get(0).getSlices();
		int k = images.get(0).getComponents();
		ImageData tmp = new ImageDataInt(r,c,s,k);
		croppingRegion = new int[6];
		Arrays.fill(croppingRegion, 0);
		croppingRegion[1] = r;
		croppingRegion[3] = c;
		croppingRegion[5] = s;
		for(ImageData im:images){
			for(int x = 0;x < r; x++){
				for(int y=0;y<c;y++){
					for(int z=0;z<s;z++){
						for(int a=0;a<k;a++){
							tmp.set(x,y,z,a,tmp.getInt(x, y,z,a)+im.getInt(x,y,z,a));
						}
					}
				}
			}
		}
		//Set x lower bound
		outerloop1: //magic
			for(int x = 0;x < r; x++){
				for(int y=0;y<c;y++){
					for(int z=0;z<s;z++){
						for(int a=0;a<k;a++){
							int val = tmp.getInt(x, y, z, a);
							if(val > 0){
								croppingRegion[0]=x;
								break outerloop1;
							}
						}
					}
				}
			}

		//Set x upper bound
		outerloop2: //magic
			for(int x = r-1;x >=0; x--){
				for(int y=0;y<c;y++){
					for(int z=0;z<s;z++){
						for(int a=0;a<k;a++){
							int val = tmp.getInt(x, y, z, a);
							if(val > 0){
								croppingRegion[1]=x+1;
								break outerloop2;
							}
						}
					}
				}
			}

			//Set y lower bound
			outerloop3: //magic
				for(int y=0;y<c;y++){
					for(int x=0;x<r;x++){
						for(int z=0;z<s;z++){
							for(int a=0;a<k;a++){
								int val = tmp.getInt(x, y, z, a);
								if(val > 0){
									croppingRegion[2]=y;
									break outerloop3;
								}
							}
						}
					}
				}

			//Set y upper bound
			outerloop4: //magic
				for(int y=c-1;y>=0;y--){
					for(int x=0;x<r;x++){
						for(int z=0;z<s;z++){
							for(int a=0;a<k;a++){
								int val = tmp.getInt(x, y, z, a);
								if(val > 0){
									croppingRegion[3]=y+1;
									break outerloop4;
								}
							}
						}
					}
				}

				//Set z lower
				outerloop5: //magic
					for(int z=0;z<s;z++){
						for(int x=0;x<r;x++){
							for(int y=0;y<c;y++){
								for(int a=0;a<k;a++){
									int val = tmp.getInt(x, y, z, a);
									if(val > 0){
										croppingRegion[4]=z;
										break outerloop5;
									}
								}
							}
						}
					}

				//Set z upper
				outerloop6: //magic
					for(int z=s-1;z>=0;z--){
						for(int x=0;x<r;x++){
							for(int y=0;y<c;y++){
								for(int a=0;a<k;a++){
									int val = tmp.getInt(x, y, z, a);
									if(val > 0){
										croppingRegion[5]=z+1;
										break outerloop6;
									}
								}
							}
						}
					}

					JistLogger.logOutput(JistLogger.INFO, String.format("Cropping region is x: [%d %d] y: [%d %d] z: [%d %d]",croppingRegion[0],croppingRegion[1],croppingRegion[2],croppingRegion[3],croppingRegion[4],croppingRegion[5]));

	}

	private boolean inCroppingRegion(int i,int j, int k){
		return (i>=croppingRegion[0]&&i<=croppingRegion[1]&&j>=croppingRegion[2]&&j<=croppingRegion[3]&&k>=croppingRegion[4]&&k<=croppingRegion[5]);
	}

	private void determineConsensusRegion(){
		for(int i=0;i<estimate.getRows();i++){
			for(int j=0;j<estimate.getCols();j++){
				for(int k=0;k<estimate.getSlices();k++){
					for(int l=0;l<estimate.getComponents();l++){
						int n = determineIfConsensus(i,j,k,l);
						if(n != -1){
							estimate.set(i,j,k,l,n);
							probability.set(i,j,k,l,1);
							consensus.set(i,j,k,l,1);
							int[] labelSet = theta.getLabelSet(n);
							for(int m=0;m<labelSet.length;m++){
								resultLabels.get(m).set(i,j,k,l,labelSet[m]);
							}
						}
					}
				}
			}
		}
	}

	private int determineIfConsensus(int i,int j, int k, int l){
		int n = -1;
		int[] labels = new int[numClasses];
		Arrays.fill(labels,-1);
		int label;
		int c;
		ImageData img;
		for(int im = 0;im < images.size();im++){
			img = images.get(im);
			label = img.getInt(i, j,k,l);
			c = classes.get(im);
			if(labels[c]==-1)
				labels[c] = label;
			else if(labels[c] != label)
				return -1;
		}
		n = theta.findLabel(labels);
		return n;
	}

	public List<ImageData> getThetaVolume(){
		theta.normalize();
		return theta.getAsVolume();
	}
}
