package edu.vanderbilt.masi.algorithms.regression;

import org.w3c.dom.*;

public class Tree {
	Node[] nodes;
	
	public Tree(int decisionLevels) throws Exception {
		if(decisionLevels<0)
			throw new Exception("Tree can't have less than 0 decision levels.");

	    if(decisionLevels>19)
	        throw new Exception("Tree can't have more than 19 decision levels.");

	    // This full allocation of node storage may be wasteful of memory
	    // if trees are unbalanced but is efficient otherwise. Because child
	    // node indices can determined directly from the parent node's index
	    // it isn't necessary to store parent-child references within the
	    // nodes.
	    nodes = new Node[(1 << (decisionLevels + 1)) - 1]; 
	}
	
	public int GetNodeCount() {
		return nodes.length; 
	}
	
	public Node GetNode(int index) {
		return nodes[index];
	}
	
	public void SetNode(int index, Node node) {
		nodes[index] = node;
	}
	
	public void CheckValid() {
		
	}
	
	public int[] Apply(DataCollection data)
    {
      CheckValid();

      int[] leafNodeIndices = new int[data.GetDataCount()]; // of leaf node reached per data point

      // Allocate temporary storage for data point indices and response values
      int[] dataIndices_ = new int[data.GetDataCount()];
      for (int i = 0; i < data.GetDataCount(); i++)
        dataIndices_[i] = i;

      float[] responses_ = new float[data.GetDataCount()];

      ApplyNode(0, data, dataIndices_, 0, data.GetDataCount(), leafNodeIndices, responses_);

      return leafNodeIndices;
    }

	public static int Partition(float[] keys, int[] values, int i0, 
			int i1, float threshold) {
		int i = i0;     // index of first element
	    int j = i1 - 1; // index of last element

	    while (i != j)
	    {
	    	if (keys[i] >= threshold)
	        {
	        	// Swap keys[i] with keys[j]
	        	float key = keys[i];
	        	int value = values[i];

	        	keys[i] = keys[j];
	        	values[i] = values[j];

	        	keys[j] = key;
	        	values[j] = value;

	        	j--;
	        }
	        else
	        {
	        	i++;
	        }
	    }

		return keys[i] >= threshold ? i : i + 1;
	}
	
	 private void ApplyNode(
		        int nodeIndex,
		        DataCollection data,
		        int[] dataIndices,
		        int i0,
		        int i1,
		        int[] leafNodeIndices,
		        float[] responses_) {
		      Node node = nodes[nodeIndex];

		      if (node.IsLeaf()) {
		        for (int i = i0; i < i1; i++)
		          leafNodeIndices[dataIndices[i]] = nodeIndex;
		        return;
		      }

		      if (i0 == i1)   // No samples left
		        return;

		      for (int i = i0; i < i1; i++)
		        responses_[i] = node.Feature.GetResponse(data, dataIndices[i]);

		      int ii = Partition(responses_, dataIndices, i0, i1, node.Threshold);

		      // Recurse for child nodes.
		      ApplyNode(nodeIndex * 2 + 1, data, dataIndices, i0, ii, leafNodeIndices, responses_);
		      ApplyNode(nodeIndex * 2 + 2, data, dataIndices, ii, i1, leafNodeIndices, responses_);
		    }
	 
	 public void CreateTreeElement(Document doc, Element treeElement) {
		 for (Node item : nodes) {
			 Element nodeElement = doc.createElement("node");
			 item.CreateNodeElement(doc, nodeElement);
			 treeElement.appendChild(nodeElement);
		 }
		 
	 }
	 
	 public void XmlParse(Element treeElement) {
		 NodeList children = treeElement.getElementsByTagName("node");
		 for (int i = 0; i < children.getLength(); ++i) {
			 Element nodeElement = (Element) children.item(i);
			 Node node = new Node();
			 node.XmlParse(nodeElement);
			 nodes[i] = node;
		 }
	 }
	 
	 ////////////////////////////////////
	 public int[][] GetNodeIdForBoxes(int boxCount) {
		 int[][] idArray = new int[boxCount][];
		 for (int i = 0; i < boxCount; ++i) {
			 idArray[i] = GetNodeIdForBox(i);
		 }
		 return idArray;
	 }
	 
	 private int[] GetNodeIdForBox(int index) {
		 int[] id = new int[2];
		 double min1 = Double.MAX_VALUE, min2 = Double.MAX_VALUE;
		 for (int i = (nodes.length + 1) / 2; i < nodes.length; ++i) {
			 if (!nodes[i].IsLeaf()) continue;
			 if (min1 > nodes[i].GetBoxesVariance()[index])
				 id[0] = i;
			 else if (min2 > nodes[i].GetBoxesVariance()[index])
				 id[1] = i;
		 }
		 
		 return id;
	 }
	 
}