package edu.vanderbilt.masi.algorithms.adaboost;

public class SplitCriterionInformationGain extends SplitCriterionBase {
	
	private final static double log2base = Math.log(2); 
	
	public SplitCriterionInformationGain() {
		str = "Information Gain";
	}
	
	public double get(WeakLearnerDecisionStump stump, float [][] X, boolean [] Y, double [] W, int num_samples_total) {
		
		double frac00 = 0;
		double frac01 = 0;
		double frac10 = 0;
		double frac11 = 0;
		
		for (int i = 0; i < num_samples_total; i++) {
			if (W[i] > 0) {
				if (stump.classify(X, i) == -1) {
					if (!Y[i])
						frac00 += W[i];
					else
						frac01 += W[i];
				} else {
					if (!Y[i])
						frac10 += W[i];
					else
						frac11 += W[i];
				}
			}
		}
		
		double gain = entropy(stump.get_rate()) - 
					  (frac00 + frac01) * entropy(frac00/(frac00 + frac01)) - 
					  (frac10 + frac11) * entropy(frac11/(frac10 + frac11));
		return(gain);
		
	}
	
	protected double entropy(double rate) {
		
		// handle the edge cases
		if (Double.isNaN(rate))
			return(1);
		if (rate <= 0 || rate >= 1)
			return(0);
		
		return(-1 * ((rate * Math.log(rate)/log2base) + ((1-rate) * Math.log(1-rate)/log2base)));
	}


}
