package edu.jhu.ece.iacl.algorithms.graphics.locator.balltree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;

import javax.vecmath.Point3d;
import javax.vecmath.Point3f;
import javax.vecmath.Vector3d;

import edu.jhu.ece.iacl.algorithms.graphics.utilities.PrincipalComponentAnalysisDouble;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.geom.EmbeddedPointSet;

/**
 * A bounding ball tree data structure for fast lookup into 3D scattered data.
 * This is useful for RBF interpolation of a point set. The optimal bounding
 * ball for scattered point data is an open problem. This algorithm addresses
 * the problem by computing a convex hull around data points and using the
 * center of mass to define the ball's center.
 * 
 * S. Omohundro, Five balltree construction algorithms: International Computer
 * Science Institute, 1989.
 * 
 * @author Blake Lucas
 * 
 */
public class BallTree extends AbstractCalculation {
	private BBall root;
	private int maxDepth;
	private PointBall[] leafNodes;

	public BBall getRoot() {
		return root;
	}

	/**
	 * Split children along principal axis
	 * 
	 * @param children
	 *            children balls
	 * @param edges
	 *            corresponding edges
	 * @return
	 */
	protected int splitPosition(List<Ball> children, BBallEdge[] edges) {
		PrincipalComponentAnalysisDouble pca = new PrincipalComponentAnalysisDouble(
				children);
		Vector3d v = pca.getPrinicpalEigenVector();
		int i = 0;
		for (Ball b : children) {
			edges[i++] = new BBallEdge(b, v);
		}
		Arrays.sort(edges, 0, children.size());
		int splitPos = children.size() / 2;
		return splitPos;
	}

	/**
	 * Constructor
	 * 
	 * @param maxDepth
	 *            max ball tree depth
	 */
	public BallTree(int maxDepth) {
		super();
		this.maxDepth = maxDepth;
		setLabel("Ball-Tree");
	}

	/**
	 * Constructor
	 * 
	 * @param parent
	 *            parent calculation
	 * @param pts
	 *            points
	 * @param maxDepth
	 *            maximum ball tree depth
	 */
	public BallTree(AbstractCalculation parent, Point3f[] pts, int maxDepth) {
		super(parent);
		this.maxDepth = maxDepth;
		setLabel("Ball-Tree");
		init(pts);
	}

	/**
	 * Initialize ball tree
	 * 
	 * @param pts
	 *            points
	 */
	public void init(Point3f[] pts) {
		root = new BBall();
		List<Ball> children = root.getChildren();
		int i = 0;
		for (Point3f pt : pts) {
			children.add(new PointBall(pt, i++));
		}
		root.updateDescendants();
		buildTree();
	}

	/**
	 * Initialize ball tree
	 * 
	 * @param pts
	 *            points
	 */
	public void init(Point3d[] pts) {
		root = new BBall();
		List<Ball> children = root.getChildren();
		int i = 0;
		for (Point3d pt : pts) {
			children.add(new PointBall(pt, i++));
		}
		root.updateDescendants();
		buildTree();
	}

	/**
	 * Get point set representing ball tree
	 * 
	 * @return point set
	 */
	public EmbeddedPointSet getPointSet() {
		LinkedList<Ball> all = root.getAllDescendants();
		Point3f[] pts = new Point3f[all.size()];
		double[][] radii = new double[all.size()][2];
		int i = 0;
		for (Ball b : all) {
			pts[i] = new Point3f((float) b.x, (float) b.y, (float) b.z);
			radii[i][0] = b.getRadius();
			radii[i][1] = b.getDepth();
			i++;
		}
		EmbeddedPointSet ps = new EmbeddedPointSet(pts);
		ps.setPointData(radii);
		ps.setName("balltree");
		return ps;
	}

	/**
	 * Get all leaf nodes
	 * 
	 * @return
	 */
	public PointBall[] getLeafNodes() {
		return leafNodes;
	}

	/**
	 * Get leaf node at specified index
	 * 
	 * @param i
	 *            index
	 * @return ball
	 */
	public PointBall getLeafNode(int i) {
		return leafNodes[i];
	}

	/**
	 * Comparator for ball
	 * 
	 * @author Blake Lucas
	 * 
	 */
	protected class DistanceCompare implements Comparator<Ball> {
		protected Point3d pivot;

		public DistanceCompare(Point3d pivot) {
			this.pivot = pivot;
		}

		public int compare(Ball b1, Ball b2) {
			return (int) Math.signum(b1.distanceToSphere(pivot)
					- b2.distanceToSphere(pivot));
		}

	}

	/**
	 * Get nearest neighbors from point.
	 * 
	 * @param pt
	 *            point
	 * @param K
	 *            number of nearest neighbors
	 * @return nearest neighbor balls
	 */
	public PointBall[] getNearestNeighbors(Point3d pt, int K) {
		PointBall[] nbhrs = new PointBall[K];
		getNearestNeighbors(pt, nbhrs);
		return nbhrs;
	}

	/**
	 * Get nearest neighbors for point using brute force search. This should
	 * agree with neighbors obtained from using the ball tree, and the ball tree
	 * has been tested to insure that it does.
	 * 
	 * @param pt
	 *            point
	 * @param nbhrs
	 *            list of nearest neighbors
	 */
	public void getNearestNeighborsBrute(Point3d pt, Point3d[] nbhrs) {

		int K = nbhrs.length;
		PriorityQueue<PointBall> minHeap = new PriorityQueue<PointBall>(
				leafNodes.length, new DistanceCompare(pt));
		for (PointBall b : leafNodes) {
			minHeap.add(b);
		}
		for (int i = 0; i < K; i++) {
			nbhrs[i] = minHeap.remove();
		}
	}

