package edu.jhu.ece.iacl.algorithms.tgdm;

import java.util.Hashtable;
import java.util.LinkedList;
import java.util.List;

import javax.vecmath.Point3f;
import javax.vecmath.Vector3f;

import edu.jhu.ece.iacl.algorithms.graphics.GeometricUtilities;
import edu.jhu.ece.iacl.algorithms.graphics.isosurf.IsoSurfaceOnGrid;
import edu.jhu.ece.iacl.algorithms.gvf.FastMarchingGradient;
import edu.jhu.ece.iacl.algorithms.gvf.FastMarchingGradient.Normalization;
import edu.jhu.ece.iacl.algorithms.topology.ConnectivityRule;
import edu.jhu.ece.iacl.algorithms.volume.DistanceField;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.data.BinaryMinHeap;
import edu.jhu.ece.iacl.jist.structures.geom.EmbeddedSurface;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.MaskVolume6;
import edu.jhu.ece.iacl.jist.structures.image.VoxelFloat;
import edu.jhu.ece.iacl.jist.structures.image.VoxelIndexed;

import Jama.*;

public class FastMarchingSphericalExtension extends AbstractCalculation {
	EmbeddedSurface surf;
	Hashtable<Long, Integer> hash=null;
	int rows, cols, slices;
	int offset = 0;
	public static final byte UNVISITED = 0;
	public static final byte VISITED = 1;
	public static final byte SOLVED = 2;

	float[][][] levelSetM;
	float[][][][] grad;
	byte[][][] labelM;
	float[][][] Qx, Qy, Qz;
	Point3f[][][] Q;
	BinaryMinHeap heap;
	byte[] neighborsX;
	byte[] neighborsY;
	byte[] neighborsZ;

	public FastMarchingSphericalExtension() {
		super();

		setLabel("Fast-Marching Spherical Extension");
		MaskVolume6 mask = new MaskVolume6();
		neighborsX = mask.getNeighborsX();
		neighborsY = mask.getNeighborsY();
		neighborsZ = mask.getNeighborsZ();
	}

	public FastMarchingSphericalExtension(AbstractCalculation parent) {
		super(parent);
		setLabel("Fast-Marching Spherical Extension");
		MaskVolume6 mask = new MaskVolume6();
		neighborsX = mask.getNeighborsX();
		neighborsY = mask.getNeighborsY();
		neighborsZ = mask.getNeighborsZ();
	}

	protected void buildHash() {
		int vertCount = surf.getVertexCount();
		hash = new Hashtable<Long, Integer>();
		int i, j, k, d;
		Point3f p;
		long hashVal;
		for (int id = 0; id < vertCount; id++) {
			p = surf.getVertex(id);
			i = (int) Math.floor(p.x);
			j = (int) Math.floor(p.y);
			k = (int) Math.floor(p.z);
			if (p.x > i) {
				d = 1;
			} else if (p.y > j) {
				d = 2;
			} else if (p.z > k) {
				d = 3;
			} else {
				d = 0;
				System.out.println("D " + p);
			}
			hashVal = d + 4 * (k + slices * (j + cols * i));
			hash.put(hashVal, id);
		}
	}

	protected Point3f[] lookupPoint(int i1, int j1, int k1, int i2, int j2,
			int k2) {
		
		if (levelSetM[i1][j1][k1] * levelSetM[i2][j2][k2] < 0) {
			// swap
			int tmp;
			if (i1 > i2 || j1 > j2 || k1 > k2) {
				tmp = i1;
				i1 = i2;
				i2 = tmp;
				tmp = j1;
				j1 = j2;
				j2 = tmp;
				tmp = k1;
				k1 = k2;
				k2 = tmp;
			}
			int d = 0;
			if (i2 > i1) {
				d = 1;
			} else if (j2 > j1) {
				d = 2;
			} else if (k2 > k1) {
				d = 3;
			} else {
				System.out.println("NOT FOUND " + i1 + " " + j1 + " " + k1);
				d = 0;
			}
			long hashVal = d + 4 * (k1 + slices * (j1 + cols * i1));
			Integer VID = hash.get(hashVal);
			if(VID!=null){
				int vid=VID.intValue();
				return new Point3f[] { surf.getVertex(vid),
						surf.getPointAtOffset(vid, offset) };
			} else {
				long hashVal2 = d + 4 * (k2 + slices * (j2 + cols * i2));
				System.out.println("COULD NOT FIND POINT! ("+i1+","+j1+","+k1+") ("+i2+","+j2+","+k2+") hash "+hashVal+" "+hashVal2+" hash "+hash.get(hashVal2)+" "+((hash!=null)?hash.size():null));
				System.out.flush();
				System.exit(1);
				return null;
			}

		} else {
			return new Point3f[] { new Point3f(i1, j1, k1), Q[i1][j1][k1] };
		}
	}

