package edu.jhu.ece.iacl.algorithms.graphics.locator.kdtree;

import javax.vecmath.*;

/**
 * A mesh triangle object for fast-intersection tests with a cache to store the
 * last intersection point. The intersection procedures are a port of WildMagic.
 * 
 * @author Blake Lucas
 * 
 */
public abstract class KdTriangle extends BBox {
	// constant used for avoiding numerical accuracy
	public static float EPS = 5e-4f;
	// set of points
	public KdPoint3 pts[];
	protected static Point3f lastIntersectionPoint;

	/**
	 * Get last intersection point
	 * 
	 * @return
	 */
	public static Point3f getLastIntersectionPoint() {
		return lastIntersectionPoint;
	}

	/**
	 * Update bounding box
	 */
	public abstract void update();

	/**
	 * Constructor
	 */
	public KdTriangle() {
		super();
	}

	/**
	 * Get calcualted triangle normal
	 * 
	 * @return normal
	 */
	public Vector3f getNormal() {
		Vector3f kEdge1 = new Vector3f();
		Vector3f kEdge2 = new Vector3f();
		Vector3f kNormal = new Vector3f();
		kEdge1.sub(pts[1], pts[0]);
		kEdge2.sub(pts[2], pts[0]);
		kNormal.cross(kEdge1, kEdge2);
		kNormal.normalize();
		return kNormal;
	}

	/**
	 * Get triangle area
	 * 
	 * @return triangle area
	 */
	public float getArea() {
		Vector3f kEdge1 = new Vector3f();
		Vector3f kEdge2 = new Vector3f();
		Vector3f kNormal = new Vector3f();
		kEdge1.sub(pts[1], pts[0]);
		kEdge2.sub(pts[2], pts[0]);
		kNormal.cross(kEdge1, kEdge2);
		return kNormal.length() * 0.5f;
	}

	/**
	 * Get triangle area
	 * 
	 * @return triangle area
	 */
	public Point3f getCentroid() {
		Point3f p = new Point3f();
		p.add(pts[0]);
		p.add(pts[1]);
		p.add(pts[2]);
		p.scale(0.333333f);
		return p;
	}

	/**
	 * Interpolate location of point on triangle using barycentric coordinates
	 * 
	 * @param b
	 *            barycentric coordinate
	 * @return interpolated point
	 */
	public Point3f getPointCoords(Point3d b) {
		return new Point3f((float) (pts[0].x * b.x + pts[1].x * b.y + pts[2].x
				* b.z), (float) (pts[0].y * b.x + pts[1].y * b.y + pts[2].y
				* b.z), (float) (pts[0].z * b.x + pts[1].z * b.y + pts[2].z
				* b.z));
	}

	/**
	 * Get barycentric coordinates for point
	 * 
	 * @param p
	 *            point
	 * @return barycentric coordinates
	 */
	public Point3d getBaryCoords(Point3f p) {
		double a = pts[0].x - pts[2].x;
		double b = pts[1].x - pts[2].x;
		double c = pts[2].x - p.x;

		double d = pts[0].y - pts[2].y;
		double e = pts[1].y - pts[2].y;
		double f = pts[2].y - p.y;

		double g = pts[0].z - pts[2].z;
		double h = pts[1].z - pts[2].z;
		double i = pts[2].z - p.z;

		double l1 = (b * (f + i) - c * (e + h)) / (a * (e + h) - b * (d + g));
		double l2 = (a * (f + i) - c * (d + g)) / (b * (d + g) - a * (e + h));
		if (Double.isNaN(l1) || Double.isInfinite(l1)) {
			l1 = 0;
		}
		if (Double.isNaN(l2) || Double.isInfinite(l2)) {
			l2 = 0;
		}
		if (l1 > 1 || l2 > 1 || l1 + l2 > 1 || l1 < 0 || l2 < 0) {
			// System.err.println("Barycentric Coordinate Invalid R1:"+r1+" R2:"+r2+" R3:"+r3+" P:"+p);
			l1 = Math.max(Math.min(l1, 1), 0);
			l2 = Math.max(Math.min(l2, 1), 0);
			if (l1 + l2 > 1) {
				double diff = 0.5 * (1 - l1 - l2);
				l1 += diff;
				l2 += diff;
			}
		}
		return new Point3d(l1, l2, 1 - l1 - l2);
	}

