package edu.vanderbilt.masi.algorithms.clasisfication.decisiontree;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.TreeSet;
import java.util.HashSet;

import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class GiniCalculator extends ImpurityCalculator {

	private HashMap<Integer,Float> labels;
	private HashMap<Integer,Integer> counts;

	public GiniCalculator(){super();}

	@Override
	public Cutoff calculate(float[][] features,int[] target) {
		this.labels = new HashMap<Integer,Float>();
		this.counts = new HashMap<Integer,Integer>();
		for(int i=0;i<target.length;i++){
			if(this.labels.containsKey(target[i])){
				this.labels.put(target[i], this.labels.get(target[i])+1);
			}else this.labels.put(target[i], (float) 1);
		}
		for(int i:this.labels.keySet()){
			this.counts.put(i, Math.round(this.labels.get(i)));
			this.labels.put(i, this.labels.get(i)/target.length);
			JistLogger.logOutput(JistLogger.FINE, "Label "+i+" has probability "+this.labels.get(i));
		}
		Cutoff C = null;
		float best = -10000000;
		for(int i=0;i<features[0].length;i++){
			JistLogger.logOutput(JistLogger.FINE, "Calculating GINI for feature number "+i);
			ArrayList<FeatureCalculationComparable> F = new ArrayList<FeatureCalculationComparable>(features.length);
			for(int j=0;j<features.length;j++) F.add( new FeatureCalculationComparable(features[j][i],target[j]));
			HashMap<Integer,Integer> numRight = new HashMap<Integer,Integer>();
			HashMap<Integer,Integer> numLeft = new HashMap<Integer,Integer>();
			int right = 0;
			int left = 0;
			for(int j:this.labels.keySet()){
				numRight.put(j,this.counts.get(j));
				right += this.counts.get(j);
				numLeft.put(j,0);
			}
			
			Collections.sort(F);
			for(FeatureCalculationComparable f:F){
				left++;
				right--;
				int c = f.getC();
				numRight.put(c, numRight.get(c)-1);
				numLeft.put(c, numLeft.get(c)+1);
				float pL = ((float) left)/(left+right);
				float pR = ((float) right)/(left+right);
				float score = 0;
				for(int j: numLeft.values()) score += pL*(((float) j)/left)*(((float) j)/left);
				for(int j: numRight.values()) score += pR*(((float) j)/right)*(((float) j)/right);
				if(score>best){
					best = score;
					C = new Cutoff(i,f.getVal());
				}
			}
			JistLogger.logOutput(JistLogger.FINE, "The score was "+best);
		} 
		if(C !=null){
			JistLogger.logOutput(JistLogger.FINE, "The best features is number "+C.num+" with a value of "+C.val+" with an impurity of "+best);
			JistLogger.logFlush();
		}else{
			C = new Cutoff(-1,-1);
		}
		return C;
	}

}