	protected void initializeBoundary(double maxDist) {
		int i, j, k;
		int count = 0;

		buildHash();
		labelM = new byte[rows][cols][slices];
		for (i = 0; i < rows; i++) {
			for (j = 0; j < cols; j++) {
				for (k = 0; k < slices; k++) {
					labelM[i][j][k] = UNVISITED;
					if (Math.abs(levelSetM[i][j][k]) <= maxDist) {
						count++;
					}
				}
			}
		}
		setTotalUnits(count);
		heap = new BinaryMinHeap(count, rows, cols, slices);
		int vertCount = surf.getVertexCount();
		Point3f q;
		for (int id = 0; id < vertCount; id++) {
			Point3f p = surf.getVertex(id);
			i = (int) Math.floor(p.x);
			j = (int) Math.floor(p.y);
			k = (int) Math.floor(p.z);
			if (labelM[i][j][k] == UNVISITED) {
				q = updateNode(i, j, k, surf.getPointAtOffset(id, offset));

				if (q != null) {
					if(Float.isNaN(q.x)||Float.isNaN(q.y)||Float.isNaN(q.z)){
						System.err.println("Initial Q2 IS NaN"+ q+" "+i+","+j+","+k);
					}
					Q[i][j][k] = q;
					VoxelIndexed<VoxelFloat> vox = new VoxelIndexed<VoxelFloat>(
							new VoxelFloat((float) Math.abs(levelSetM[i][j][k])));
					vox.setRefPosition(i, j, k);
					labelM[i][j][k] = VISITED;
					heap.add(vox);
				}
			}
			i = (int) Math.ceil(p.x);
			j = (int) Math.ceil(p.y);
			k = (int) Math.ceil(p.z);
			if (labelM[i][j][k] == UNVISITED) {
				q = updateNode(i, j, k, surf.getPointAtOffset(id, offset));
				if (q != null) {
					if(Float.isNaN(q.x)||Float.isNaN(q.y)||Float.isNaN(q.z)){
						System.err.println("Initial Q2 IS NaN"+ q+" "+i+","+j+","+k);
					}
					Q[i][j][k] = q;
					VoxelIndexed<VoxelFloat> vox = new VoxelIndexed<VoxelFloat>(
							new VoxelFloat((float) Math.abs(levelSetM[i][j][k])));
					vox.setRefPosition(i, j, k);
					labelM[i][j][k] = VISITED;
					heap.add(vox);
				}
			}
		}
	}

	protected void initializeBoundary(float[][][] Qx, float[][][] Qy,
			float[][][] Qz, float[][][] skelvol, double maxDist) {
		int i, j, k;
		int ni, nj, nk;
		int count = 0;
		float l1, l2;
		labelM = new byte[rows][cols][slices];
		for (i = 0; i < rows; i++) {
			for (j = 0; j < cols; j++) {
				for (k = 0; k < slices; k++) {
					labelM[i][j][k] = UNVISITED;
					if (Math.abs(levelSetM[i][j][k]) <= maxDist) {
						count++;
					}

				}
			}
		}
		heap = new BinaryMinHeap(count, rows, cols, slices);
		Point3f tmpq;
		for (i = 0; i < rows; i++) {
			for (j = 0; j < cols; j++) {
				for (k = 0; k < slices; k++) {
					l1 = levelSetM[i][j][k];
					if (skelvol != null && skelvol[i][j][k] > 0) {
						// Do not compute skeleton points
						labelM[i][j][k] = SOLVED;
					} else {
						if (Math.abs(l1) <= maxDist) {
							for (int koff = 0; koff < MaskVolume6.length; koff++) {
								ni = i + neighborsX[koff];
								nj = j + neighborsY[koff];
								nk = k + neighborsZ[koff];
								if (nj < 0 || nj >= cols || nk < 0
										|| nk >= slices || ni < 0 || ni >= rows)
									continue; // Out of boundary
								l2 = levelSetM[ni][nj][nk];
								if (l1 * l2 < 0) {
									Q[i][j][k] = new Point3f(Qx[i][j][k],
											Qy[i][j][k], Qz[i][j][k]);
									//In case Q is not already normalized
									GeometricUtilities.normalize(tmpq=Q[i][j][k]);
									if(Float.isNaN(tmpq.x)||Float.isNaN(tmpq.y)||Float.isNaN(tmpq.z)){
										System.err.println("Initial Q IS NaN"+ tmpq+" "+l1+" "+l2+" "+Qx[i][j][k]+" "+Qy[i][j][k]+" "+Qz[i][j][k]+" "+Qx[ni][nj][nk]+" "+Qy[ni][nj][nk]+" "+Qz[ni][nj][nk]);
									} else {
										VoxelIndexed<VoxelFloat> vox = new VoxelIndexed<VoxelFloat>(
												new VoxelFloat((float) Math
														.abs(levelSetM[i][j][k])));
										vox.setRefPosition(i, j, k);
										labelM[i][j][k] = SOLVED;
										heap.add(vox);
									}
									break;
								}
							}
						}
					}
				}
			}
		}
		setTotalUnits(count);
	}

