package edu.jhu.ece.iacl.algorithms.graphics.utilities;

import java.util.List;

import Jama.*;

import javax.vecmath.Matrix3f;
import javax.vecmath.Point3d;
import javax.vecmath.Point3f;
import javax.vecmath.Point3f;
import javax.vecmath.Vector3f;

import edu.jhu.ece.iacl.algorithms.VersionUtil;
import edu.jhu.ece.iacl.algorithms.icp.IterativeClosestPointRegistration;

/**
 * PCA analysis on a set of 3D points. This method computes the 3D
 * transformation matrix needed to align the minimum Eigen vector with the Z
 * axis. All computations are performed in floating point.
 * 
 * @author Blake Lucas
 * 
 */
public class PrincipalComponentAnalysisFloat {
	public static String getVersion() {
		return VersionUtil.parseRevisionNumber("$Revision: 1.1 $");
	}


	protected Matrix convariance;
	protected Matrix3f correctRot;
	protected Matrix3f returnRot;
	protected Point3f centroid;
	private String consistancymethod;
	private Matrix lastV;

	public PrincipalComponentAnalysisFloat(Point3f[] points) {
		correctRot = computeRotation(convariance = computeCovariance(points));
		returnRot = new Matrix3f();
		returnRot.invert(correctRot);
	}

	public Matrix3f getBackwardRotation() {
		return correctRot;
	}

	public Matrix3f getBackgroundRotation() {
		return returnRot;
	}

	public Point3f getCentroid() {
		return centroid;
	}

	public Matrix getCovarianceMatrix() {
		return convariance;
	}
	public Matrix getLastV(){
		return lastV;
	}
	