	/**
	 * Compute intersection of segment and triangle
	 * 
	 * @param p1
	 *            point p1
	 * @param p2
	 *            point p2
	 * @return intersection point
	 */
	public Point3f intersectionPoint(Point3f p1, Point3f p2) {
		// compute the offset origin, edges, and normal
		KdSegment seg = new KdSegment(p1, p2);
		Vector3f dr = seg.direction;
		float len = seg.extent;
		Vector3f kDiff = new Vector3f();
		Vector3f kEdge1 = new Vector3f();
		Vector3f kEdge2 = new Vector3f();
		Vector3f kNormal = new Vector3f();
		kDiff.sub(p1, pts[0]);
		kEdge1.sub(pts[1], pts[0]);
		kEdge2.sub(pts[2], pts[0]);
		kNormal.cross(kEdge1, kEdge2);

		// Solve Q + t*D = b1*E1 + b2*E2 (Q = kDiff, D = segment direction,
		// E1 = kEdge1, E2 = kEdge2, N = Cross(E1,E2)) by
		// |Dot(D,N)|*b1 = sign(Dot(D,N))*Dot(D,Cross(Q,E2))
		// |Dot(D,N)|*b2 = sign(Dot(D,N))*Dot(D,Cross(E1,Q))
		// |Dot(D,N)|*t = -sign(Dot(D,N))*Dot(Q,N)
		double fDdN = dr.dot(kNormal);
		double fSign;
		if (fDdN > 0) {
			fSign = (double) 1.0;
		} else if (fDdN < 0) {
			fSign = (double) -1.0;
			fDdN = -fDdN;
		} else {
			// System.out.println("SEGMENT PARALLEL !");
			// Segment and triangle are parallel, call it a "no intersection"
			// even if the segment does intersect.
			return null;
		}
		Vector3f crs = new Vector3f();
		crs.cross(kDiff, kEdge2);
		double fDdQxE2 = fSign * dr.dot(crs);
		if (fDdQxE2 >= (double) 0.0) {
			crs.cross(kEdge1, kDiff);
			double fDdE1xQ = fSign * dr.dot(crs);
			if (fDdE1xQ >= (double) 0.0) {
				if (fDdQxE2 + fDdE1xQ <= fDdN) {
					// line intersects triangle, check if segment does
					double fQdN = -fSign * kDiff.dot(kNormal);
					double fExtDdN = len * fDdN;
					/*
					 * if (-fExtDdN <= fQdN && fQdN <= fExtDdN) {
					 */
					double fInv = ((double) 1.0) / fDdN;
					double m_fSegmentT = fQdN * fInv;

					if (m_fSegmentT >= 0 && m_fSegmentT <= len) {
						// segment intersects triangle
						/*
						 * double m_fTriB1 =
						 * Math.min(1,Math.max(0,fDdQxE2*fInv)); double m_fTriB2
						 * = Math.min(1,Math.max(0,fDdE1xQ*fInv)); double
						 * m_fTriB0 = (double)Math.min(1,Math.max(0,1.0 -
						 * m_fTriB1 - m_fTriB2));
						 */
						Point3f p = new Point3f(p1);
						dr.scale((float) m_fSegmentT);
						p.add(dr);
						return p;
						// return getPointCoords(new
						// Point3d(m_fTriB0,m_fTriB1,m_fTriB2));
					}
				}
			}
		}
		Point3f intersect = null;
		double d;
		KdSegment e1 = new KdSegment(pts[0], pts[1]);
		d = seg.distance(e1);
		if (d < EPS) {
			intersect = e1.getLastIntersect();
		} else {
			KdSegment e2 = new KdSegment(pts[1], pts[2]);
			d = seg.distance(e2);
			if (d < EPS) {
				intersect = e2.getLastIntersect();
			} else {
				KdSegment e3 = new KdSegment(pts[2], pts[0]);
				d = seg.distance(e3);
				if (d < EPS) {
					intersect = e3.getLastIntersect();
				}
			}
		}
		if (intersect != null) {
			kDiff.sub(intersect, p1);
			double l = seg.direction.dot(kDiff);
			if (l >= 0 && l <= len) {
				lastIntersectionPoint = intersect;
				return intersect;
			}
		}
		lastIntersectionPoint = null;
		return null;

	}