	public void solve(float[][][] levelSet, float[][][] Qx, float[][][] Qy,
			float[][][] Qz, float[][][] skelvol, int offset, double maxDist) {
		this.offset = offset;
		levelSetM = levelSet;
		rows = levelSetM.length;
		cols = levelSetM[0].length;
		slices = levelSetM[0][0].length;
		double max_signed_dist= DistanceField.maxSignedDistance(levelSet);
		System.out.println(max_signed_dist+" "+maxDist);
		if(max_signed_dist>maxDist){
			System.err.println("Maximum signed distance cannot be larger!");
		}
		FastMarchingGradient fmg = new FastMarchingGradient(this);
		this.grad = fmg.gradient(levelSetM, Normalization.MAGNITUDE);
		Q = new Point3f[rows][cols][slices];
		initializeBoundary(Qx, Qy, Qz, skelvol, maxDist);
		computeMap(maxDist);
		this.Qx = new float[rows][cols][slices];
		this.Qy = new float[rows][cols][slices];
		this.Qz = new float[rows][cols][slices];
		int ni, nj, nk;
		int upwindCount = 0;
		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					Point3f p = Q[i][j][k];
					if (p != null) {
						this.Qx[i][j][k] = p.x;
						this.Qy[i][j][k] = p.y;
						this.Qz[i][j][k] = p.z;
						if(Float.isNaN(p.x)||Float.isNaN(p.y)||Float.isNaN(p.z)){
							System.out.println("Final Q IS NaN "+ p+" "+i+" "+j+" "+k);
						}
						/*
						 * upwindCount=0; float l=levelSetM[i][j][k]; for (int
						 * koff = 0; koff < MaskVolume6.length; koff++) { ni = i +
						 * neighborsX[koff]; nj = j + neighborsY[koff]; nk = k +
						 * neighborsZ[koff]; if (nj < 0 || nj >= cols || nk < 0 ||
						 * nk >= slices || ni < 0|| ni >= rows)continue; // Out
						 * of boundary if(levelSetM[ni][nj][nk]<=l){
						 * upwindCount++; } } if(upwindCount<3){ Qx[i][j][k] =
						 * -10; Qy[i][j][k] = -10; Qz[i][j][k] = -10; }
						 */
					} else {
						if (skelvol != null && skelvol[i][j][k] > 0) {
							this.Qx[i][j][k] = Qx[i][j][k];
							this.Qy[i][j][k] = Qy[i][j][k];
							this.Qz[i][j][k] = Qz[i][j][k];
						} else {
							this.Qx[i][j][k] = Float.NaN;
							this.Qy[i][j][k] = Float.NaN;
							this.Qz[i][j][k] = Float.NaN;
						}
					}
				}
			}
		}
	}

	public void solve(float[][][] levelSet, EmbeddedSurface embeddedSurf,
			int offset, double maxDist) {
		this.offset = offset;
		levelSetM = levelSet;
		rows = levelSetM.length;
		cols = levelSetM[0].length;
		slices = levelSetM[0][0].length;
		FastMarchingGradient fmg = new FastMarchingGradient(this);
		this.grad = fmg.gradient(levelSetM, Normalization.MAGNITUDE);
		Q = new Point3f[rows][cols][slices];
		this.surf = embeddedSurf;
		double max_signed_dist= DistanceField.maxSignedDistance(levelSet);
		System.out.println(max_signed_dist+" "+maxDist);
		if(max_signed_dist>maxDist){
			System.err.println("Maximum signed distance cannot be larger!");
		}
		initializeBoundary(maxDist);
		computeMap(maxDist);
		Qx = new float[rows][cols][slices];
		Qy = new float[rows][cols][slices];
		Qz = new float[rows][cols][slices];
		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					Point3f p = Q[i][j][k];
					if (p != null) {
						Qx[i][j][k] = p.x;
						Qy[i][j][k] = p.y;
						Qz[i][j][k] = p.z;
					} else {
						Qx[i][j][k] = Float.NaN;
						Qy[i][j][k] = Float.NaN;
						Qz[i][j][k] = Float.NaN;
					}
				}
			}
		}
	}

	public void computeMap(double maxDist) {
		Point3f q;
		int i, j, k, ni, nj, nk;
		while (heap.size() > 0) {
			VoxelIndexed<VoxelFloat> he = (VoxelIndexed<VoxelFloat>) heap
					.remove();
			i = he.getRow();
			j = he.getColumn();
			k = he.getSlice();
			labelM[i][j][k] = SOLVED;
			incrementCompletedUnits();
			for (int koff = 0; koff < MaskVolume6.length; koff++) {
				ni = i + neighborsX[koff];
				nj = j + neighborsY[koff];
				nk = k + neighborsZ[koff];
				if (nj < 0 || nj >= cols || nk < 0 || nk >= slices || ni < 0
						|| ni >= rows)
					continue; // Out of boundary
				if (Math.abs(levelSetM[ni][nj][nk]) >= maxDist)
					continue;

				byte l = labelM[ni][nj][nk];
				if (l == VISITED) {
					q = updateNode(ni, nj, nk, Q[i][j][k]);
					if (q != null) {
						Q[ni][nj][nk] = q;
					}
				} else if (l == UNVISITED) {
					q = updateNode(ni, nj, nk, Q[i][j][k]);
					if (q != null) {
						Q[ni][nj][nk] = q;
						VoxelIndexed<VoxelFloat> vox = new VoxelIndexed<VoxelFloat>(
								new VoxelFloat((float) Math
										.abs(levelSetM[ni][nj][nk])));
						vox.setRefPosition(ni, nj, nk);
						labelM[ni][nj][nk] = VISITED;
						heap.add(vox);
					}
				}
			}
		}
		heap.makeEmpty();
		heap = null;
		labelM = null;
		hash = null;
		surf = null;
		System.gc();
		markCompleted();
	}

	protected Point3f updateNode(int i, int j, int k, Point3f qparent) {
		Vector3f tan = new Vector3f(grad[i][j][k][0], grad[i][j][k][1],
				grad[i][j][k][2]);
		Point3f pivot = new Point3f(i, j, k);
		Vector3f diff = new Vector3f();
		boolean sgn = (levelSetM[i][j][k] > 0);
		double w, l;
		int ni, nj, nk;
		int n = 0;
		List<Point3f> upwindPoints = new LinkedList<Point3f>();
		List<Point3f> upwindQs = new LinkedList<Point3f>();
		for (int koff = 0; koff < MaskVolume6.length; koff++) {
			ni = i + neighborsX[koff];
			nj = j + neighborsY[koff];
			nk = k + neighborsZ[koff];
			if (nj < 0 || nj >= cols || nk < 0 || nk >= slices || ni < 0
					|| ni >= rows)
				continue; /* Out of boundary */
			Point3f[] ret = lookupPoint(ni, nj, nk, i, j, k);
			Point3f p = ret[0];
			if (ret[1] != null) {
				if (sgn) {
					diff.sub(pivot, p);
				} else {
					diff.sub(p, pivot);
				}
				if (tan.dot(diff) >= 0) {
					upwindPoints.add(p);
					upwindQs.add(ret[1]);
				}
			}

		}
		if (upwindPoints.size() < 3) {
			Point3f q = null;
			q = new Point3f();
			double wsum = 0;
			n = 0;
			for (Point3f nd : upwindPoints) {
				// Weight by distance to iso-level
				w = pivot.distance(nd);
				w = (w > 1E-6) ? 1 / w : 1E6;
				wsum += w;
				Point3f p = new Point3f(upwindQs.get(n));
				p.scale((float) w);
				q.add(p);
				n++;
			}
			l = Math.sqrt(q.x * q.x + q.y * q.y + q.z * q.z);
			if (l > 1E-6) {
				q.scale(1.0f / (float) l);
				// System.out.println("UPWIND "+upwindPoints.size()+" "+q);
				return q;
			} else {
				/*
				 * System.out.println("2 SMALL MAGNITUDE "+q); for (int koff =
				 * 0; koff < MaskVolume6.length; koff++) { ni = i +
				 * neighborsX[koff]; nj = j + neighborsY[koff]; nk = k +
				 * neighborsZ[koff]; if (nj < 0 || nj >= cols || nk < 0 || nk >=
				 * slices || ni < 0 || ni >= rows) continue; // Out of boundary
				 * //if (labelM[ni][nj][nk] != UNVISITED) { Point3f[] ret =
				 * lookupPoint(ni, nj, nk, i, j, k); System.out.println(ret[0]+"
				 * "+ret[1]); //} }
				 */
				return null;
			}
		} else {
			Matrix D = new Matrix(upwindPoints.size(), 3);
			Matrix W = new Matrix(upwindPoints.size(), upwindPoints.size());
			Matrix One = new Matrix(upwindPoints.size(), 1);
			Matrix Qx, Qy, Qz;
			Qx = new Matrix(upwindPoints.size(), 1);
			Qy = new Matrix(upwindPoints.size(), 1);
			Qz = new Matrix(upwindPoints.size(), 1);
			for (n = 0; n < upwindPoints.size(); n++) {
				Point3f p = upwindPoints.get(n);
				diff.sub(pivot, p);
				D.set(n, 0, diff.x);
				D.set(n, 1, diff.y);
				D.set(n, 2, diff.z);
				w = diff.length();
				w = (w > 1E-6) ? 1 / w : 1E6;
				W.set(n, n, w);
				Point3f q = upwindQs.get(n);
				Qx.set(n, 0, q.x);
				Qy.set(n, 0, q.y);
				Qz.set(n, 0, q.z);
				One.set(n, 0, 1);
			}
			Matrix Dsqr = D.transpose().times(W).times(D);
			SingularValueDecomposition svd = new SingularValueDecomposition(
					Dsqr);
			Matrix S = svd.getS();
			Matrix V = svd.getV();
			Matrix U = svd.getU();
			for (n = 0; n < S.getColumnDimension(); n++) {
				if (Math.abs(S.get(n, n)) > 1E-12) {
					S.set(n, n, 1 / S.get(n, n));
				} else {
					S.set(n, n, 0);
				}
			}
			Matrix Dinv = V.times(S.times(U.transpose())).times(D.transpose())
					.times(W);
			Matrix T = new Matrix(1, 3);
			if (sgn) {
				T.set(0, 0, tan.x);
				T.set(0, 1, tan.y);
				T.set(0, 2, tan.z);
			} else {
				T.set(0, 0, -tan.x);
				T.set(0, 1, -tan.y);
				T.set(0, 2, -tan.z);
			}

			Matrix ot = T.times(Dinv);
			double denom = ot.times(One).get(0, 0);
			if (Math.abs(denom) < 1E-10) {
				denom = 1;
			}
			Point3f q = new Point3f();
			q.x = (float) ((ot.times(Qx)).get(0, 0) / denom);
			q.y = (float) ((ot.times(Qy)).get(0, 0) / denom);
			q.z = (float) ((ot.times(Qz)).get(0, 0) / denom);
			l = Math.sqrt(q.x * q.x + q.y * q.y + q.z * q.z);
			if (l > 1E-5) {
				q.scale(1.0f / (float) l);
				return q;
			} else {
				q = new Point3f();
				double wsum = 0;
				n = 0;
				for (Point3f nd : upwindPoints) {
					// Weight by distance to iso-level
					w = pivot.distance(nd);
					w = (w > 1E-6) ? 1 / w : 1E6;
					wsum += w;
					Point3f p = new Point3f(upwindQs.get(n));
					p.scale((float) w);
					q.add(p);
					n++;
				}
				l = Math.sqrt(q.x * q.x + q.y * q.y + q.z * q.z);
				if (l > 1E-6) {
					q.scale(1.0f / (float) l);
					return q;
				} else {
					/*
					 * System.out.println("1 SMALL MAGNITUDE "+q); for (int koff =
					 * 0; koff < MaskVolume6.length; koff++) { ni = i +
					 * neighborsX[koff]; nj = j + neighborsY[koff]; nk = k +
					 * neighborsZ[koff]; if (nj < 0 || nj >= cols || nk < 0 ||
					 * nk >= slices || ni < 0 || ni >= rows) continue; // Out of
					 * boundary //if (labelM[ni][nj][nk] != UNVISITED) {
					 * Point3f[] ret = lookupPoint(ni, nj, nk, i, j, k);
					 * System.out.println(ret[0]+" "+ret[1]); //} }
					 * 
					 * //q = new Point3f(0,0,0);//(Point3f) qparent.clone();
					 */
					return null;
				}
			}
		}
	}

	public Point3f[][][] getQ() {
		return Q;
	}

	public float[][][] getQx() {
		return Qx;
	}

	public float[][][] getQy() {
		return Qy;
	}

	public float[][][] getQz() {
		return Qz;
	}

	public float[][][][] getGradient() {
		return grad;
	}
}
