package edu.vanderbilt.masi.algorithms.clasisfication.decisiontree;

import java.util.*;

import org.json.*;

import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class Node {

	public HashMap<Integer,Float> classProbabilities;
	public Cutoff cutoff=null;
	public int depth;
	public int max_depth;
	public Node left;
	public Node right;
	private boolean leaf=false;
	private int min_leaf;
	private int classification=-1;


	public Node(int depth,int max_depth,int[] labels,int min_leaf){
		this.min_leaf = min_leaf;
		JistLogger.logOutput(JistLogger.FINE, "Starting on depth "+depth+" there are "+labels.length+" samples here");
		JistLogger.logFlush();
		this.max_depth = max_depth;
		this.depth=depth;
		this.classProbabilities = new HashMap<Integer,Float>();
		for(int i=0;i<labels.length;i++){
			if(!this.classProbabilities.containsKey(labels[i])) this.classProbabilities.put(labels[i], (float) 1);
			else this.classProbabilities.put(labels[i], this.classProbabilities.get(labels[i])+1);
		}
		for(int i:this.classProbabilities.keySet()) this.classProbabilities.put(i, this.classProbabilities.get(i)/labels.length);
	}

	public Node(JSONObject obj) {
		try {
			this.min_leaf = obj.getInt("min_leaf");
			this.max_depth = obj.getInt("max_depth");
			this.depth = obj.getInt("depth");
			this.classification = obj.getInt("classification");
			this.leaf = obj.getBoolean("leaf");
			if(!this.leaf){
				if(obj.has("left"))	this.left = new Node(obj.getJSONObject("left"));
				else this.left = null;
				if(obj.has("right"))this.right = new Node(obj.getJSONObject("right"));
				else this.right=null;
			}
			if(obj.has("cutoff")){
				JSONObject cut = obj.getJSONObject("cutoff");
				this.cutoff = new Cutoff(cut.getInt("num"),(float) cut.getDouble("val"));
			}
			JSONObject t = obj.getJSONObject("probabilities");
			JSONArray a = t.names();
			this.classProbabilities = new HashMap<Integer,Float>(a.length());
			for(int i=0;i<a.length();i++){
				String s = a.getString(i);
				this.classProbabilities.put(Integer.parseInt(s), (float) t.getDouble(s));
			}
		} catch (JSONException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	public void run(float[][] features, int[] target, float[] weight, ImpurityCalculator IC){
		float maxProb = (float) 0;
		for(int i: this.classProbabilities.keySet()){
			if(this.classProbabilities.get(i)>maxProb){
				maxProb = this.classProbabilities.get(i);
				this.classification = i;
			}
		}
		if(this.depth<this.max_depth&&maxProb<0.9999&&target.length>this.min_leaf){
			this.cutoff = IC.calculate(features, target);
			if(this.cutoff.num<0){
				JistLogger.logOutput(JistLogger.FINE, "There is no impurity in this node.  Punting.");
				JistLogger.logFlush();
				return;
			}
			this.continueTree(features, target, weight, IC);
		}else if(maxProb>=0.99){
			this.leaf=true;
			JistLogger.logOutput(JistLogger.FINE, "The Maximum Probability in this Node is "+maxProb+". Punting.");
		}else if(this.depth>=this.max_depth){
			this.leaf=true;
			JistLogger.logOutput(JistLogger.FINE, "Reached Max Depth. Punting.");
		}else if(target.length<=this.min_leaf){
			this.leaf=true;
			JistLogger.logOutput(JistLogger.FINE, "The number of samples in this node is less than the min leaf size. Punting.");
		}
		JistLogger.logFlush();
	}

	private void continueTree(float[][] features, int[] target, float[] weight, ImpurityCalculator IC){
		int[] tree = new int[target.length];
		int num = 0;
		for(int i=0;i<target.length;i++){
			if(features[i][this.cutoff.num]<this.cutoff.val) tree[i] = -1;
			else{
				tree[i] = 1;
				num++;
			}
		}
		float[][] leftFeatures = new float[target.length-num+1][features[0].length];
		float[][] rightFeatures = new float[num][features[0].length];
		int[] leftTarget = new int[target.length-num+1];
		int[] rightTarget = new int[num];
		float[] leftWeight = new float[target.length-num+1];
		float[] rightWeight = new float[num];
		int l=0;
		int r=0;
		for(int i=0;i<target.length;i++){
			if(tree[i]==-1){
				leftFeatures[l] = features[i];
				leftTarget[l] = target[i];
				leftWeight[l] = weight[i];
				l++;
			}else{
				rightFeatures[r] = features[i];
				rightTarget[r] = target[i];
				rightWeight[r] = weight[i];
				r++;
			}
		}
		this.left = new Node(this.depth+1,this.max_depth,leftTarget,this.min_leaf);
		this.left.run(leftFeatures, leftTarget, leftWeight, IC);
		this.right = new Node(this.depth+1,this.max_depth,rightTarget,this.min_leaf);
		this.right.run(rightFeatures, rightTarget, rightWeight, IC);
	}

	public int classify(float features[]){
		JistLogger.logOutput(JistLogger.FINE, "In depth "+this.depth);
		if(this.leaf==true||this.cutoff.num<0){
			return this.classification;
		}else if(features[this.cutoff.num]<this.cutoff.val){
			return this.left.classify(features);
		}else{
			return this.right.classify(features);
		}
	}

	public HashMap<Integer,Float> score(float features[]){
		JistLogger.logOutput(JistLogger.FINER, "Scoring in depth "+this.depth);
		if(this.leaf==true||this.cutoff.num<0){
			return this.classProbabilities;
		}else if(features[this.cutoff.num]<this.cutoff.val){
			if(this.left!=null)	return this.left.score(features);
			else return this.classProbabilities;
		}else{
			if(this.right!=null)return this.right.score(features);
			else return this.classProbabilities;
		}
	}

	public JSONObject toJSON(){
		JSONObject obj = new JSONObject();
		try {
			if(!this.leaf)obj.put("cutoff",this.cutoff.toJSON());
			obj.put("probabilities",this.classProbabilities);
			obj.put("leaf", this.leaf);
			if(this.left != null)obj.put("left",this.left.toJSON());
			if(this.right != null) obj.put("right", this.right.toJSON());
			obj.put("depth", this.depth);
			obj.put("max_depth", this.max_depth);
			obj.put("min_leaf", this.min_leaf);
			obj.put("classification", this.classification);
		} catch (JSONException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

		return obj;
	}

}