	/**
	 * Compute intersection of ray and triangle
	 * 
	 * @param org
	 *            ray origin
	 * @param v
	 *            ray direction
	 * @return intersection point
	 */
	public Point3f intersectionPoint(Point3f org, Vector3f v) {
		// compute the offset origin, edges, and normal
		KdSegment seg = new KdSegment(org, v);
		Vector3f dr = seg.direction;
		float len = seg.extent;
		Vector3f kDiff = new Vector3f();
		Vector3f kEdge1 = new Vector3f();
		Vector3f kEdge2 = new Vector3f();
		Vector3f kNormal = new Vector3f();
		kDiff.sub(org, pts[0]);
		kEdge1.sub(pts[1], pts[0]);
		kEdge2.sub(pts[2], pts[0]);
		kNormal.cross(kEdge1, kEdge2);

		// Solve Q + t*D = b1*E1 + b2*E2 (Q = kDiff, D = segment direction,
		// E1 = kEdge1, E2 = kEdge2, N = Cross(E1,E2)) by
		// |Dot(D,N)|*b1 = sign(Dot(D,N))*Dot(D,Cross(Q,E2))
		// |Dot(D,N)|*b2 = sign(Dot(D,N))*Dot(D,Cross(E1,Q))
		// |Dot(D,N)|*t = -sign(Dot(D,N))*Dot(Q,N)
		double fDdN = dr.dot(kNormal);
		double fSign;
		if (fDdN > 0) {
			fSign = (double) 1.0;
		} else if (fDdN < 0) {
			fSign = (double) -1.0;
			fDdN = -fDdN;
		} else {
			// System.out.println("SEGMENT PARALLEL !");
			// Segment and triangle are parallel, call it a "no intersection"
			// even if the segment does intersect.
			return null;
		}
		Vector3f crs = new Vector3f();
		crs.cross(kDiff, kEdge2);
		double fDdQxE2 = fSign * dr.dot(crs);
		if (fDdQxE2 >= (double) 0.0) {
			crs.cross(kEdge1, kDiff);
			double fDdE1xQ = fSign * dr.dot(crs);
			if (fDdE1xQ >= (double) 0.0) {
				if (fDdQxE2 + fDdE1xQ <= fDdN) {
					// line intersects triangle, check if segment does
					double fQdN = -fSign * kDiff.dot(kNormal);
					double fExtDdN = len * fDdN;
					if (-fExtDdN <= fQdN && fQdN <= fExtDdN) {
						// segment intersects triangle
						double fInv = ((double) 1.0) / fDdN;
						double m_fSegmentT = fQdN * fInv;
						// double m_fTriB1 = fDdQxE2*fInv;
						// double m_fTriB2 = fDdE1xQ*fInv;
						// double m_fTriB0 = (double)1.0 - m_fTriB1 - m_fTriB2;
						Point3f p = new Point3f(org);
						dr.scale((float) m_fSegmentT);
						p.add(dr);
						return p;
					}
				}
			}
		}
		Point3f intersect = null;
		double d;
		KdSegment e1 = new KdSegment(pts[0], pts[1]);
		d = seg.distance(e1);
		if (d < EPS) {
			intersect = e1.getLastIntersect();
		} else {
			KdSegment e2 = new KdSegment(pts[1], pts[2]);
			d = seg.distance(e2);
			if (d < EPS) {
				intersect = e2.getLastIntersect();
			} else {
				KdSegment e3 = new KdSegment(pts[2], pts[0]);
				d = seg.distance(e3);
				if (d < EPS) {
					intersect = e3.getLastIntersect();
				}
			}
		}
		lastIntersectionPoint = intersect;
		return intersect;
	}

