package edu.vanderbilt.masi.algorithms.adaboost;

import java.util.Arrays;
import org.json.*;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class AdaBoost extends AbstractCalculation {
	
	public static int FUSION_TYPE_MEAN_CLASSIFIER = 0;
	public static int FUSION_TYPE_MEAN_PROBABILITY = 1;
	public static int FUSION_TYPE_MAXIMUM_LIKELIHOOD = 2;
	
	public static JSONArray train(float [][] X,
						   		  boolean [] Y,
						   		  int numiters,
						   		  WeakLearner bwl,
						   		  int num_samples_total) {
		
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Training AdaBoost Classifier"));
		int num_features = X.length;
		
		// create the sorted indices matrix
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Sorting features for fast thresholding"));
		int [][] sort_inds = new int [num_features][num_samples_total];
		for (int f = 0; f < num_features; f++) {
			JistLogger.printStatusBar(f, num_features);
			
			// sort the temporary Integer index array
			ArrayIndexComparator comparator = new ArrayIndexComparator(X[f], num_samples_total);
			Integer [] tmp = comparator.createIndexArray();
			Arrays.sort(tmp, comparator);
			
			// copy back into the primitive array
			for (int i = 0; i < num_samples_total; i++)
				sort_inds[f][i] = tmp[i];
		}
		JistLogger.logOutput(JistLogger.INFO, "");
		
		// allocate space for the results
		JistLogger.logOutput(JistLogger.INFO, String.format("Allocating space for the results"));
		double [] W, H;
		double W_norm, true_error;
		int chval, yval;
		JSONArray bwlList = new JSONArray(); 
		
		// Allocate space for the weights
		JistLogger.logOutput(JistLogger.INFO, String.format("Allocating space for the weights"));
		W = new double [num_samples_total];
		H = new double [num_samples_total];
		for (int i = 0; i < num_samples_total; i++) {
			W[i] = 1/((double)num_samples_total);
			H[i] = 0;
		}
		
		// iterate the appropriate number of times
		for (int iter = 0; iter < numiters; iter++) {
			
			// initialize the iteration
			W_norm = 0;
			true_error = 0;
			
			// find the best weak learner
			bwl.find_best_learner(Y, X, sort_inds, W, num_samples_total, num_features);

			for (int i = 0; i < num_samples_total; i++) {
					
				// get the true response
				yval = (Y[i]) ? 1 : -1;
				
				// set the guess using this feature
				chval = bwl.classify(X, i);
				
				// get the aggregate response of the adaboost classifiers to this point
				H[i] += chval*bwl.get_alpha();
				
				// get whether or not the adaboost response is correct
				if (yval * H[i] < 0)
					true_error++;
				
				// update the weight for this sample
				W[i] *= Math.exp(-bwl.get_alpha() * (chval*yval));
				
				// keep track of the rolling sum for normalization
				W_norm += W[i];
			}
			
			// normalize the weights and the true error ratio
			for (int i = 0; i < num_samples_total; i++)
				W[i] /= W_norm;
			true_error /= num_samples_total;
			
			// print some information to the screen
			bwl.print_info(iter, true_error);

			// save it to the list of weak learners
			bwlList.put(bwl.toJSON());
		}
		
		return(bwlList);
	}
	
	public static float classify(float [][] X,
			                     int sample_num,
			                     WeakLearner [][] learners,
			                     int fusion_type) {

		double pp = 0;
		int num_classifiers = learners.length;
		if (fusion_type == FUSION_TYPE_MEAN_CLASSIFIER) {

			// iterate over all trained weak learners
			for (int i = 0; i < num_classifiers; i++)
				for (int j = 0; j < learners[i].length; j++)
					pp += learners[i][j].get_alpha() * learners[i][j].classify(X, sample_num);

			
		} else if (fusion_type == FUSION_TYPE_MEAN_PROBABILITY) {

			double val;
			for (int i = 0; i < num_classifiers; i++) {
				val = 0;
				for (int j = 0; j < learners[i].length; j++)
					val += learners[i][j].get_alpha() * learners[i][j].classify(X, sample_num);
				pp += 1 / (1 + Math.exp(-val));
			}
			
		} else if (fusion_type == FUSION_TYPE_MAXIMUM_LIKELIHOOD) {
			
			double val;
			for (int i = 0; i < num_classifiers; i++) {
				val = 0;
				for (int j = 0; j < learners[i].length; j++)
					val += learners[i][j].get_alpha() * learners[i][j].classify(X, sample_num);
				pp += -Math.log(1 + Math.exp(-val));
			}
			
		} else {
			throw new RuntimeException(String.format("Invalid Classifier Fusion Type: %d", fusion_type));
		}
		
		pp /= (double)(num_classifiers);
		return((float)pp);
	}
	
}
