package edu.jhu.ece.iacl.algorithms.graphics.utilities.rbf;

import Jama.*;
import javax.vecmath.Point2f;
import javax.vecmath.Point3f;
import javax.vecmath.Vector3f;

/**
 * Use radial basis functions to interpolate w position for a function
 * w=f(x,y,z) where f(.) and its derivative are known for several positions. The
 * derivative vector is assumed to be <dw/dx,dw/dy,dw/dz>
 * 
 * @author Blake Lucas
 * 
 */
public class RadialBasisFunc4DWithDerivConst extends RadialBasisFunctionFloat {

	protected Vector3f[] norms;
	protected double[] vals;

	public RadialBasisFunc4DWithDerivConst(Point3f[] points, Vector3f[] norms,
			double[] vals, RadialFunctionType func) {
		this.points = points;
		this.norms = norms;
		this.vals = vals;
		radialFunc = func;
		computeCoefficients();
	}

	public float interpolate(Point3f p) {
		float ret = 0;
		ret += chi[0] + chi[1] * p.x + chi[2] * p.y + chi[3] * p.z;
		for (int i = 0; i < points.length; i++) {
			ret += (lambda[i] + lambda[i + points.length] * p.x
					+ lambda[i + 2 * points.length] * p.y + lambda[i + 3
					* points.length]
					* p.z)
					* evalFunc(p, points[i]);
		}
		return ret;
	}

	public Vector3f interpolateGradient(Point3f p) {
		Vector3f grad = new Vector3f();
		grad.x = (float) chi[1];
		grad.y = (float) chi[2];
		grad.z = (float) chi[3];
		Vector3f g;
		double val1;
		double val2;
		for (int i = 0; i < points.length; i++) {
			g = evalFuncGradient(p, points[i]);
			val1 = evalFunc(p, points[i]);
			val2 = (lambda[i] + lambda[i + points.length] * p.x
					+ lambda[i + 2 * points.length] * p.y + lambda[i + 3
					* points.length]
					* p.z);
			grad.x += lambda[i + points.length] * val1 + g.x * val2;
			grad.y += lambda[i + 2 * points.length] * val1 + g.y * val2;
			grad.z += lambda[i + 3 * points.length] * val1 + g.z * val2;
		}
		return grad;
	}

