package edu.jhu.ece.iacl.algorithms.graphics.utilities.rbf;

import javax.vecmath.Point3f;
import javax.vecmath.Vector3f;

import Jama.*;

/**
 * Use radial basis functions to interpolate a w position for a function
 * w=f(x,y,z) with the constraint that the inner product is constant at
 * specified positions.
 * 
 * @author Blake Lucas
 * 
 */
public class RadialBasisFuncWithOrthoConst extends RadialBasisFunctionFloat {
	protected Vector3f tan;
	protected double[] vals;
	protected Vector3f[] norms;
	protected double dot;

	protected Point3f[] points;

	public RadialBasisFuncWithOrthoConst(Point3f[] points, Vector3f[] norms,
			double[] vals, double dot, Vector3f tan, RadialFunctionType func)
			throws RuntimeException {
		this.points = points;
		this.norms = norms;
		radialFunc = func;
		this.tan = tan;
		this.vals = vals;
		this.dot = dot;
		computeCoefficients();
	}

	/*
	 * public float interpolate(Point3f p){ float ret=0;
	 * ret+=chi[0]+chi[1]*p.x+chi[2]*p.y+chi[3]*p.z; for(int
	 * i=0;i<points.length;i++){
	 * ret+=(lambda[i]+p.distance(points[i])*lambda[i+points
	 * .length])*evalFunc(points[i],p); } return ret; }
	 */
	public float interpolate(Point3f p) {
		float ret = 0;
		ret += chi[0] + chi[1] * p.x + chi[2] * p.y + chi[3] * p.z;
		// System.out.printf("> ");
		for (int i = 0; i < points.length; i++) {
			float d = evalFunc(p, points[i]);
			ret += d
					* (lambda[i] + lambda[i + points.length] * p.x
							+ lambda[i + 2 * points.length] * p.y + lambda[i
							+ 3 * points.length]
							* p.z);
			// System.out.printf("%5.2f %5.2f %5.2f %5.2f ",d,d*p.x,d*p.y,d*p.z);
			// System.out.printf("%7.2f ",d);
		}
		/*
		 * for(int i=0;i<points.length;i++){ float d=evalFunc(p,points[i]);
		 * System.out.printf("%7.2f ",d*p.x); } for(int
		 * i=0;i<points.length;i++){ float d=evalFunc(p,points[i]);
		 * System.out.printf("%7.2f ",d*p.y); } for(int
		 * i=0;i<points.length;i++){ float d=evalFunc(p,points[i]);
		 * System.out.printf("%7.2f ",d*p.z); }
		 */
		// System.out.printf("%5.2f %5.2f %5.2f %5.2f\n",1.0f,p.x,p.y,p.z);
		return ret;
	}

	public Vector3f interpolateGrad(Point3f p) {
		Vector3f grad = new Vector3f((float) chi[1], (float) chi[2],
				(float) chi[3]);
		for (int i = 0; i < points.length; i++) {
			Vector3f deriv = evalFuncGradient(p, points[i]);
			float d = evalFunc(p, points[i]);
			float l = (float) (lambda[i] + lambda[i + points.length] * p.x
					+ lambda[i + 2 * points.length] * p.y + lambda[i + 3
					* points.length]
					* p.z);
			grad.x += l * deriv.x + lambda[i + points.length] * d;
			grad.y += l * deriv.y + lambda[i + 2 * points.length] * d;
			grad.z += l * deriv.z + lambda[i + 3 * points.length] * d;
		}
		return grad;
	}