	/**
	 * Get nearest neighbors for point
	 * 
	 * @param pt
	 *            point
	 * @param nbhrs
	 *            array to store nearest neighbors
	 */
	public void getNearestNeighbors(Point3d pt, Point3d[] nbhrs) {
		int K = nbhrs.length;
		PriorityQueue<Ball> queue = new PriorityQueue<Ball>(1000,
				new DistanceCompare(pt));
		queue.add(root);
		PriorityQueue<PointBall> minHeap = new PriorityQueue<PointBall>(2 * K,
				new DistanceCompare(pt));
		while (queue.size() > 0 && minHeap.size() < K) {
			Ball b = queue.remove();
			if (b instanceof PointBall) {
				minHeap.add((PointBall) b);
			} else {
				queue.addAll(b.getChildren());
			}
		}

		for (int i = 0; i < K; i++) {
			nbhrs[i] = minHeap.remove();
		}
	}

	/**
	 * Update the position of a point in the ball tree.
	 * 
	 * @param index
	 *            index of point to update
	 * @param pt
	 *            new position
	 */
	public void update(int index, Point3d pt) {
		PointBall leaf = leafNodes[index];
		Ball ball = leaf.parent;
		leaf.set(pt);
		if (ball.contains(pt)) {
			// Parent contains new location for the point, do not change tree
			return;
		}
		// Parent does not contain new point location, remove point from parent
		if (!ball.getChildren().remove(leaf)) {
			System.out.println("COULD NOT REMOVE CHILD! " + index + " " + pt);
			System.out.flush();
			System.exit(0);
		}
		// Find parent that contains the current point
		while (ball != null && !ball.contains(pt)) {
			ball = ball.parent;
		}
		// If no parent contains this point, insert point at root
		if (ball == null) {
			ball = root;
		}
		insert(leaf, ball);
	}

	/**
	 * Insert point into ball
	 * 
	 * @param pt
	 *            point
	 * @param ball
	 *            ball
	 */
	public void insert(PointBall pt, Ball ball) {
		double minD, d;
		Ball minBall;
		while (!ball.isLeaf()) {
			minD = Double.MAX_VALUE;
			minBall = null;
			for (Ball b : ball.getChildren()) {
				if (b instanceof PointBall) {
					minBall = b;
					break;
				}
				d = b.distanceToSphere(pt);
				if (d < minD) {
					minD = d;
					minBall = b;
				}
			}
			ball = minBall;
		}
		// If this is a point ball, insert point into parent
		if (ball instanceof PointBall) {
			ball = ball.parent;
		}
		ball.getChildren().add(pt);
		pt.parent = ball;
		ball.updateAncestors();

	}

	/**
	 * Build ball tree for all vertices in a surface.
	 */
	protected void buildTree() {
		Ball ball = root;
		leafNodes = new PointBall[root.getChildren().size()];
		root.getChildren().toArray(leafNodes);
		LinkedList<Ball> balls = new LinkedList<Ball>();
		balls.add(ball);
		int sz;
		int splitPos;
		BBallEdge[] edges;
		BBall leftChild;
		BBall rightChild;
		int initSz = leafNodes.length;
		edges = new BBallEdge[initSz];
		setTotalUnits(Math.min(maxDepth, (int) Math.floor(Math.log(initSz / 2)
				/ Math.log(2))));
		int depthCount = 0;
		// Create ball edges
		for (int i = 0; i < edges.length; i++)
			edges[i] = new BBallEdge();
		while (!balls.isEmpty()) {
			ball = balls.remove();
			if (ball.getDepth() > maxDepth)
				continue;
			if (ball.getDepth() > depthCount) {
				// Depth increased, increment completed units
				depthCount = ball.getDepth();
				incrementCompletedUnits();
			}
			List<Ball> children = ball.getChildren();
			// Less than two children, no need to subdivide
			if (children.size() < 2)
				continue;
			// Number of edges is twice the number of edges
			sz = children.size();
			// Find best split position
			splitPos = splitPosition(children, edges);
			// Check if split position is outside range of children
			if (splitPos <= 0 || splitPos == sz - 1) {
				continue;
			}
			// Split child into left and right
			leftChild = new BBall(ball.getDepth() + 1);
			rightChild = new BBall(ball.getDepth() + 1);
			List<Ball> lchildren = leftChild.getChildren();
			List<Ball> rchildren = rightChild.getChildren();
			// Add balls to left child
			for (int i = 0; i < splitPos; i++) {
				lchildren.add(edges[i].getBall());
			}
			// Add balls to right child
			for (int i = splitPos; i < sz; i++) {
				rchildren.add(edges[i].getBall());
			}
			// Erase children
			children.clear();
			// Replace children with only two children
			children.add(leftChild);
			children.add(rightChild);
			// Add balls to queue
			if (lchildren.size() > 4)
				balls.add(leftChild);
			if (rchildren.size() > 4)
				balls.add(rightChild);
		}
		root.updateDescendants();
		markCompleted();
	}

}