	protected void computeCoefficients() {
		Matrix A = new Matrix(4 * points.length + 4, 4 * points.length + 4);
		Matrix b = new Matrix(4 * points.length + 4, 1);
		double d;
		Vector3f deriv;
		// Populate s(x) entries
		for (int i = 0; i < points.length; i++) {
			Point3f p = points[i];
			A.set(i, 4 * points.length, 1);
			A.set(i, 4 * points.length + 1, p.x);
			A.set(i, 4 * points.length + 2, p.y);
			A.set(i, 4 * points.length + 3, p.z);

			b.set(i, 0, vals[i]);
			A.set(4 * points.length, i, 1);
			A.set(4 * points.length + 1, i, p.x);
			A.set(4 * points.length + 2, i, p.y);
			A.set(4 * points.length + 3, i, p.z);

			for (int j = 0; j < points.length; j++) {
				d = evalFunc(p, points[j]);
				A.set(i, j, d);
				A.set(i, j + points.length, points[i].x * d);
				A.set(i, j + 2 * points.length, points[i].y * d);
				A.set(i, j + 3 * points.length, points[i].z * d);

			}
		}

		for (int i = 0; i < points.length; i++) {
			Point3f p = points[i];
			Vector3f n = norms[i];
			// Populate ds(x)/dx entries
			b.set(i + points.length, 0, n.x);

			A.set(i + points.length, 4 * points.length, 0);
			A.set(i + points.length, 4 * points.length + 1, 1);
			A.set(i + points.length, 4 * points.length + 2, 0);
			A.set(i + points.length, 4 * points.length + 3, 0);

			A.set(4 * points.length, i + points.length, 0);
			A.set(4 * points.length + 1, i + points.length, 1);
			A.set(4 * points.length + 2, i + points.length, 0);
			A.set(4 * points.length + 3, i + points.length, 0);

			// Populate ds(x)/dy entries
			b.set(i + 2 * points.length, 0, n.y);

			A.set(i + 2 * points.length, 4 * points.length, 0);
			A.set(i + 2 * points.length, 4 * points.length + 1, 0);
			A.set(i + 2 * points.length, 4 * points.length + 2, 1);
			A.set(i + 2 * points.length, 4 * points.length + 3, 0);

			A.set(4 * points.length, i + 2 * points.length, 0);
			A.set(4 * points.length + 1, i + 2 * points.length, 0);
			A.set(4 * points.length + 2, i + 2 * points.length, 1);
			A.set(4 * points.length + 3, i + 2 * points.length, 0);

			// Populate ds(x)/dz entries
			b.set(i + 3 * points.length, 0, n.z);

			A.set(i + 3 * points.length, 4 * points.length, 0);
			A.set(i + 3 * points.length, 4 * points.length + 1, 0);
			A.set(i + 3 * points.length, 4 * points.length + 2, 0);
			A.set(i + 3 * points.length, 4 * points.length + 3, 1);

			A.set(4 * points.length, i + 3 * points.length, 0);
			A.set(4 * points.length + 1, i + 3 * points.length, 0);
			A.set(4 * points.length + 2, i + 3 * points.length, 0);
			A.set(4 * points.length + 3, i + 3 * points.length, 1);

			for (int j = 0; j < points.length; j++) {
				deriv = evalFuncGradient(p, points[j]);
				d = evalFunc(p, points[j]);
				A.set(i + points.length, j, deriv.x);
				A.set(i + points.length, j + points.length, points[i].x
						* deriv.x + d);
				A.set(i + points.length, j + 2 * points.length, points[i].y
						* deriv.x);
				A.set(i + points.length, j + 3 * points.length, points[i].z
						* deriv.x);

				A.set(i + 2 * points.length, j, deriv.y);
				A.set(i + 2 * points.length, j + points.length, points[i].x
						* deriv.y);
				A.set(i + 2 * points.length, j + 2 * points.length, points[i].y
						* deriv.y + d);
				A.set(i + 2 * points.length, j + 3 * points.length, points[i].z
						* deriv.y);

				A.set(i + 3 * points.length, j, deriv.z);
				A.set(i + 3 * points.length, j + points.length, points[i].x
						* deriv.z);
				A.set(i + 3 * points.length, j + 2 * points.length, points[i].y
						* deriv.z);
				A.set(i + 3 * points.length, j + 3 * points.length, points[i].z
						* deriv.z + d);

			}
		}
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				A.set(4 * points.length + i, 4 * points.length + j, 0);
			}
			b.set(4 * points.length + i, 0, 0);
		}
		SingularValueDecomposition svd = new SingularValueDecomposition(A);
		Matrix S = svd.getS();
		Matrix V = svd.getV();
		Matrix U = svd.getU();
		int zeros = 0;
		for (int i = 0; i < S.getColumnDimension(); i++) {
			if (Math.abs(S.get(i, i)) > 1E-12) {
				S.set(i, i, 1 / S.get(i, i));
			} else {
				zeros++;
				S.set(i, i, 0);
			}
		}
		Matrix coeff = V.times(S.times(U.transpose())).times(b);
		chi = new double[4];
		lambda = new double[4 * points.length];
		chi[0] = coeff.get(4 * points.length, 0);
		chi[1] = coeff.get(4 * points.length + 1, 0);
		chi[2] = coeff.get(4 * points.length + 2, 0);
		chi[3] = coeff.get(4 * points.length + 3, 0);

		for (int i = 0; i < 4 * points.length; i++) {
			lambda[i] = coeff.get(i, 0);
		}
	}

	public float evalFunc(Point3f p1, Point3f p2) {
		float r = (float) Math
				.sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y)
						* (p1.y - p2.y) + (p1.z - p2.z) * (p1.z - p2.z));
		switch (radialFunc) {
		case LINEAR:
			return r;
		case THIN_PLATE:
			return (float) ((r > 0) ? r * r * Math.log(r) : 0);
		case GAUSSIAN:
			return (float) Math.exp(-alpha * r * r);
		case MULTIQUADRIC:
			return (float) Math.sqrt((r * r + C * C));
		default:
			return 0;
		}
	}

	public Vector3f evalFuncGradient(Point3f p1, Point3f p2) {
		float r = (float) Math
				.sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y)
						* (p1.y - p2.y) + (p1.z - p2.z) * (p1.z - p2.z));
		float rx = (r != 0) ? (p1.x - p2.x) / r : 0;
		float ry = (r != 0) ? (p1.y - p2.y) / r : 0;
		float rz = (r != 0) ? (p1.z - p2.z) / r : 0;

		float d = 0;
		switch (radialFunc) {
		case LINEAR:
			d = 1;
			break;
		case THIN_PLATE:
			d = (float) ((r > 0) ? 2 * r * Math.log(Math.abs(r)) + r : 0);
			break;
		case GAUSSIAN:
			d = (float) (-2 * alpha * r * Math.exp(-alpha * r * r));
			break;
		case MULTIQUADRIC:
			d = (float) (r / Math.sqrt((r * r + C * C)));
			break;
		default:
		}
		return new Vector3f(d * rx, d * ry, d * rz);
	}
}
