package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MultiClassPerformanceParameters {

	private int numClasses;
	private int numRaters;
	private MultiClassPerformanceParameterBase[][] thetas;
	private MultiClassPerformanceParameterBase[][] initialProbabilities;
	//num labels x num classes
	private int[][] ind;
	//num labels x num raters
	private float[][] alpha;
	private int[] classes;
	private HashMap<Integer,String> intToClass;
	private int numElements;


	public MultiClassPerformanceParameters(){}

	public List<ImageData> getAsVolume(){
		int maxSize = -1;
		for(int i=0;i<thetas.length;i++){
			if(thetas[i][0].getNumObsLabels() > maxSize){
				maxSize = thetas[i][0].getNumObsLabels();
			}
		}
		List<ImageData> ims = new ArrayList<ImageData>(numRaters);
		for(int i=0;i<numRaters;i++){
			ims.add(new ImageDataFloat(thetas[i][0].getNumObsLabels(),maxSize,numClasses));
			ims.get(i).setName("Rater "+i+" theta values");
		}
		for(int i=0;i<numRaters;i++){
			for(int j=0;j<numClasses;j++){
				for(int k=0;k<thetas[i][j].getNumObsLabels();k++){
					for(int l=0;l<thetas[i][j].getNumTargetLabels();l++){
						ims.get(i).set(k,l,j,thetas[i][j].getValue(k, l));
					}
				}
			}
		}
		return ims;
	}

	public void copy(MultiClassPerformanceParameters c){
		numClasses = c.getNumClasses();
		numRaters  = c.getNumRaters();
		classes = c.classes;
		intToClass = c.intToClass;
		thetas = new MultiClassPerformanceParameterBase[c.thetas.length][c.thetas[0].length];
		initialProbabilities = new MultiClassPerformanceParameterBase[c.initialProbabilities.length][c.initialProbabilities[0].length];
		for(int i=0;i<initialProbabilities.length;i++){
			for(int j=0;j<initialProbabilities[i].length;j++){
				initialProbabilities[i][j] = c.initialProbabilities[i][j].copy();
			}
		}

		for(int i=0;i<thetas.length;i++){
			for(int j=0;j<thetas[0].length;j++){
				thetas[i][j] = c.thetas[i][j].copy();
			}
		}
		ind = c.ind;
		alpha = new float[c.alpha.length][c.alpha[0].length];
		for(int i=0;i<alpha.length;i++){
			for(int j=0;j<alpha[0].length;j++){
				alpha[i][j] = c.alpha[i][j];
			}
		}
	}

	public int getNumRaters(){ return numRaters;}

	public void loadFromFile(File f,HashMap<String,Integer> classMap){
		intToClass = new HashMap<Integer,String>(classMap.size());
		for(String s: classMap.keySet()) intToClass.put(classMap.get(s), s);
		try {
			BufferedReader br = new BufferedReader(new FileReader(f));
			numClasses = Integer.parseInt(br.readLine());
			JistLogger.logOutput(JistLogger.INFO, numClasses + " classes found in the theta file");
			initialProbabilities = new MultiClassPerformanceParameterBase[numClasses][numClasses];
			String line;
			String[] lineBits;
			while((line=br.readLine())!=null){
				lineBits = line.split(",");
				if(!classMap.containsKey(lineBits[0]))
					classMap.put(lineBits[0], classMap.values().size());
				if(!classMap.containsKey(lineBits[1]))
					classMap.put(lineBits[1], classMap.values().size());
				initialProbabilities[classMap.get(lineBits[0])][classMap.get(lineBits[1])]
						= new MultiClassPerformanceParameterBase(br);
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	public int getNumClasses(){ return numClasses; }

	public void initialize(){
		generateIndMatrix();
		giveLittleWeight();
		normalize();
		calculateAlpha();
		calculateNumElements();
	}

	private void calculateNumElements(){
		numElements = 0;
		for(MultiClassPerformanceParameterBase[] row: thetas){
			for(MultiClassPerformanceParameterBase t: row){
				numElements += t.getNumObsLabels()*t.getNumTargetLabels();
			}
		}
	}

	private void giveLittleWeight(){
		for(MultiClassPerformanceParameterBase[] row: thetas){
			for(MultiClassPerformanceParameterBase t: row){
				t.giveLittleWeight(0.000005f);
			}
		}
	}

	public void setInitialThetas(ArrayList<Integer> classes1){
		classes = new int[classes1.size()];
		Arrays.fill(classes, -1);
		numRaters = classes1.size();
		thetas = new MultiClassPerformanceParameterBase[numRaters][numClasses];
		for(int i = 0; i<classes1.size(); i++){
			int c = classes1.get(i);
			classes[i] = c;
			for(int j=0;j<numClasses;j++){
				thetas[i][j] = initialProbabilities[c][j].copy();
			}
		}
	}

	private void generateIndMatrix(){
		JistLogger.logOutput(JistLogger.INFO, "Starting Ind Matrix Generation");
		JistLogger.logFlush();

		int[][] indTemp1 = new int[initialProbabilities[0][0].getNumTargetLabels()][1];

		//Fill with first column
		for(int i=0;i<indTemp1.length;i++) indTemp1[i][0] = i;

		//Fill the rest
		//This is horrible
		for(int i=1;i<numClasses;i++){;
		int numLabels = initialProbabilities[0][i].getNumTargetLabels();
		int[][] indTemp2 = new int[numLabels*indTemp1.length][i+1];
		for(int j=0;j<numLabels;j++){
			for(int k=0;k<indTemp1.length;k++){
				int idx = j*indTemp1.length + k;
				indTemp2[idx][i] = j;
				for(int l=0;l<indTemp1[0].length;l++) indTemp2[idx][l] = indTemp1[k][l];
			}
		}
		ind = indTemp2;
		filterInds();
		indTemp1 = ind;
		}

		ind = indTemp1;
	}

	private boolean testIfPossible(int[] labelSet){
		if(labelSet.length==1) return true;
		for(int i=0;i<labelSet.length;i++){
			for(int j=0;j<labelSet.length;j++){
				if(initialProbabilities[i][j].getValue(labelSet[i],labelSet[j])==0) return false;
			}
		}
		return true;
	}

	private void filterInds(){
		int maxInd = ind.length;
		JistLogger.logOutput(JistLogger.INFO, "There are "+maxInd+" possible label sets before filtering");
		JistLogger.logFlush();
		int[][] indTemp = new int[ind.length][ind[0].length];
		for(int i=0;i<ind.length;i++){
			for(int j=0;j<ind[0].length;j++){
				indTemp[i][j] = ind[i][j];
			}
		}
		boolean[] keep = new boolean[indTemp.length];
		Arrays.fill(keep, false);
		int n=0;
		for(int i=0;i<maxInd;i++){
			if(testIfPossible(indTemp[i])){
				keep[i]=true;
				n++;
			}
		}

		ind = new int[n][numClasses];
		n=0;
		for(int i=0;i<indTemp.length;i++){
			if(keep[i]){
				ind[n] = indTemp[i];
				n++;
			}
		}
		JistLogger.logOutput(JistLogger.INFO,"There are "+ind.length+" indices remaining after filtering");
		JistLogger.logFlush();
	}

	public void calculateAlpha(){
		alpha = new float[ind.length][numRaters];
		for(int j=0;j<numRaters;j++){
			for(int i=0;i<alpha.length;i++){
				alpha[i][j] = calculateAlpha(j,i);
			}
		}
	}

	private float calculateAlpha(int r,int idx){
		int[] l = ind[idx];
		float[] vec = new float[thetas[r][0].getNumObsLabels()];
		Arrays.fill(vec, 1f);

		// Load array with appropriate values
		for(int i=0;i<vec.length;i++){
			// For each observed label add the product over all labels in the truth
			for(int j=0;j<numClasses;j++){
				vec[i] *= thetas[r][j].getValue(i, l[j]);
			}
		}

		// Determine the max and min for binary search
		float ma = Float.MIN_VALUE;
		float mi = Float.MAX_VALUE;
		for(float el:vec){
			if(ma < el){
				ma = el;
			}
			if(mi > el){
				mi = el;
			}
		}
		if(ma<0.00001)
			return 0;
		if(ma > 0.9999){
			ma = 0.9999f;
		}
		ma = toLogSpace(ma,vec.length);
		mi = toLogSpace(mi,vec.length);
		float me = ((ma - mi)/2) + mi;
		boolean allOne = true;
		for(int i=0;i<vec.length;i++){
			if(vec[i] !=1){
				allOne = false;
				break;
			}
		}
		float res;
		if(allOne){
			res = toLogSpace(0.9999f,vec.length);
		}
		else{
			res = runBinarySearch(mi,me,ma,vec);
		}
		return res;
	}

	private float toLogSpace(float val,int n){
		float top = (float) Math.log(1/(float)n);
		float bottom = (float) Math.log(val);
		return top/bottom;
	}

	private float runBinarySearch(float mi,float me, float ma,float[] vec){
		float val = 0;
		for(float el:vec){
			val += Math.pow(el, me);
		}
		float res = Math.abs(val - 1);
		if(res < 0.0001){
			return me;
		}
		else if(val > 1){
			mi = me;
			me = (ma - mi)/2 + mi;
			return runBinarySearch(mi,me,ma,vec);
		}else if(val < 1){
			ma = me;
			me = (ma - mi)/2 + mi;
			return runBinarySearch(mi,me,ma,vec);
		}
		JistLogger.logError(JistLogger.WARNING, "Something went wrong in alpha calculation");
		return 0;
	}

	public int getNumLabels(){ return ind.length; }

	public void reset(){
		for(MultiClassPerformanceParameterBase[] row: thetas){
			for(MultiClassPerformanceParameterBase t: row){
				t.reset();
			}
		}
	}

	public void setUpToBuildMatrices(int[] numLabels,HashMap<Integer,String> classMap){
		intToClass = classMap;
		int n = numLabels.length;
		thetas = new MultiClassPerformanceParameterBase[n][n];
		classes = new int[numLabels.length];
		for(int i=0;i<n;i++){
			classes[i] = i;
			for(int j=0;j<n;j++){
				thetas[i][j] = new MultiClassPerformanceParameterBase(numLabels[i],numLabels[j]);
			}
		}
	}

	/**
	 * Calculates the log of the probability of the label given the rater's observation and confusion matrices
	 * 
	 * @param rater The index of the rater
	 * @param obs The observed label of the rater
	 * @param label the index of the label we are working with
	 * @return The log of the probability that the rater observed obs given that the true label is indexed by label
	 */
	public double getLogProbability(int rater,int obs, int label){
		int[] l = ind[label];
		double prob = 0;
		//loop over all elements in the label
		for(int i=0;i<l.length;i++){
			float val = thetas[rater][i].getValue(obs, l[i]);
			prob += Math.log(val);
		}
		prob *= alpha[label][rater];
		return prob;
	}

	public void normalize(){
		for(int i = 0;i < thetas.length;i++){
			for(int j=0;j<thetas[i].length;j++){
				thetas[i][j].normalize();
			}
		}
	}

	public int[] getLabelSet(int i){
		return ind[i];
	}

	public void runMStep(int rater,int obs,float[] prob){
		int[] l;
		for(int i=0;i<prob.length;i++){
			l = ind[i];
			if(prob[i] > 0){
				for(int j=0;j<l.length;j++){
					thetas[rater][j].addValue(obs,l[j],prob[i]*alpha[i][rater]);
				}
			}
		}
	}

	public String toString(){
		String s = "";
		for(int i =0;i<thetas.length;i++){
			for(int j=0;j<thetas[i].length;j++){
				s += intToClass.get(classes[i]);
				s += ",";
				s += intToClass.get(j);
				s += "\n";
				s += thetas[i][j].getNumObsLabels();
				s += ",";
				s += thetas[i][j].getNumTargetLabels();
				s += "\n";
				s += thetas[i][j].toString();
			}
		}
		s = (String) s.substring(0,s.length()-1);
		return s;
	}

	public float calculateConvergence(MultiClassPerformanceParameters theta2){
		float diff = 0;
		for(int i=0;i<thetas.length;i++){
			for(int j=0;j<thetas[i].length;j++){
				diff += thetas[i][j].calculateDiff(theta2.thetas[i][j]);
			}
		}
		return diff/numElements;
	}

	public void nudge(float v){
		for(int i=0;i<thetas.length;i++){
			for(int j=0;j<thetas[i].length;j++){
				thetas[i][j].nudge(v,initialProbabilities[classes[i]][j]);
			}
		}
	}

	public void addValue(int c1,int c2, int obs, int target, float val){
		thetas[c1][c2].addValue(obs, target, val);
	}

	public HashMap<Integer,String> getClassMap(){ return intToClass; }

	public int findLabel(int[] labels){
		boolean correct;
		for(int i=0;i<ind.length;i++){
			correct = true;
			for(int j=0;j<labels.length;j++){
				if(labels[j] != ind[i][j]){
					correct = false;
					break;
				}
			}
			if(correct)
				return i;
		}
		return -1;
	}
}
