package edu.vanderbilt.masi.algorithms.clasisfication.decisiontree;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;

import org.json.*;

import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.algorithms.clasisfication.Classifier;

public class DecisionTreeClassifier extends Classifier implements Cloneable {
	
	private int max_depth=10;
	private SplitCriterion split_criterion=SplitCriterion.ENTROPY;
	private int min_leaf=10;
	private int min_split=15;
	private Node head;
	
	public void setMaxDepth(int d){
		this.max_depth = d;
	}
	
	public void setSplitCriterion(SplitCriterion S){
		this.split_criterion = S;
	}
	
	public void setMinLeaf(int m){
		this.min_leaf=m;
	}
	
	public void setMinSplit(int m){
		this.min_split = m;
	}

	@Override
	public void writeToFile(File f) {
		try {
			FileWriter fstream = new FileWriter(f);
			BufferedWriter bw = new BufferedWriter(fstream);
			JSONObject obj = new JSONObject();
			obj.put("classifier_type", "Decision Tree");
			obj.put("max_depth", this.max_depth);
			obj.put("min_leaf",this.min_leaf);
			obj.put("min_split",this.min_split);
			obj.put("trained", this.trained);
			switch(this.split_criterion){
			case ENTROPY:
				obj.put("split_criterion", "entropy");
				break;
			case GINI:
				obj.put("split_criterion", "gini");
				break;
			case RANDOM:
				obj.put("split_criterion", "random");
				break;
			default:
				break;
			}
			if(this.trained){
				obj.put("head", this.head.toJSON());
			}
			bw.write(obj.toString());
			bw.close();
			fstream.close();
		} catch (IOException e) {
			e.printStackTrace();
		} catch (JSONException e) {
			e.printStackTrace();
		}
		
	}

	
	public void train(float[][] features,int[] target,float[] weight) {
		ImpurityCalculator C = null;
		if(this.split_criterion==SplitCriterion.GINI) C = new GiniCalculator();
		else if(this.split_criterion==SplitCriterion.ENTROPY) C = new InformationGainCalculator();
		float[][] training_features = new float[(int) (target.length*.7)][features[0].length];
		int[] training_target = new int[(int) (target.length*.7)];
		float[][] test_features = new float[target.length - training_features.length][features[0].length];
		int[] test_target = new int[target.length-training_target.length];
		JistLogger.logOutput(JistLogger.FINE, "Building training and validation test");
		int num_training = 0;
		int num_testing = 0;
		HashSet<Integer> nums = new HashSet<Integer>(training_target.length*10);
		Random rand = new Random();
		while(nums.size()<training_target.length){
			int nextNum = rand.nextInt(training_target.length);
			nums.add(nextNum);
		}
		
		for(int i=0;i<target.length;i++){
			if(!nums.contains(i)){
				test_target[num_testing] = target[i];
				test_features[num_testing] = features[i];
				num_testing++;
			}else{
				training_features[num_training] = features[i];
				training_target[num_training] = target[i];
				num_training++;
			}
		}
		JistLogger.logOutput(JistLogger.FINE, "Starting Building of Tree");
		this.head = new Node(1,this.max_depth,training_target,this.min_leaf);
		this.head.run(training_features, training_target, weight, C);
		this.trained = true;
		int num = 0;
		for(int i=0;i<test_target.length;i++){
			int v = this.head.classify(test_features[i]);
			if(v==test_target[i]) num++;
		}
		JistLogger.logOutput(JistLogger.INFO, "The classification rate was "+num+"/"+test_target.length);
		JistLogger.logFlush();
	}

	@Override
	public void buildFromJSON(JSONObject obj) {
		try {
			this.max_depth = obj.getInt("max_depth");
			this.min_leaf  = obj.getInt("min_leaf");
			this.min_split = obj.getInt("min_split");
			String split = obj.getString("split_criterion");
			if(split.equals("entropy")) this.split_criterion = SplitCriterion.ENTROPY;
			else if(split.equals("gini")) this.split_criterion = SplitCriterion.GINI;
			else if(split.equals("random")) this.split_criterion = SplitCriterion.RANDOM;
			this.trained = obj.getBoolean("trained");
			if(this.trained) this.head = new Node(obj.getJSONObject("head"));
		} catch (JSONException e) {
			e.printStackTrace();
		}
		
	}
	
	public DecisionTreeClassifier basicClone(){
		DecisionTreeClassifier DT = new DecisionTreeClassifier();
		DT.setMaxDepth(this.max_depth);
		DT.setMinLeaf(this.min_leaf);
		DT.setMinSplit(this.min_split);
		DT.setSplitCriterion(this.split_criterion);
		return DT;
	}

	@Override
	public HashMap<Integer, Float> score(float[] features) {
		return this.head.score(features);
	}

}