	public Matrix getTranslationTransformationMatrix() {
		Matrix Tmat = new Matrix(4, 4);
		Tmat.set(0, 3, centroid.x);
		Tmat.set(1, 3, centroid.y);
		Tmat.set(2, 3, centroid.z);
		Tmat.set(0, 0, 1);
		Tmat.set(1, 1, 1);
		Tmat.set(2, 2, 1);
		Tmat.set(3, 3, 1);
		Tmat.set(3, 0, 0);
		Tmat.set(3, 1, 0);
		Tmat.set(3, 2, 0);
		return Tmat;
	}
	public Matrix getRigidGlobalScaleTransformationMatrix() {
		Matrix Tmat = new Matrix(4, 4);
		EigenvalueDecomposition ed = new EigenvalueDecomposition(convariance);
		Matrix R = ed.getV();
		lastV=R;
		Matrix D = ed.getD();
		
		
		double meanScale=0.333333*(Math.abs(D.get(0,0))+Math.abs(D.get(1,1))+Math.abs(D.get(2,2)));
		for (int m = 0; m < 3; m++) {
			for (int n = 0; n < 3; n++) {
				Tmat.set(m, n, meanScale*Math.signum(D.get(n,n))*R.get(m, n));
			}
		}
		
		Tmat.set(0, 3, centroid.x);
		Tmat.set(1, 3, centroid.y);
		Tmat.set(2, 3, centroid.z);
		Tmat.set(3, 3, 1);
		Tmat.set(3, 0, 0);
		Tmat.set(3, 1, 0);
		Tmat.set(3, 2, 0);
		return Tmat;
	}
	public Matrix getRigidTransformationMatrix() {
		Matrix Tmat = new Matrix(4, 4);
		EigenvalueDecomposition ed = new EigenvalueDecomposition(convariance);
		Matrix R = ed.getV();
		lastV=R;
		Matrix D = ed.getD();
		System.out.println("R matrix");
		printMatrix(R);
		System.out.println("D matrix");
		printMatrix(D);
		
		for (int m = 0; m < 3; m++) {
			for (int n = 0; n < 3; n++) {
				Tmat.set(m, n, Math.signum(D.get(n,n))*R.get(m, n));
			}
		}
		
		Tmat.set(0, 3, centroid.x);
		Tmat.set(1, 3, centroid.y);
		Tmat.set(2, 3, centroid.z);
		Tmat.set(3, 3, 1);
		Tmat.set(3, 0, 0);
		Tmat.set(3, 1, 0);
		Tmat.set(3, 2, 0);
		return Tmat;
	}
	public Matrix getAffineTransformationMatrix() {
		Matrix Tmat = new Matrix(4, 4);
		EigenvalueDecomposition ed = new EigenvalueDecomposition(convariance);
		Matrix R = ed.getV();
		lastV=R;
		Matrix D = ed.getD();
		for (int m = 0; m < 3; m++) {
			for (int n = 0; n < 3; n++) {
				Tmat.set(m, n, Math.sqrt(D.get(n,n))*R.get(m, n));
			}
		}
		
		Tmat.set(0, 3, centroid.x);
		Tmat.set(1, 3, centroid.y);
		Tmat.set(2, 3, centroid.z);
		Tmat.set(3, 3, 1);
		Tmat.set(3, 0, 0);
		Tmat.set(3, 1, 0);
		Tmat.set(3, 2, 0);
		return Tmat;
	}
	public Matrix getRigidGlobalScaleTransformationMatrix(Matrix Vt) {
		Matrix Tmat = new Matrix(4, 4);
		EigenvalueDecomposition ed = new EigenvalueDecomposition(convariance);
		Matrix R = ed.getV();
		lastV=R;
		Matrix D = ed.getD();
		R = preserveOrder(R,Vt);
		
		double meanScale=0.333333*(Math.abs(D.get(0,0))+Math.abs(D.get(1,1))+Math.abs(D.get(2,2)));
		for (int m = 0; m < 3; m++) {
			for (int n = 0; n < 3; n++) {
				Tmat.set(m, n, meanScale*Math.signum(D.get(n,n))*R.get(m, n));
			}
		}
		Tmat.set(0, 3, centroid.x);
		Tmat.set(1, 3, centroid.y);
		Tmat.set(2, 3, centroid.z);
		Tmat.set(3, 3, 1);
		Tmat.set(3, 0, 0);
		Tmat.set(3, 1, 0);
		Tmat.set(3, 2, 0);
		return Tmat;
	}
	public Matrix getRigidTransformationMatrix(Matrix Vt) {
		Matrix Tmat = new Matrix(4, 4);
		EigenvalueDecomposition ed = new EigenvalueDecomposition(convariance);
		Matrix R = ed.getV();
		lastV=R;
		Matrix D = ed.getD();
		R = preserveOrder(R,Vt);
		System.out.println("R matrix");
		printMatrix(R);
		System.out.println("D matrix");
		printMatrix(D);
		
		for (int m = 0; m < 3; m++) {
			for (int n = 0; n < 3; n++) {
				Tmat.set(m, n, Math.signum(D.get(n,n))*R.get(m, n));
			}
		}
		Tmat.set(0, 3, centroid.x);
		Tmat.set(1, 3, centroid.y);
		Tmat.set(2, 3, centroid.z);
		Tmat.set(3, 3, 1);
		Tmat.set(3, 0, 0);
		Tmat.set(3, 1, 0);
		Tmat.set(3, 2, 0);
		return Tmat;
	}
	public Matrix getAffineTransformationMatrix(Matrix Vt) {
		Matrix Tmat = new Matrix(4, 4);
		EigenvalueDecomposition ed = new EigenvalueDecomposition(convariance);
		Matrix R = ed.getV();
		lastV=R;
		Matrix D = ed.getD();
		R = preserveOrder(R,Vt);
		for (int m = 0; m < 3; m++) {
			for (int n = 0; n < 3; n++) {
				Tmat.set(m, n, Math.sqrt(D.get(n,n))*R.get(m, n));
			}
		}
		Tmat.set(0, 3, centroid.x);
		Tmat.set(1, 3, centroid.y);
		Tmat.set(2, 3, centroid.z);
		Tmat.set(3, 3, 1);
		Tmat.set(3, 0, 0);
		Tmat.set(3, 1, 0);
		Tmat.set(3, 2, 0);
		return Tmat;
	}
	public void translateAndRotatePoint(Point3f p) {
		p.sub(centroid);
		correctRot.transform(p);
	}
	protected void printMatrix(Matrix M) {
		int rows = M.getRowDimension();
		int cols = M.getColumnDimension();
		for (int m = 0; m < rows; m++) {
			for (int n = 0; n < cols; n++) {
				System.out.printf("%8.4f ", M.get(m, n));
			}
			System.out.println("");
		}
	}
	public void translateAndRotatePoints(Point3f[] points) {
		for (Point3f p : points) {
			p.sub(centroid);
			correctRot.transform(p);
		}
	}

	public void translateAndRotateBackPoint(Point3f p) {
		returnRot.transform(p);
		p.add(centroid);
	}

	public void translateAndRotateBackPoints(Point3f[] points) {
		for (Point3f p : points) {
			returnRot.transform(p);
			p.add(centroid);
		}
	}

	public void rotateVector(Vector3f p) {
		correctRot.transform(p);
	}

	public void rotateVectors(Vector3f[] points) {
		for (Vector3f p : points) {
			correctRot.transform(p);
		}
	}

	public void rotateBackVector(Vector3f p) {
		returnRot.transform(p);
	}

	public void rotateBackVectors(Vector3f[] points) {
		for (Vector3f p : points) {
			returnRot.transform(p);
		}
	}

	protected Matrix computeCovariance(Point3f[] points) {
		centroid = new Point3f();
		Point3f p;
		Matrix b = new Matrix(points.length, 3);
		for (int i = 0; i < points.length; i++) {
			p = points[i];
			centroid.x += p.x;
			centroid.y += p.y;
			centroid.z += p.z;
		}
		centroid.x /= points.length;
		centroid.y /= points.length;
		centroid.z /= points.length;
		for (int i = 0; i < points.length; i++) {
			p = points[i];
			b.set(i, 0, p.x - centroid.x);
			b.set(i, 1, p.y - centroid.y);
			b.set(i, 2, p.z - centroid.z);
		}
		Matrix A = b.transpose().times(b);
		return A;
	}