	/**
	 * Get distance between point and triangle
	 * 
	 * @param p
	 *            point
	 * @return distance
	 */
	public double distance(Point3f p) {
		Point3f p1 = pts[0];
		Point3f p2 = pts[1];
		Point3f p3 = pts[2];
		Vector3f kDiff = new Vector3f();
		kDiff.sub(p1, p);
		Vector3f kEdge0 = new Vector3f();
		kEdge0.sub(p2, p1);
		Vector3f kEdge1 = new Vector3f();
		kEdge1.sub(p3, p1);
		float fA00 = kEdge0.lengthSquared();
		float fA01 = kEdge0.dot(kEdge1);
		float fA11 = kEdge1.lengthSquared();
		float fB0 = kDiff.dot(kEdge0);
		float fB1 = kDiff.dot(kEdge1);
		float fC = kDiff.lengthSquared();
		float fDet = Math.abs(fA00 * fA11 - fA01 * fA01);
		float fS = fA01 * fB1 - fA11 * fB0;
		float fT = fA01 * fB0 - fA00 * fB1;
		float fSqrDistance;
		if (fS + fT <= fDet) {
			if (fS < (float) 0.0) {
				if (fT < (float) 0.0) // region 4
				{
					if (fB0 < (float) 0.0) {
						fT = (float) 0.0;
						if (-fB0 >= fA00) {
							fS = (float) 1.0;
							fSqrDistance = fA00 + ((float) 2.0) * fB0 + fC;
						} else {
							fS = -fB0 / fA00;
							fSqrDistance = fB0 * fS + fC;
						}

					} else {
						fS = (float) 0.0;
						if (fB1 >= (float) 0.0) {
							fT = (float) 0.0;
							fSqrDistance = fC;
						} else if (-fB1 >= fA11) {
							fT = (float) 1.0;
							fSqrDistance = fA11 + ((float) 2.0) * fB1 + fC;
						} else {
							fT = -fB1 / fA11;
							fSqrDistance = fB1 * fT + fC;
						}
					}

				} else // region 3
				{
					fS = (float) 0.0;
					if (fB1 >= (float) 0.0) {
						fT = (float) 0.0;
						fSqrDistance = fC;
					} else if (-fB1 >= fA11) {
						fT = (float) 1.0;
						fSqrDistance = fA11 + ((float) 2.0) * fB1 + fC;
					} else {
						fT = -fB1 / fA11;
						fSqrDistance = fB1 * fT + fC;
					}
				}
			} else if (fT < (float) 0.0) // region 5
			{
				fT = (float) 0.0;
				if (fB0 >= (float) 0.0) {
					fS = (float) 0.0;
					fSqrDistance = fC;
				} else if (-fB0 >= fA00) {
					fS = (float) 1.0;
					fSqrDistance = fA00 + ((float) 2.0) * fB0 + fC;
				} else {
					fS = -fB0 / fA00;
					fSqrDistance = fB0 * fS + fC;
				}
			} else // region 0
			{
				// minimum at interior point
				float fInvDet = ((float) 1.0) / fDet;
				fS *= fInvDet;
				fT *= fInvDet;
				fSqrDistance = fS
						* (fA00 * fS + fA01 * fT + ((float) 2.0) * fB0) + fT
						* (fA01 * fS + fA11 * fT + ((float) 2.0) * fB1) + fC;
			}
		} else {
			float fTmp0, fTmp1, fNumer, fDenom;

			if (fS < (float) 0.0) // region 2
			{
				fTmp0 = fA01 + fB0;
				fTmp1 = fA11 + fB1;
				if (fTmp1 > fTmp0) {
					fNumer = fTmp1 - fTmp0;
					fDenom = fA00 - 2.0f * fA01 + fA11;
					if (fNumer >= fDenom) {
						fS = (float) 1.0;
						fT = (float) 0.0;
						fSqrDistance = fA00 + ((float) 2.0) * fB0 + fC;
					} else {
						fS = fNumer / fDenom;
						fT = (float) 1.0 - fS;
						fSqrDistance = fS
								* (fA00 * fS + fA01 * fT + 2.0f * fB0) + fT
								* (fA01 * fS + fA11 * fT + ((float) 2.0) * fB1)
								+ fC;
					}
				} else {
					fS = (float) 0.0;
					if (fTmp1 <= (float) 0.0) {
						fT = (float) 1.0;
						fSqrDistance = fA11 + ((float) 2.0) * fB1 + fC;
					} else if (fB1 >= (float) 0.0) {
						fT = (float) 0.0;
						fSqrDistance = fC;
					} else {
						fT = -fB1 / fA11;
						fSqrDistance = fB1 * fT + fC;
					}
				}
			} else if (fT < (float) 0.0) // region 6
			{
				fTmp0 = fA01 + fB1;
				fTmp1 = fA00 + fB0;
				if (fTmp1 > fTmp0) {
					fNumer = fTmp1 - fTmp0;
					fDenom = fA00 - ((float) 2.0) * fA01 + fA11;
					if (fNumer >= fDenom) {
						fT = (float) 1.0;
						fS = (float) 0.0;
						fSqrDistance = fA11 + ((float) 2.0) * fB1 + fC;
					} else {
						fT = fNumer / fDenom;
						fS = (float) 1.0 - fT;
						fSqrDistance = fS
								* (fA00 * fS + fA01 * fT + ((float) 2.0) * fB0)
								+ fT
								* (fA01 * fS + fA11 * fT + ((float) 2.0) * fB1)
								+ fC;
					}
				} else {
					fT = (float) 0.0;
					if (fTmp1 <= (float) 0.0) {
						fS = (float) 1.0;
						fSqrDistance = fA00 + ((float) 2.0) * fB0 + fC;
					} else if (fB0 >= (float) 0.0) {
						fS = (float) 0.0;
						fSqrDistance = fC;
					} else {
						fS = -fB0 / fA00;
						fSqrDistance = fB0 * fS + fC;
					}
				}
			} else // region 1
			{
				fNumer = fA11 + fB1 - fA01 - fB0;
				if (fNumer <= (float) 0.0) {
					fS = (float) 0.0;
					fT = (float) 1.0;
					fSqrDistance = fA11 + ((float) 2.0) * fB1 + fC;
				} else {
					fDenom = fA00 - 2.0f * fA01 + fA11;
					if (fNumer >= fDenom) {
						fS = (float) 1.0;
						fT = (float) 0.0;
						fSqrDistance = fA00 + ((float) 2.0) * fB0 + fC;
					} else {
						fS = fNumer / fDenom;
						fT = (float) 1.0 - fS;
						fSqrDistance = fS
								* (fA00 * fS + fA01 * fT + ((float) 2.0) * fB0)
								+ fT
								* (fA01 * fS + fA11 * fT + ((float) 2.0) * fB1)
								+ fC;
					}
				}
			}
		}

		// account for numerical round-off error
		if (fSqrDistance < (float) 0.0) {
			fSqrDistance = (float) 0.0;
		}

		kEdge0.scale(fS);
		kEdge1.scale(fT);
		lastIntersectionPoint = (Point3f) pts[0].clone();
		lastIntersectionPoint.add(kEdge0);
		lastIntersectionPoint.add(kEdge1);
		return Math.sqrt(fSqrDistance);
	}

}
