package edu.vanderbilt.masi.algorithms.clasisfication.decisiontree;

import java.util.HashMap;
import java.util.TreeSet;

import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class InformationGainCalculator extends ImpurityCalculator {
	
	private HashMap<Integer,Float> labels;

	@Override
	public Cutoff calculate(float[][] features, int[] target) {
		this.labels = new HashMap<Integer,Float>();
		for(int i=0;i<target.length;i++){
			if(this.labels.containsKey(target[i])){
				this.labels.put(target[i], this.labels.get(target[i])+1);
			}else this.labels.put(target[i], (float) 1);
		}
		for(int i:this.labels.keySet()){
			this.labels.put(i, this.labels.get(i)/target.length);
			JistLogger.logOutput(JistLogger.FINE, "Label "+i+" has probability "+this.labels.get(i));
		}
		Cutoff C = null;
		TreeSet<Float> vals;
		float best = -10000000;
		for(int i=0;i<features[0].length;i++){
			JistLogger.logOutput(JistLogger.FINE, "Calculating GINI for feature number "+i);
			vals = new TreeSet<Float>();
			for(int j=0;j<features.length;j++) vals.add(features[j][i]);
			JistLogger.logOutput(JistLogger.FINER, "There are "+vals.size()+" values to calculate for.");
			for(float f:vals){
				float impurity = calculateImpurity(features,target,i,f);
				if(impurity>best){
					best = impurity;
					C = new Cutoff(i,f);
				}
				JistLogger.logOutput(JistLogger.FINEST, "The impurity for the value "+f+" was "+impurity);
			}
		}
		if(C !=null){
			JistLogger.logOutput(JistLogger.FINE, "The best features is number "+C.num+" with a value of "+C.val+" with an impurity of "+best);
			JistLogger.logFlush();
		}else{
			C = new Cutoff(-1,-1);
		}
		return C;
	}
	
	private float calculateImpurity(float[][] features,int[] target,int feature,float value){
		float score = 0;
		HashMap<Integer,Integer> numRight = new HashMap<Integer,Integer>();
		HashMap<Integer,Integer> numLeft = new HashMap<Integer,Integer>();
		for(int i:this.labels.keySet()){
			score += this.labels.get(i)*Math.log(this.labels.get(i))/Math.log(2);
			numRight.put(i,0);
			numLeft.put(i,0);
		}
		int right=0;
		int left=0;
		for(int i=0;i<target.length;i++){
			int l = target[i];
			if(features[i][feature]<value){
				left++;
				numLeft.put(l, numLeft.get(l)+1);
			}else{
				right++;
				numRight.put(l, numRight.get(l)+1);
			}
		}
		float pL = ((float) left)/(left+right);
		float pR = ((float) right)/(left+right);
		for(int i: numLeft.values()) score -= pL*(((float) i)/left)*Math.log((((float) i)/left))/Math.log(2);
		for(int i: numRight.values()) score -= pR*(((float) i)/right)*Math.log((((float) i)/right))/Math.log(2);
		return score;
	}

}
