package edu.jhu.ece.iacl.algorithms.graphics.utilities.rbf;

import Jama.*;
import javax.vecmath.Point2f;
import javax.vecmath.Point3f;
import javax.vecmath.Vector3f;

/**
 * Use radial basis functions to interpolate w position for a function
 * w=f(x,y,z) where f(.) is known for several positions. This method performs
 * all operations in floating point.
 * 
 * @author Blake Lucas
 * 
 */
public class RadialBasisFuncFloat4D extends RadialBasisFunctionFloat {
	protected double[] vals;

	protected Point3f[] points;

	public RadialBasisFuncFloat4D(Point3f[] points, double[] vals,
			RadialFunctionType func) {
		this.points = points;
		this.vals = vals;
		radialFunc = func;
		computeCoefficients();
	}

	public float interpolate(Point3f p) {
		float ret = 0;
		ret += chi[0] + chi[1] * p.x + chi[2] * p.y + chi[3] * p.z;
		for (int i = 0; i < points.length; i++) {
			ret += lambda[i] * evalFunc(p, points[i]);
		}
		return ret;
	}

	public Vector3f interpolateGradient(Point3f p) {
		Vector3f grad = new Vector3f();
		grad.x = (float) chi[1];
		grad.y = (float) chi[2];
		grad.z = (float) chi[3];
		Vector3f g;
		for (int i = 0; i < points.length; i++) {
			g = evalFuncGradient(p, points[i]);
			grad.x += g.x * lambda[i];
			grad.y += g.y * lambda[i];
			grad.z += g.z * lambda[i];
		}
		return grad;
	}

	protected void computeCoefficients() {
		Matrix A = new Matrix(points.length + 4, points.length + 4);
		Matrix b = new Matrix(points.length + 4, 1);
		double d;
		// Populate s(x) entries
		for (int i = 0; i < points.length; i++) {
			Point3f p = points[i];
			A.set(i, points.length, 1);
			A.set(i, points.length + 1, p.x);
			A.set(i, points.length + 2, p.y);
			A.set(i, points.length + 3, p.z);

			b.set(i, 0, vals[i]);
			A.set(points.length, i, 1);
			A.set(points.length + 1, i, p.x);
			A.set(points.length + 2, i, p.y);
			A.set(points.length + 3, i, p.z);
			for (int j = 0; j < points.length; j++) {
				d = evalFunc(p, points[j]);
				A.set(i, j, d);
			}
		}
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				A.set(points.length + i, points.length + j, 0);
			}
			b.set(points.length + i, 0, 0);
		}
		SingularValueDecomposition svd = new SingularValueDecomposition(A);
		Matrix S = svd.getS();
		Matrix V = svd.getV();
		Matrix U = svd.getU();
		int zeros = 0;
		for (int i = 0; i < S.getColumnDimension(); i++) {
			if (Math.abs(S.get(i, i)) > 1E-12) {
				S.set(i, i, 1 / S.get(i, i));
			} else {
				zeros++;
				S.set(i, i, 0);
			}
		}
		Matrix coeff = V.times(S.times(U.transpose())).times(b);
		chi = new double[4];
		lambda = new double[points.length];
		chi[0] = coeff.get(points.length, 0);
		chi[1] = coeff.get(points.length + 1, 0);
		chi[2] = coeff.get(points.length + 2, 0);
		chi[3] = coeff.get(points.length + 3, 0);
		for (int i = 0; i < points.length; i++) {
			lambda[i] = coeff.get(i, 0);
		}
		/*
		 * for(int i=0;i<points.length;i++){
		 * if(Math.abs(interpolate(points[i])-vals[i])>1E-5){
		 * System.out.println(
		 * "TOO MUCH ERROR "+points[i]+" "+vals[i]+" "+interpolate(points[i]));
		 * } }
		 */
	}

	public float evalFunc(Point3f p1, Point3f p2) {
		float r = (float) Math
				.sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y)
						* (p1.y - p2.y) + (p1.z - p2.z) * (p1.z - p2.z));
		switch (radialFunc) {
		case LINEAR:
			return r;
		case THIN_PLATE:
			return (float) ((r != 0) ? r * r * Math.log(r) : 0);
		case GAUSSIAN:
			return (float) Math.exp(-alpha * r * r);
		case MULTIQUADRIC:
			return (float) Math.sqrt((r * r + C * C));
		default:
			return 0;
		}
	}

	public Vector3f evalFuncGradient(Point3f p1, Point3f p2) {
		float r = (float) Math
				.sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y)
						* (p1.y - p2.y) + (p1.z - p2.z) * (p1.z - p2.z));
		float rx = (r != 0) ? (p1.x - p2.x) / r : 0;
		float ry = (r != 0) ? (p1.y - p2.y) / r : 0;
		float rz = (r != 0) ? (p1.z - p2.z) / r : 0;

		float d = 0;
		switch (radialFunc) {
		case LINEAR:
			d = 1;
			break;
		case THIN_PLATE:
			d = (float) ((r > 0) ? 2 * r * Math.log(Math.abs(r)) + r : 0);
			break;
		case GAUSSIAN:
			d = (float) (-2 * alpha * r * Math.exp(-alpha * r * r));
			break;
		case MULTIQUADRIC:
			d = (float) (r / Math.sqrt((r * r + C * C)));
			break;
		default:
		}
		return new Vector3f(d * rx, d * ry, d * rz);
	}
}