	protected Matrix computeCovariance(List<Point3f> points) {
		centroid = new Point3f();
		int sz = points.size();
		Matrix b = new Matrix(sz, 3);
		for (Point3f p : points) {
			centroid.x += p.x;
			centroid.y += p.y;
			centroid.z += p.z;
		}
		centroid.x /= sz;
		centroid.y /= sz;
		centroid.z /= sz;
		int i = 0;
		for (Point3f p : points) {
			b.set(i, 0, p.x - centroid.x);
			b.set(i, 1, p.y - centroid.y);
			b.set(i, 2, p.z - centroid.z);
			i++;
		}
		Matrix A = b.transpose().times(b);
		return A;
	}

	protected Matrix3f computeRotation(Matrix A) {
		EigenvalueDecomposition ed = new EigenvalueDecomposition(A);
		Matrix rot = ed.getV();
		Matrix cov = ed.getD();
		Vector3f v = new Vector3f();
		Matrix3f correctRot = new Matrix3f();
		// Align min principal axis with the z axis
		if (cov.get(0, 0) < cov.get(1, 1)) {
			if (cov.get(0, 0) < cov.get(2, 2)) {
				// X is min principal axis
				correctRot.rotY((float) Math.PI * 0.5f);
			} else {
				// Z is min principal axis
				correctRot.rotX(0);
			}
		} else {
			if (cov.get(1, 1) < cov.get(2, 2)) {
				// Y is min principal axis
				correctRot.rotX((float) Math.PI * 0.5f);
			} else {
				// Z is min principal axis
				correctRot.rotX(0);
			}
		}
		Matrix Ainv = rot.inverse();
		Matrix3f B = new Matrix3f();
		// System.out.println("Covariance "+" "+A.getRowDimension()+"
		// "+A.getColumnDimension());
		for (int i = 0; i < 3; i++) {
			for (int j = 0; j < 3; j++) {
				B.setElement(i, j, (float) Ainv.get(i, j));
				// System.out.print(B.getElement(i, j)+" ");
			}
			// System.out.println();
		}
		correctRot.mul(B);
		return correctRot;
	}
	
	protected Matrix toRotationMatrix(Matrix A){
		if(dot(cross(A,0,A,1),0,A,2)<0){
			Matrix B = A.copy();
			Matrix z = cross(A,0,A,1);
			B.set(0, 2, z.get(0, 0));
			B.set(1, 2, z.get(1, 0));
			B.set(2, 2, z.get(2, 0));
			return B;
		}else{
			return A;
		}
	}
	
	//This method assumes that the ith target and source eigenvectors
	//should be aligned, and assumes that the brains require a rotation 
	//of less than pi radians to be registered
	//input is the matrix of eigenvectors
    private Matrix preserveOrder(Matrix Vs, Matrix Vt){
        Matrix Vs_new = Vs.copy();
        for (int j = 0; j < 3; j++) {
        	double angle = vectAng(Vs.get(0, j), Vs.get(1, j), Vs.get(2, j), Vt.get(0, j), Vt.get(1, j), Vt.get(2, j));
            if(angle > Math.PI/2 || angle < 0 ) {
            	for(int i = 0; i < 3; i++ ) {
            		Vs_new.set(i, j, -Vs.get(i, j));
            	}
            }        
        }
        return Vs_new;
    }
	
	private Matrix cross(Matrix A, int Acol, Matrix B, int Bcol){
		Matrix AxB=new Matrix(3,1);
		AxB.set(0, 0, A.get(1, Acol)*B.get(2, Bcol)-A.get(2, Acol)*B.get(1, Bcol));
		AxB.set(1, 0, A.get(2, Acol)*B.get(0, Bcol)-A.get(0, Acol)*B.get(2, Bcol));
		AxB.set(2, 0, A.get(0, Acol)*B.get(1, Bcol)-A.get(1, Acol)*B.get(0, Bcol));
		return AxB;
	}
	
	private double dot(Matrix A, int Acol, Matrix B, int Bcol){
		return A.get(0, Acol)*B.get(0, Bcol) + A.get(1, Acol)*B.get(1, Bcol)+A.get(2, Acol)*B.get(2, Bcol);
	}
	
    private double vectAng(double vx1, double vy1, double vz1, double vx2, double vy2, double vz2) {
        double dv1 = Math.sqrt(vx1*vx1+vy1*vy1+vz1*vz1);
        double dv2 = Math.sqrt(vx2*vx2+vy2*vy2+vz2*vz2);
        if((dv1==0)||(dv2==0))
            return -1;
        vx1=vx1/dv1;
        vy1=vy1/dv1;
        vz1=vz1/dv1;
        vx2=vx2/dv2;
        vy2=vy2/dv2;
        vz2=vz2/dv2;
        //finite precision can result in normalized sums greater than +-1. catch these with the min
        return Math.acos(Math.max(-1.0f,Math.min(vx1*vx2+vy1*vy2+vz1*vz2,1.0f)));
    }
	
}
