package edu.jhu.ece.iacl.algorithms.tgdm;

import javax.vecmath.Matrix3d;
import javax.vecmath.Matrix3f;
import javax.vecmath.Point3d;
import javax.vecmath.Point3f;
import javax.vecmath.Point3i;
import javax.vecmath.Vector3d;
import javax.vecmath.Vector3f;

import edu.jhu.ece.iacl.algorithms.graphics.GeometricUtilities;
import edu.jhu.ece.iacl.algorithms.graphics.locator.balltree.PointBall;
import edu.jhu.ece.iacl.algorithms.gvf.FastMarchingGradient;
import edu.jhu.ece.iacl.algorithms.gvf.FastMarchingGradient.Normalization;
import edu.jhu.ece.iacl.algorithms.tgdm.GenericTGDM.NBPoint;
import edu.jhu.ece.iacl.algorithms.tgdm.ProfileTGDM.Tracking;
import edu.jhu.ece.iacl.algorithms.tgdm.TrackingTGDM.EmbeddedNBPoint;

public class AuxiliaryForceTGDM extends TrackingTGDM {
	// Reference to principal surface
	protected CoupledForceTGDM principalTGDM;
	protected int[][] nearestNeighborCache;
	protected float[][][] initPhi;
	public AuxiliaryForceTGDM(CoupledForceTGDM principalTGDM, int rule) {
		super(principalTGDM, rule);
		this.principalTGDM = principalTGDM;
	}

	protected void initNarrowBand(ProfileTGDM profile) {
		super.initNarrowBand(profile);
		// Build nearest neighbor cache
		if (principalTGDM.nearestNeighborCacheEnabled) {
			int i = 0;
			nearestNeighborCache = new int[narrowBand.size()][principalTGDM.NearestNeighbors];
			PointBall[] nbhrs = new PointBall[principalTGDM.NearestNeighbors];
			for (NBPoint pt : this.narrowBand) {
				Point3d q = new Point3d(Qx[pt.x][pt.y][pt.z], Qy[pt.x][pt.y][pt.z], Qz[pt.x][pt.y][pt.z]);
				principalTGDM.balltree.getNearestNeighbors(q, nbhrs);
				for (int k = 0; k < nbhrs.length; k++) {
					nearestNeighborCache[i][k] = nbhrs[k].getIndex();
				}
				i++;
			}
		}
	}

	protected double weightScaling = 3;

	protected Vector3d interpolateForce(double[] radii, Vector3d[] forces) {
		Vector3d v = new Vector3d();
		float w, wsum = 0;
		for (int i = 0; i < radii.length; i++) {
			w = (float) Math.exp(-weightScaling * radii[i]);
			v.x += w * forces[i].x;
			v.y += w * forces[i].y;
			v.z += w * forces[i].z;
			wsum += w;
		}
		for (int i = 0; i < radii.length; i++) {
			w = (float) Math.exp(-weightScaling * radii[i]);
			v.x += w * forces[i].x;
			v.y += w * forces[i].y;
			v.z += w * forces[i].z;
			wsum += w;
		}
		if(wsum>1E-10){
			v.scale(1 / wsum);
		} else{
			v=new Vector3d(forces[0]);
		}
		return v;
	}

	protected Point3d interpolatePoint(double[] radii, Point3i[] points) {
		Point3d v = new Point3d();
		Point3d vs = new Point3d();
		
		double w, wsum = 0;
		for (int i = 0; i < radii.length; i++) {
			w = (float) Math.exp(-weightScaling * radii[i]);
			v.x += w * points[i].x;
			v.y += w * points[i].y;
			v.z += w * points[i].z;
			wsum += w;
		}

		if(wsum>1E-10){
			v.scale(1 / wsum);
		} else{
			Point3i p=points[0];
			v=new Point3d(p.x,p.y,p.z);
		}

		/*
		for (int i = 0; i < radii.length; i++) {
			w = (float) Math.exp(-weightScaling * radii[i]);
			vs.x += w * (points[i].x-v.x)*(points[i].x-v.x);
			vs.y += w * (points[i].y-v.y)*(points[i].y-v.y);
			vs.z += w * (points[i].z-v.z)*(points[i].z-v.z);
			wsum += w;
		}
		if(wsum>1E-10){
			vs.scale(1 / wsum);
			vs.x=Math.sqrt(vs.x);
			vs.y=Math.sqrt(vs.y);
			vs.z=Math.sqrt(vs.z);
		} else{
			vs=new Point3d();
		}
		System.out.println("Location "+v+" standard deviation "+vs);
		*/
		return v;
	}

