package edu.vanderbilt.masi.algorithms.adaboost;

import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class WeakLearnerDecisionStump extends WeakLearner {
	
	private int sign;
	private double thresh;
	private int feature_ind;
	private double split_val;
	private int split_type;
	private SplitCriterionBase split_criterion;
	
	public final static String TYPE = "DecisionStump"; 
	
	public final static int NUM_ELEMENTS = 6;
	
	public final static int SPLIT_TYPE_CLASSIFICATION_RATE = 0;
	public final static int SPLIT_TYPE_GINI = 1;
	public final static int SPLIT_TYPE_INFORMATION_GAIN = 2;
	public final static int SPLIT_TYPE_GAIN_RATIO = 3;

	
	public WeakLearnerDecisionStump(int in_split_type) {
		split_type = in_split_type;
		
		// initialize the split criterion object
		if (split_type == SPLIT_TYPE_CLASSIFICATION_RATE)
			split_criterion = new SplitCriterionClassificationRate();
		else if (split_type == SPLIT_TYPE_GINI)
			split_criterion = new SplitCriterionGini();
		else if (split_type == SPLIT_TYPE_INFORMATION_GAIN)
			split_criterion = new SplitCriterionInformationGain();
		else if (split_type == SPLIT_TYPE_GAIN_RATIO)
			split_criterion = new SplitCriterionGainRatio();
		else
			throw new RuntimeException("Invalid Split Type in WeakLearnerDecisionStump");
		
		reset();
	}
	
	public WeakLearnerDecisionStump(double [] parms) {
		this((int)parms[0]);
	}
	
	public void reset() {
		sign = -1;
		rate = 0;
		thresh = 0;
		feature_ind = -1;
		alpha = 0;
		split_val = Double.NEGATIVE_INFINITY;
	}
	
	public boolean equals(WeakLearnerDecisionStump stump) {
		double [] vals1 = get_vals();
		double [] vals2 = stump.get_vals();
		boolean equal = true;
		for (int i = 0; i < NUM_ELEMENTS; i++)
			equal = equal && vals1[i] == vals2[i];
		return(equal);
	}
	
	public String get_type() { return(TYPE); }
	
	public double [] get_parms() {
		double [] parms = new double [1];
		parms[0] = split_type;
		return(parms);
	}
	
	public int get_num_elements() { return(NUM_ELEMENTS); } 
	
	public void set_vals(double [] vals) {
		set_vals(vals, 0);
	}

	public void set_vals(double [] vals, int st_ind) {
		rate = vals[st_ind++];
		alpha = vals[st_ind++];
		sign = (int)vals[st_ind++];
		thresh = vals[st_ind++];
		feature_ind = (int)vals[st_ind++];
		split_val = vals[st_ind++];
	}	
	
	public void get_vals(double [] vals) {
		get_vals(vals, 0);
	}
	
	public void get_vals(double [] vals, int st_ind) {
		vals[st_ind + 0] = rate;
		vals[st_ind + 1] = alpha;
		vals[st_ind + 2] = sign;
		vals[st_ind + 3] = thresh;
		vals[st_ind + 4] = feature_ind;
		vals[st_ind + 5] = split_val;
	}
	
	public boolean valid() { return(feature_ind >= 0); }
	
	public int classify(float [][] X, int ind) {
		
		if (!valid())
			throw new RuntimeException("Trying to use invalid decision stump");
		
		if (sign == 1)
			return((X[feature_ind][ind] > thresh) ? 1 : -1);
		else
			return((X[feature_ind][ind] <= thresh) ? 1 : -1);
	}
	
	public void print_info(int iter, double error) {
		JistLogger.logOutput(JistLogger.INFO, String.format("Iteration %03d (%.5f): %s", 
															iter+1, error, get_summary_string()));
	}
	
	public String get_criterion_string() {
		return(split_criterion.toString());
	}
	
	public String get_summary_string() {
		String str = String.format("%.5f, %.5f, %.5f, %04d, %s, %.5f", 
				                   1-rate, split_val, alpha, feature_ind, 
								   (sign == 1) ? "+1" : "-1", thresh);
		return(str);
	}

	public void find_best_learner(boolean [] Y,
								  float [][] X,
								  int [][] sort_inds,
								  double [] W,
								  int num_samples_total,
								  int num_features) {
		
		
		double [] best_vals = new double [NUM_ELEMENTS];
		best_vals[NUM_ELEMENTS-1] = Double.NEGATIVE_INFINITY;
		
		// set the fraction positive given the weights
		double frac_pos = 0;
		for (int j = 0; j < num_samples_total; j++)
			if (Y[j])
				frac_pos += W[j];
		
		// find the best feature (weak learner)
		for (int f = 0; f < num_features; f++) {
			
			// reset the values
			reset();
			feature_ind = f;
			
			// optimize the decision stump for this feature
			optimize_decision_stump(Y, X, sort_inds, W, num_samples_total, f, frac_pos);
			
			// keep only the best learner
			if (split_val > best_vals[NUM_ELEMENTS-1])
				get_vals(best_vals);
		}
		
		// set the values to the best decision stump
		set_vals(best_vals);
		set_alpha();
	}

	public void find_best_learner(boolean [] Y,
								  float [][] X,
								  int [][] sort_inds,
								  double [] W,
								  int num_samples_total,
								  int num_features,
								  boolean [] kk) {
		
		// get the re-normalized weights
		double [] new_W = get_renormalized_weights(W, kk, num_samples_total);
		
		// find the best learner using the updated weights
		find_best_learner(Y, X, sort_inds, new_W, num_samples_total, num_features);
	}
	
	private void optimize_decision_stump (boolean [] Y,
										  float [][] X,
										  int [][] sort_inds,
										  double [] W,
										  int num_samples_total,
										  int f,
										  double frac_pos) {
		
		// initialize
		float [] tX = X[f];
		int [] si = sort_inds[f];
		double p1 = frac_pos;
		boolean first = true;
		
		// find the optimal threshold
		int ii;
		for (int i = 0; i < num_samples_total; i++) {
			ii = si[i];
			if (W[ii] > 0) {
				
				// initialize, if this is the first time we're in the loop
				if (first) {

					if (p1 >= 1-p1)
						set(1, p1, tX[ii]-0.001f);
					else
						set(-1, 1-p1, tX[ii]-0.001f);
					first = false;
				}
				
				// update the rate
				p1 += (Y[ii]) ? -W[ii] : W[ii];
				
				// change the best values if we have to
				if (i < num_samples_total-1 && tX[ii] == tX[si[i+1]])
					continue;
				else if (p1 > rate)
					set(1, p1, tX[ii]);
				else if (1-p1 > rate)
					set(-1, 1-p1, tX[ii]);
			}
	    }
		
		// calculate the split criterion value
		split_val = split_criterion.get(this, X, Y, W, num_samples_total);
	}
	
	private void set(int sign, double rate, double thresh) {
		this.sign = sign;
		this.rate = rate;
		this.thresh = thresh;
	}

	private double [] get_renormalized_weights(double [] W,
											   boolean [] kk,
											   int num_samples_total) {
		
		// set the fraction positive given the weights
		double [] Wt = new double [num_samples_total];
		double norm_fact = 0;
		for (int j = 0; j < num_samples_total; j++) {
			Wt[j] = (kk[j]) ? W[j] : 0;
			norm_fact += Wt[j];
		}
		
		// normalize the weights
		for (int j = 0; j < num_samples_total; j++)
			if (kk[j])
				Wt[j] /= norm_fact;
		
		return(Wt);
		
	}
}
