package edu.vanderbilt.masi.algorithms.adaboost;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class WeakLearnerDecisionTree extends WeakLearner {
	
	private WeakLearnerDecisionStump [][] stumps;
	private int max_depth;
	private int min_samples;
	private int [] num_stumps;
	private int split_type;
	public final static String TYPE = "DecisionTree";
	
	public WeakLearnerDecisionTree(int in_max_depth, int in_min_samples, int in_split_type) {
		
		// set the parameters for this tree
		max_depth = in_max_depth;
		min_samples = in_min_samples;
		split_type = in_split_type;
		
		// initialize the stumps
		stumps = new WeakLearnerDecisionStump [max_depth][];
		num_stumps = new int [max_depth];
		for (int d = 0; d < max_depth; d++) {
			num_stumps[d] = (int)Math.pow(2, d);
			stumps[d] = new WeakLearnerDecisionStump[num_stumps[d]];
			for (int i = 0; i < num_stumps[d]; i++)
				stumps[d][i] = new WeakLearnerDecisionStump(split_type);
		}
		
		// set all of the data to its default value
		reset();
	}
	
	public WeakLearnerDecisionTree(double [] parms) {
		this((int)parms[0], (int)parms[1], (int)parms[2]);
	}
	
	public void reset() {
		rate = 0;
		alpha = 0;
		for (int d = 0; d < max_depth; d++)
			for (int i = 0; i < num_stumps[d]; i++)
				stumps[d][i].reset();
	}
	
	public String get_type() { return(TYPE); }
	
	public double [] get_parms() {
		double [] parms = new double [3];
		parms[0] = max_depth;
		parms[1] = min_samples;
		parms[2] = split_type;
		return(parms);
	}
	
	public int get_num_elements() {
		int num_stump_elements = 0;
		for (int d = 0; d < max_depth; d++)
			for (int i = 0; i < num_stumps[d]; i++)
				if (stumps[d][i].valid())
					num_stump_elements += stumps[d][i].get_num_elements();
				else
					num_stump_elements++;
		return(2 + num_stump_elements);
	}
	
	public void set_vals(double [] vals) {
		rate = vals[0];
		alpha = vals[1];
		
		int c = 2;
		for (int d = 0; d < max_depth; d++)
			for (int i = 0; i < num_stumps[d]; i++) {
				if (vals[c] == -1) {
					stumps[d][i].reset();
					c++;
				} else {
					stumps[d][i].set_vals(vals, c);
					c += WeakLearnerDecisionStump.NUM_ELEMENTS;
				}
			}
	}
	
	public void get_vals(double [] vals) {
		vals[0] = rate;
		vals[1] = alpha;
		
		int c = 2;
		for (int d = 0; d < max_depth; d++)
			for (int i = 0; i < num_stumps[d]; i++) {
				if (stumps[d][i].valid()) {
					stumps[d][i].get_vals(vals, c);
					c += WeakLearnerDecisionStump.NUM_ELEMENTS;
				} else {
					vals[c] = -1;
					c++;
				}
			}
	}
	
	public void find_best_learner(boolean [] Y,
								  float [][] X,
								  int [][] sort_inds,
								  double [] W,
								  int num_samples_total,
								  int num_features) {
		
		reset();
		
		stumps[0][0].find_best_learner(Y, X, sort_inds, W, num_samples_total, num_features);
		
		boolean [] kk = new boolean [num_samples_total];

		for (int d = 1; d < max_depth; d++) {
			for (int n = 0; n < num_stumps[d]; n++) {
				
				// traverse the tree backwards
				boolean valid = true;
				int [] nums = new int [d];
				int [] vals = new int [d];
				int cn = n;
				for (int dp = d-1; dp >= 0; dp--) {
					nums[dp] = cn / 2;
					vals[dp] = (cn % 2 == 0) ? -1 : 1;
					cn /= 2;
					valid = valid && stumps[dp][nums[dp]].valid();
				}
				
				// check if we should stop looking any further
				if (!valid) continue;
				
				// update the keep vector
				Arrays.fill(kk, true);
				int num_keep = 0;
				for (int dp = 0; dp < d; dp++)
					for (int j = 0; j < num_samples_total; j++) {
						if (kk[j])
							kk[j] = stumps[dp][nums[dp]].classify(X, j) == vals[dp];
						
						if (dp == d-1 && kk[j])
							num_keep++;
					}
				
				// make sure we have enough samples
				if (num_keep < min_samples)
					continue;
				
				// find the best learner
				stumps[d][n].find_best_learner(Y, X, sort_inds, W, num_samples_total, num_features, kk);
				
				// make sure that this stump doesn't exactly match previous stump
				if (stumps[d][n].equals(stumps[d-1][nums[d-1]]))
					stumps[d][n].reset();
			}
			
			// make sure we need to continue
			boolean cont = false;
			for (int n = 0; n < num_stumps[d]; n++)
				cont = cont || stumps[d][n].valid();
			if (!cont)
				break;
			
		}
		
		// calculate the final classification rate
		rate = 0;
		for (int i = 0; i < num_samples_total; i++)
			rate += (classify(X, i) > 0 == Y[i]) ? W[i] : 0;
		
		// set the alpha value
		this.set_alpha();
	}
	
	public int classify(float [][] X, int ind) {
		
		// set the initial value
		int val = stumps[0][0].classify(X, ind);
		
		// set the next stump index we want to use
		int ii = val > 0 ? 1 : 0;
		
		// traverse the tree
		for (int d = 1; d < max_depth; d++) {
			
			// if the desired stump is not valid, return the current value
			if (!stumps[d][ii].valid())
				return(val);
			
			// update the value using the current stump
			val = stumps[d][ii].classify(X, ind);
			
			// set the next stump index we want to use
			ii = 2*ii + (val > 0 ? 1 : 0);
		}
		
		// return the final value
		return(val);
	}
	
	public void print_info(int iter, double error) {
		JistLogger.logOutput(JistLogger.INFO, String.format("Iteration %03d (%.5f): %.5f, %.5f", 
															 iter+1, error, alpha, 1-rate));
		for (int d = 0; d < max_depth; d++)
			for (int i = 0; i < num_stumps[d]; i++)
				if (stumps[d][i].valid()) {
					String str = stumps[d][i].get_summary_string();
					JistLogger.logOutput(JistLogger.INFO, String.format("=== %d,%d: %s", d, i, str));
				}
	}
	
	public String get_criterion_string() {
		return(stumps[0][0].get_criterion_string());
	}
}