	protected double updatePoint(NBPoint pt, ProfileTGDM profile) {
		// lookup correspondence point
		Point3d q = new Point3d(Qx[pt.x][pt.y][pt.z], Qy[pt.x][pt.y][pt.z], Qz[pt.x][pt.y][pt.z]);
		if(Double.isNaN(q.x)||Double.isNaN(q.y)||Double.isNaN(q.z)){
			System.err.println("Q is NaN "+q+" "+pt);
		}
		Vector3d[] forces = new Vector3d[principalTGDM.NearestNeighbors];
		Point3i[] pts = new Point3i[principalTGDM.NearestNeighbors];
		Vector3d commuteForce;
		Point3d pt2;
		double[] radii = new double[principalTGDM.NearestNeighbors];
		if (principalTGDM.nearestNeighborCacheEnabled) {
			// Find nearest neighbors and forces in principal tgdm
			for (int k = 0; k < principalTGDM.NearestNeighbors; k++) {
				int i = nearestNeighborCache[pt.index][k];
				NBPoint pd = principalTGDM.narrowBand.get(i);
				Point3d nq = new Point3d();
				nq.x = principalTGDM.Qx[pd.x][pd.y][pd.z];
				nq.y = principalTGDM.Qy[pd.x][pd.y][pd.z];
				nq.z = principalTGDM.Qz[pd.x][pd.y][pd.z];
				if(nq==null||Double.isNaN(nq.x)||Double.isNaN(nq.y)||Double.isNaN(nq.z)){
					System.err.println("NEIGHBOR "+pd+" is NaN");
				}
				radii[k] = nq.distance(q);
				forces[k] = ((EmbeddedNBPoint) pd).vec;
				pts[k] = pd;
			}
			commuteForce = interpolateForce(radii, forces);
			pt2 = interpolatePoint(radii, pts);
			// interpolate force
		} else {
			PointBall[] nbhrs = new PointBall[principalTGDM.NearestNeighbors];
			// find nearest neighbors in original
			principalTGDM.balltree.getNearestNeighbors(q, nbhrs);
			// find force associated with location
			for (int k = 0; k < principalTGDM.NearestNeighbors; k++) {
				NBPoint pd = principalTGDM.narrowBand.get(((PointBall) nbhrs[k]).getIndex());
				if(nbhrs[k]==null||Double.isNaN(nbhrs[k].x)||Double.isNaN(nbhrs[k].y)||Double.isNaN(nbhrs[k].z)){
					System.err.println("NEIGHBOR "+pd+" is NaN");
				}
				radii[k] = nbhrs[k].distance(q);
				forces[k] = ((EmbeddedNBPoint) pd).vec;
				pts[k] = pd;
			}
			// interpolate force
			commuteForce = interpolateForce(radii, forces);
			pt2 = interpolatePoint(radii, pts);
		}
		Vector3d dir = new Vector3d();
		dir.sub(pt2, new Point3d(pt.x, pt.y, pt.z));
		// double commuteForce=regionForces[pt.x][pt.y][pt.z];
		int LX, LY, LZ, HX, HY, HZ; /* Variables implementing mirror boundary */
		double North, South, West, East, Current, Front, Back;
		/*
		 * Dmx (Dmy, Dmz) : backward difference Dpx (Dpy, Dpz) : forward
		 * difference D0x (D0y, D0z) : centered difference
		 */
		double D0x, D0y, D0z;
		Vector3d curvatureGrad = new Vector3d();
		int x, y, z;
		/*
		 * double Dmx, Dmy, Dpx, Dpy, Dxx, Dyy, Dxy; double Dmz, Dpz, Dzz, Dxz,
		 * Dyz; double SD0x, SD0y, SD0z;
		 * 
		 * double K, G; Vector3d advectionGrad = new Vector3d(); Vector3d
		 * regionGrad = new Vector3d(); Vector3d advectionForce = new
		 * Vector3d(); double regionWeight = profile.getPressureForce(); double
		 * externalWeight = profile.getExternalForce(); double regionForce;
		 */
		// NBPoint = narrowBand.get(NBIndex - 1);
		x = pt.x;
		y = pt.y;
		z = pt.z;
		// System.out.printf("%d %d %d\n",d,i,j);
		LY = (y == 0) ? 1 : 0;
		HY = (y == (cols - 1)) ? 1 : 0;
		LZ = (z == 0) ? 1 : 0;
		HZ = (z == (slices - 1)) ? 1 : 0;
		LX = (x == 0) ? 1 : 0;
		HX = (x == (rows - 1)) ? 1 : 0; // j should correspond
		// to _x ?!
		North = implicitLevelSet[x][y - 1 + LY][z];
		South = implicitLevelSet[x][y + 1 - HY][z];
		West = implicitLevelSet[x - 1 + LX][y][z];
		East = implicitLevelSet[x + 1 - HX][y][z];
		Front = implicitLevelSet[x][y][z + 1 - HZ];
		Back = implicitLevelSet[x][y][z - 1 + LZ];
		Current = implicitLevelSet[x][y][z];
		// central differences
		D0x = (East - West) / 2;
		D0y = (South - North) / 2;
		D0z = (Front - Back) / 2;
		// CENTRAL DIFFERENCE PHI GRADIENT MAGNITUDE
		// Variable pressure field
		curvatureGrad.x = D0x;
		curvatureGrad.y = D0y;
		curvatureGrad.z = D0z;
		double gphi = curvatureGrad.length();
		GeometricUtilities.normalize(curvatureGrad);
		double numer = commuteForce.dot(dir);
		double denom = curvatureGrad.dot(dir);
		double force = (Math.abs(denom) > 1E-10) ? numer / denom : 0;
		curvatureGrad.scale(force);
		if(Double.isNaN(curvatureGrad.x)||Double.isNaN(curvatureGrad.y)||Double.isNaN(curvatureGrad.z)){
			System.err.println("VELOCITY IS NaN "+numer+" "+denom+" "+dir+" "+pt);
		}
		((EmbeddedNBPoint) pt).vec = curvatureGrad;
		//Compute new level set value based on velocity
		return implicitLevelSet[x][y][z] - force * gphi;
	}
}