	/*
	 * public Vector3f interpolateDeriv(Point3f p){ Vector3f deriv=new
	 * Vector3f(); deriv.x=(float)chi[1]; deriv.y=(float)chi[2];
	 * deriv.z=(float)chi[3]; float dist; float func; for(int
	 * i=0;i<points.length;i++){ Vector3f d=evalFuncGradient(points[i],p);
	 * func=evalFunc(points[i],p); dist=p.distance(points[i]);
	 * d.x=(float)(lambda
	 * [i]*d.x+lambda[points.length+i]*(dist*d.x+func*(points[i].x-p.x)));
	 * d.y=(float
	 * )(lambda[i]*d.y+lambda[points.length+i]*(dist*d.y+func*(points[i
	 * ].y-p.y)));
	 * d.z=(float)(lambda[i]*d.z+lambda[points.length+i]*(dist*d.z+func
	 * *(points[i].z-p.z))); deriv.add(d); } return deriv; }
	 */
	public float evalFunc(Point3f p1, Point3f p2) {
		float r = (float) Math
				.sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y)
						* (p1.y - p2.y) + (p1.z - p2.z) * (p1.z - p2.z));
		switch (radialFunc) {
		case LINEAR:
			return r;
		case THIN_PLATE:
			return (float) ((r > 0) ? r * r * Math.log(r) : 0);
		case GAUSSIAN:
			return (float) Math.exp(-alpha * r * r);
		case MULTIQUADRIC:
			return (float) Math.sqrt((r * r + C * C));
		default:
			return 0;
		}
	}

	public Vector3f evalFuncGradient(Point3f p1, Point3f p2) {
		float r = (float) Math
				.sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y)
						* (p1.y - p2.y) + (p1.z - p2.z) * (p1.z - p2.z));
		float rx = (r != 0) ? (p1.x - p2.x) / r : 0;
		float ry = (r != 0) ? (p1.y - p2.y) / r : 0;
		float rz = (r != 0) ? (p1.z - p2.z) / r : 0;
		float d = 0;
		switch (radialFunc) {
		case LINEAR:
			d = 1;
			break;
		case THIN_PLATE:
			d = (float) ((r > 0) ? 2 * r * Math.log(Math.abs(r)) + r : 0);
			break;
		case GAUSSIAN:
			d = (float) (-2 * alpha * r * Math.exp(-alpha * r * r));
			break;
		case MULTIQUADRIC:
			d = (float) (r / Math.sqrt((r * r + C * C)));
			break;
		default:
		}
		return new Vector3f(d * rx, d * ry, d * rz);
	}

	protected void computeCoefficients() {
		Matrix A = new Matrix(4 * points.length + 4, 4 * points.length + 4);
		Matrix b = new Matrix(4 * points.length + 4, 1);
		double d;
		Vector3f deriv;
		// Populate s(x) entries
		for (int i = 0; i < points.length; i++) {
			Point3f p = points[i];
			A.set(i, 4 * points.length, 1);
			A.set(i, 4 * points.length + 1, p.x);
			A.set(i, 4 * points.length + 2, p.y);
			A.set(i, 4 * points.length + 3, p.z);
			b.set(i, 0, vals[i]);
			A.set(4 * points.length, i, 1);
			A.set(4 * points.length + 1, i, p.x);
			A.set(4 * points.length + 2, i, p.y);
			A.set(4 * points.length + 3, i, p.z);
			for (int j = 0; j < points.length; j++) {
				d = evalFunc(p, points[j]);
				A.set(i, j, d);
				A.set(i, j + points.length, p.x * d);
				A.set(i, j + 2 * points.length, p.y * d);
				A.set(i, j + 3 * points.length, p.z * d);
			}
		}

		for (int i = 0; i < points.length; i++) {
			Point3f p = points[i];
			Vector3f n = norms[i];
			// Populate ds(x)/dx entries
			b.set(i + points.length, 0, n.x);

			A.set(i + points.length, 4 * points.length, 0);
			A.set(i + points.length, 4 * points.length + 1, 1);
			A.set(i + points.length, 4 * points.length + 2, 0);
			A.set(i + points.length, 4 * points.length + 3, 0);

			A.set(4 * points.length, i + points.length, 0);
			A.set(4 * points.length + 1, i + points.length, 1);
			A.set(4 * points.length + 2, i + points.length, 0);
			A.set(4 * points.length + 3, i + points.length, 0);

			// Populate ds(x)/dy entries
			b.set(i + 2 * points.length, 0, n.y);

			A.set(i + 2 * points.length, 4 * points.length, 0);
			A.set(i + 2 * points.length, 4 * points.length + 1, 0);
			A.set(i + 2 * points.length, 4 * points.length + 2, 1);
			A.set(i + 2 * points.length, 4 * points.length + 3, 0);

			A.set(4 * points.length, i + 2 * points.length, 0);
			A.set(4 * points.length + 1, i + 2 * points.length, 0);
			A.set(4 * points.length + 2, i + 2 * points.length, 1);
			A.set(4 * points.length + 3, i + 2 * points.length, 0);

			// Populate ds(x)/dz entries
			A.set(i + 3 * points.length, 4 * points.length, 0);
			A.set(i + 3 * points.length, 4 * points.length + 1, 0);
			A.set(i + 3 * points.length, 4 * points.length + 2, 0);
			A.set(i + 3 * points.length, 4 * points.length + 3, 1);

			A.set(4 * points.length, i + 3 * points.length, 0);
			A.set(4 * points.length + 1, i + 3 * points.length, 0);
			A.set(4 * points.length + 2, i + 3 * points.length, 0);
			A.set(4 * points.length + 3, i + 3 * points.length, 1);

			b.set(i + 3 * points.length, 0, n.z);

			for (int j = 0; j < points.length; j++) {
				deriv = evalFuncGradient(p, points[j]);
				d = evalFunc(p, points[j]);

				A.set(i + points.length, j, deriv.x);

				A.set(i + points.length, j + points.length, points[i].x
						* deriv.x + d);
				A.set(i + points.length, j + 2 * points.length, points[i].y
						* deriv.x);
				A.set(i + points.length, j + 3 * points.length, points[i].z
						* deriv.x);

				A.set(i + 2 * points.length, j, deriv.y);

				A.set(i + 2 * points.length, j + points.length, points[i].x
						* deriv.y);
				A.set(i + 2 * points.length, j + 2 * points.length, points[i].y
						* deriv.y + d);
				A.set(i + 2 * points.length, j + 3 * points.length, points[i].z
						* deriv.y);

				A.set(i + 3 * points.length, j, deriv.z);

				A.set(i + 3 * points.length, j + points.length, points[i].x
						* deriv.z);
				A.set(i + 3 * points.length, j + 2 * points.length, points[i].y
						* deriv.z);
				A.set(i + 3 * points.length, j + 3 * points.length, points[i].z
						* deriv.z + d);
			}
		}
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				A.set(4 * points.length + i, 4 * points.length + j, 0);
			}
			b.set(4 * points.length + i, 0, 0);
		}
		SingularValueDecomposition svd = new SingularValueDecomposition(A);
		Matrix S = svd.getS();
		Matrix V = svd.getV();
		Matrix U = svd.getU();
		int zeros = 0;
		for (int i = 0; i < S.getColumnDimension(); i++) {
			if (Math.abs(S.get(i, i)) > 1E-6) {
				S.set(i, i, 1 / S.get(i, i));
			} else {
				zeros++;
				S.set(i, i, 0);
			}
		}
		Matrix Ainv = V.times(S.times(U.transpose()));
		Matrix coeff = Ainv.times(b);
		chi = new double[4];
		lambda = new double[4 * points.length];
		chi[0] = coeff.get(4 * points.length, 0);
		chi[1] = coeff.get(4 * points.length + 1, 0);
		chi[2] = coeff.get(4 * points.length + 2, 0);
		chi[3] = coeff.get(4 * points.length + 3, 0);
		for (int i = 0; i < 4 * points.length; i++) {
			lambda[i] = coeff.get(i, 0);
		}

		double err = 0;
		double errD = 0;
		for (int j = 0; j < points.length; j++) {
			err += Math.pow(vals[j] - interpolate(points[j]), 2);
			Vector3f v = interpolateGrad(points[j]);
			v.sub(norms[j]);
			errD += v.lengthSquared();
		}
		err = Math.sqrt(err / points.length);
		errD = Math.sqrt(errD / points.length);

		if (err > 1E-3 || errD > 1E-3) {

			for (int j = 0; j < A.getRowDimension(); j++) {
				for (int k = 0; k < A.getColumnDimension(); k++) {
					System.out.printf("%7.2f ", A.get(j, k));
				}
				System.out.printf("| %7.2f", b.get(j, 0));
				System.out.printf(" | %7.2f", coeff.get(j, 0));
				System.out.println();
			}
			System.out.println("Error Too High " + zeros + " " + err + " "
					+ errD);
			System.exit(0);
			throw new RuntimeException("Error Too High " + zeros + " " + err
					+ " " + errD);
		}

	}
	/*
	 * protected void computeCoefficients(){ Matrix A=new
	 * Matrix(2*points.length+4,2*points.length+4); Matrix b=new
	 * Matrix(2*points.length+4,1); double d; Vector3f deriv; Vector3f dir=new
	 * Vector3f(); //Populate s(x) entries for(int i=0;i<points.length;i++){
	 * Point3f p=points[i]; for(int j=0;j<points.length;j++){
	 * d=evalFunc(p,points[j]); deriv=evalFuncGradient(points[j],p );
	 * dir.sub(points[j], p); A.set(i, j, d); A.set(i, points.length+j,
	 * p.distance(points[j])*d);
	 * 
	 * A.set(points.length+i, j, deriv.dot(tan)); A.set(points.length+i,
	 * points.length+j, p.distance(points[j])*deriv.dot(tan)+d*dir.dot(tan)); }
	 * A.set(i,2*points.length,1); A.set(i,2*points.length+1,p.x);
	 * A.set(i,2*points.length+2,p.y); A.set(i,2*points.length+3,p.z);
	 * 
	 * A.set(2*points.length,i,1); A.set(2*points.length+1,i,p.x);
	 * A.set(2*points.length+2,i,p.y); A.set(2*points.length+3,i,p.z);
	 * 
	 * A.set(points.length+i,2*points.length,0);
	 * A.set(points.length+i,2*points.length+1,tan.x);
	 * A.set(points.length+i,2*points.length+2,tan.y);
	 * A.set(points.length+i,2*points.length+3,tan.z);
	 * 
	 * A.set(2*points.length,points.length+i,0);
	 * A.set(2*points.length+1,points.length+i,tan.x);
	 * A.set(2*points.length+2,points.length+i,tan.y);
	 * A.set(2*points.length+3,points.length+i,tan.z);
	 * 
	 * b.set(i,0,vals[i]); b.set(points.length+i, 0,dot);
	 * 
	 * } for(int i=0;i<4;i++){ for(int j=0;j<4;j++){
	 * A.set(2*points.length+i,2*points.length+j,0); }
	 * b.set(2*points.length+i,0,0); } SingularValueDecomposition svd=new
	 * SingularValueDecomposition(A); Matrix S=svd.getS(); Matrix V=svd.getV();
	 * Matrix U=svd.getU(); int zeros=0; for(int
	 * i=0;i<S.getColumnDimension();i++){ if(Math.abs(S.get(i, i))>1E-6){
	 * S.set(i, i, 1/S.get(i, i)); } else { zeros++; S.set(i,i,0); } } Matrix
	 * Ainv=V.times(S.times(U.transpose())); Matrix coeff=Ainv.times(b); chi=new
	 * double[4]; lambda=new double[2*points.length];
	 * chi[0]=coeff.get(2*points.length,0);
	 * chi[1]=coeff.get(2*points.length+1,0);
	 * chi[2]=coeff.get(2*points.length+2,0);
	 * chi[3]=coeff.get(2*points.length+3,0); for(int
	 * i=0;i<2*points.length;i++){ lambda[i]=coeff.get(i,0);
	 * 
	 * }
	 * 
	 * double err=0; double errD=0; for(int j=0;j<points.length-1;j++){
	 * err+=Math.pow(vals[j]-interpolate(points[j]),2);
	 * errD+=Math.pow(interpolateDeriv(points[j]).dot(tan)-dot,2); }
	 * err=Math.sqrt(err/points.length); errD=Math.sqrt(errD/points.length);
	 * if(err>1E-3||errD>1E-3){
	 * 
	 * for(int j=0;j<A.getRowDimension();j++){ for(int
	 * k=0;k<A.getColumnDimension();k++){ System.out.printf("%5.2f ",
	 * A.get(j,k)); } if(j<points.length){
	 * System.out.printf("| %5.2f <-> %5.2f",
	 * b.get(j,0),interpolate(points[j])); } else if(j<2*points.length){
	 * System.out.printf("| %5.2f <-> %5.2f",
	 * b.get(j,0),interpolateDeriv(points[j-points.length]).dot(tan)); } else {
	 * System.out.printf("| %5.2f", b.get(j,0));
	 * 
	 * } System.out.printf(" | %5.2f", coeff.get(j,0)); System.out.println(); }
	 * System.out.println("Error Too High "+zeros+" "+err+" "+errD);
	 * System.exit(0); throw new
	 * RuntimeException("Error Too High "+zeros+" "+err+" "+errD); }
	 * 
	 * }
	 */
}
