package edu.jhu.ece.iacl.algorithms.graphics.map;

import java.util.Stack;
import javax.vecmath.Point3f;
import javax.vecmath.Vector3f;

import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.sparse.BiCG;
import no.uib.cipr.matrix.sparse.FlexCompRowMatrix;
import no.uib.cipr.matrix.sparse.IterationMonitor;
import no.uib.cipr.matrix.sparse.IterativeSolverNotConvergedException;

import edu.jhu.ece.iacl.algorithms.VersionUtil;
import edu.jhu.ece.iacl.algorithms.graphics.GeometricUtilities;
import edu.jhu.ece.iacl.algorithms.graphics.locator.sphere.SphericalMapLocator;
import edu.jhu.ece.iacl.algorithms.graphics.surf.ProgressiveSurface;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.data.BinaryMinHeap;
import edu.jhu.ece.iacl.jist.structures.geom.EmbeddedSurface;
import edu.jhu.ece.iacl.jist.structures.geom.VertexIndexed;

/**
 * Method for correcting a spherical map so that its a bijective embedding.
 * 
 * @author Blake Lucas
 * 
 */
public class SphericalMapCorrection extends ProgressiveSurface {
	public static String getVersion() {
		return VersionUtil.parseRevisionNumber("$Revision: 1.1 $");
	}

	protected double lambda;
	protected double decimationStep;
	protected double maxDecimation;
	protected int smoothStepIters;

	public void setMaxDecimation(double maxDecimation) {
		this.maxDecimation = maxDecimation;
	}

	protected int decimationAmount;
	protected Stack<MeshTopologyOperation> vertStack = new Stack<MeshTopologyOperation>();

	public SphericalMapCorrection(double lambda, double decimationStep,
			int smoothStepIters) {
		super();
		this.lambda = lambda;
		this.decimationStep = decimationStep;
		this.smoothStepIters = smoothStepIters;
		this.maxDecimation = 1.0f;
		this.edgeMetric = new ArcLengthEdgeMetric();

		setLabel("Harmonic Map Correction");
	}

	public SphericalMapCorrection(AbstractCalculation parent, double lambda,
			double decimationStep, int smoothStepIters) {
		super(parent);
		this.lambda = lambda;
		this.decimationStep = decimationStep;
		this.smoothStepIters = smoothStepIters;
		this.maxDecimation = 1.0f;
		this.edgeMetric = new ArcLengthEdgeMetric();
		setLabel("Harmonic Map Correction");
	}

	public SphericalMapCorrection(AbstractCalculation parent) {
		super(parent);
		this.lambda = 0.8f;
		this.decimationStep = 0.02f;
		this.smoothStepIters = 3;
		this.maxDecimation = 1.0f;
		this.edgeMetric = new ArcLengthEdgeMetric();
		setLabel("Harmonic Map Correction");
	}

	public SphericalMapCorrection() {
		super();
		this.lambda = 0.8f;
		this.decimationStep = 0.02f;
		this.smoothStepIters = 3;
		this.edgeMetric = new ArcLengthEdgeMetric();
		setLabel("Harmonic Map Correction");
	}

	/**
	 * Metric to decimate spherical map derived from the discrete curvature of
	 * the spherical map.
	 * 
	 * @author Blake Lucas
	 * 
	 */
	protected class SphericalCurvatureVertexMetric implements HeapVertexMetric {
		public double evaluate(int id) {
			int[] nbrs = neighborVertexVertexTable[id];
			Point3f pivot = surf.getVertex(id);
			// If this a degenerate point, promote to top of queue
			// Degenerate points should be removed first
			double l = GeometricUtilities.length(pivot);
			if (l > 1.01 || l < 0.99) {
				// Point is not on sphere!
				System.err.println("POINT NOT ON SPHERE " + id + " " + pivot
						+ " " + l);
				return 2 * MIN_HEAP_VAL;
			}
			for (int i = 0; i < nbrs.length; i++) {
				if (pivot.distance(surf.getVertex(nbrs[i])) <= MIN_EDGE_LENGTH) {
					return MIN_HEAP_VAL;
				}
			}
			Vector3f c = getMeanCurvature(id);
			return -c.length();
		}

	}

	/**
	 * Decimate surface to remove fold-overs
	 * 
	 * @param decimateAmount
	 *            Fraction of vertices to remove.
	 * @return Number of removed vertices.
	 */
	protected int decimate(double decimateAmount) {
		int vertCount = surf.getVertexCount();
		int removePtsCount = (int) Math.ceil(decimateAmount * maxNonHarmonic);
		BinaryMinHeap heap = new BinaryMinHeap(vertCount, vertCount, 1, 1);
		int count = 0;
		for (int i = 0; i < vertCount; i++) {
			if (neighborVertexVertexTable[i].length > 0) {
				// Insert un-decimated points into heap
				VertexIndexed vox = new VertexIndexed(heapMetric(i));
				vox.setPosition(i);
				heap.add(vox);
			} else {
				// Count removed points
				count++;
			}
		}
		int v1, v2;
		boolean removed = false;
		SphericalEdgeCollapse vs = null;
		int stepCount = 0;
		double lastVal = 0;
		while (!heap.isEmpty() && (vertCount - count) > 4
				&& (stepCount < removePtsCount || lastVal <= MIN_HEAP_VAL)) {
			VertexIndexed vox = (VertexIndexed) heap.remove();
			lastVal = vox.getValue();
			v2 = vox.getRow();
			int[] nbrs = neighborVertexVertexTable[v2];
			if (nbrs.length == 0) {
				// V2 has already been removed
				continue;
			}
			vs = new SphericalEdgeCollapse();
			removed = vs.apply(v2);
			// Edge successfully removed
			if (removed) {
				vertStack.push(vs);
				for (int nbr : vs.getChangedVerts()) {
					vox = new VertexIndexed(heapMetric(nbr));
					vox.setPosition(nbr);
					heap.change(nbr, 0, 0, vox);
				}
				stepCount++;
			} else {
				// Haven't hit this statement yet, but it maybe possible
				System.err
						.println("COULD NOT REMOVE VERTEX BECAUSE OF TOPOLOGY "
								+ v2 + " " + vox.getValue());
			}
		}
		System.out.println("DECIMATION " + 100 * vertStack.size()
				/ (float) vertCount + " %");

		heap.makeEmpty();
		return vertCount - stepCount;
	}

	/*
	 * public EmbeddedSurface solve(EmbeddedSurface sourceSurf, EmbeddedSurface
	 * sourceSphere, EmbeddedSurface targetSurf) { double[][] vertData = new
	 * double[sourceSurf.getVertexCount()][3];
	 * 
	 * for (int i = 0; i < vertData.length; i++) { Point3f pt =
	 * sourceSphere.getVertex(i); vertData[i][0] = pt.x; vertData[i][1] = pt.y;
	 * vertData[i][2] = pt.z; } //Store copy of original vertex data associated
	 * with surface double[][] origVertData = sourceSurf.getVertexData();
	 * //Replace vertex data with spherical coordinates data
	 * sourceSurf.setVertexData(vertData);
	 * 
	 * vertInsertOptMethod = new VertexRadialSearchMethod(); vertInsertMetric =
	 * new LinkLengthMetric(this, sourceSurf, targetSurf);
	 * ((VertexRadialSearchMethod) vertInsertOptMethod).setAngleSteps(30);
	 * //Correct target spherical map EmbeddedSurface sphere =
	 * solve(sourceSphere); //Restore original vertex data
	 * sourceSurf.setVertexData(origVertData); return sphere; }
	 */
	/**
	 * Solve for spherical mapping using two surfaces that have spherical
	 * coordinates associated with each vertex. Offset refers to the offset into
	 * the vertex data to find and store spherical coordinates.
	 */
	public EmbeddedSurface solve(EmbeddedSurface sourceSurf,
			EmbeddedSurface targetSurf, int offset) {
		vertInsertOptMethod = new VertexRadialSearchMethod();
		vertInsertMetric = new LinkLengthMetric(this, sourceSurf, targetSurf);
		((VertexRadialSearchMethod) vertInsertOptMethod).setAngleSteps(30);
		// Extract sphere from embedded surface and solve harmonic mapping on
		// sphere
		EmbeddedSurface sphere = targetSurf.getEmbeddedSphere(offset, false);
		int vertCount = targetSurf.getVertexCount();

		sphere = solve(sphere);
		// Modify embedded surface
		double[][] errData = sphere.getVertexData();
		double[][] vertData=new double[vertCount][offset+3];
		for (int id = 0; id < vertCount; id++) {
			Vector3f v = new Vector3f(sphere.getVertex(id));
			if(errData.length>0&&errData[id].length>0){
				vertData[id][0]=errData[id][0];
			}
			vertData[id][offset]=v.x;
			vertData[id][offset+1]=v.y;
			vertData[id][offset+2]=v.z;
		}
		targetSurf.setVertexData(vertData);
		// Return corrected sphere
		return sphere;
	}

	/**
	 * Re-insert deleted vertices to restore original parameterization.
	 */
	public void tessellate() {
		setTotalUnits(vertStack.size());
		while (!vertStack.isEmpty()) {
			MeshTopologyOperation vs = vertStack.pop();
			vs.restore();
			incrementCompletedUnits();
		}
		markCompleted();
	}

	public int getDecimationAmount() {
		return decimationAmount;
	}

	/**
	 * Link length metric is used to minimize distance between surfaces that are
	 * implicitly mapped via spherical maps of both surfaces.
	 * 
	 * @author Blake Lucas
	 * 
	 */
	protected class LinkLengthMetric implements VertexInsertionMetric {
		EmbeddedSurface source, target;
		SphericalMapLocator locator;

		public LinkLengthMetric(AbstractCalculation parent,
				EmbeddedSurface source, EmbeddedSurface target) {
			this.source = source;
			this.target = target;
			locator = new SphericalMapLocator(parent, source, 0);
		}

		public double evaluate(int id) {
			Point3f sp = locator.locatePoint(surf.getVertex(id));
			Point3f tp = target.getVertex(id);
			return sp.distance(tp);
		}
	}

	public int nonHarmonicPoints(EmbeddedSurface origSurf) {
		init(origSurf, false);
		return nonHarmonicPoints();
	}

	/**
	 * Label non-harmonic points on spherical parameterization.
	 * 
	 * @param sphere
	 *            spherical map
	 * @return number of non-harmonic points
	 */
	public int labelNonHarmonicPoints(EmbeddedSurface sphere) {
		init(sphere, false);
		int nonHarmonicCount = 0;
		int vertCount = surf.getVertexCount();
		double[][] data = new double[vertCount][1];
		for (int id = 0; id < vertCount; id++) {
			nonHarmonicCount += data[id][0] = (isWoundCorrectly(id)) ? 0 : 1;
		}
		sphere.setVertexData(data);
		maxNonHarmonic = Math.max(maxNonHarmonic, nonHarmonicCount);
		System.out.println("Non-Harmonic Points " + nonHarmonicCount);
		setCompletedUnits(1 - ((maxNonHarmonic == 0) ? 0 : nonHarmonicCount
				/ (float) maxNonHarmonic));
		return nonHarmonicCount;
	}

	/**
	 * Label non-harmonic points on surface with spherical parameterization.
	 * 
	 * @param origSurf
	 *            surface with embedded spherical map
	 * @param offset
	 *            offset of spherical coordinates into vertex data
	 * @return number of non-harmonic points
	 */
	public int labelNonHarmonicPoints(EmbeddedSurface origSurf, int offset) {
		EmbeddedSurface sphere = origSurf.getEmbeddedSphere(offset, false);
		init(sphere, false);
		int nonHarmonicCount = 0;
		int vertCount = surf.getVertexCount();
		double[][] data = new double[vertCount][1];
		for (int id = 0; id < vertCount; id++) {
			nonHarmonicCount += data[id][0] = (isWoundCorrectly(id)) ? 0 : 1;
		}
		origSurf.setVertexData(data);
		maxNonHarmonic = Math.max(maxNonHarmonic, nonHarmonicCount);
		System.out.println("Non-Harmonic Points " + nonHarmonicCount);
		setCompletedUnits(1 - ((maxNonHarmonic == 0) ? 0 : nonHarmonicCount
				/ (float) maxNonHarmonic));
		return nonHarmonicCount;
	}

	public int nonHarmonicPoints(EmbeddedSurface origSurf, int offset) {
		return nonHarmonicPoints(origSurf.getEmbeddedSphere(offset, false));
	}

	public void setUseHormannMetric(EmbeddedSurface refSurface, double rho) {
		this.vertInsertMetric = new HormannVertexMetric(refSurface, rho);
	}

	public void setUseLinkLengthMetric(EmbeddedSurface source,
			EmbeddedSurface target) {
		this.vertInsertMetric = new LinkLengthMetric(this, source, target);
	}

	/**
	 * Solve for spherical mapping of surface.
	 * 
	 * @param origSphere
	 *            Uncorrected spherical map.
	 * @return Corrected spherical map.
	 */
	public EmbeddedSurface solve(EmbeddedSurface origSphere) {
		init(origSphere, false);
		int vertCount = surf.getVertexCount();
		int ptcount = vertCount;
		this.vertexMetric = new SphericalCurvatureVertexMetric();
		if (!isHarmonic()) {
			while (ptcount > 4) {
				decimate(decimationStep);
				decimationAmount = vertStack.size();
				if (isHarmonic())
					break;
				smooth(smoothStepIters);
				if (isHarmonic())
					break;
				if (decimationAmount >= vertCount * maxDecimation) {
					vertStack.empty();
					this.surf = null;
					return null;
				}
			}
			// maxNonHarmonic=100;
			// ptcount=decimate(1);
			tessellate();
			isHarmonic();
		}
		EmbeddedSurface correctedSphere = surf;
		this.surf = null;
		return correctedSphere;
	}

	/*
	 * protected Point3f getMassCenter() { // Compute center of mass Point3f
	 * pnext, pcurr; Vector3f edge1 = new Vector3f(); Vector3f edge2 = new
	 * Vector3f(); Vector3f norm = new Vector3f(); Point3f centroid = new
	 * Point3f(); Point3f massCenter = new Point3f(); int vertCount =
	 * surf.getVertexCount(); float areaSum = 0; float area; for (int id = 0; id
	 * < vertCount; id++) { int len = neighborVertexVertexTable[id].length; if
	 * (len == 0) continue; Point3f pivot = surf.getVertex(id); area = 0; for
	 * (int i = 0; i < len; i++) { pnext =
	 * surf.getVertex(neighborVertexVertexTable[id][(i + 1) % len]); pcurr =
	 * surf.getVertex(neighborVertexVertexTable[id][i]); edge1.sub(pcurr,
	 * pivot); edge2.sub(pnext, pivot); norm.cross(edge2, edge1); area = 0.5f *
	 * norm.length(); centroid = new Point3f(pivot); centroid.add(pcurr);
	 * centroid.add(pnext); centroid.scale(0.333333f * area);
	 * massCenter.add(centroid); areaSum += area; } } massCenter.scale(1.0f /
	 * areaSum); return massCenter; }
	 */
	double maxMeanCurvature = 0;
	/*
	 * protected int getClosestNeighbor(int v1) { double min_dist = 1E10, d; int
	 * v2 = -1; int[] nbrs = neighborVertexVertexTable[v1]; Point3f ref =
	 * surf.getVertex(v1); for (int i = 0; i < nbrs.length; i++) { d =
	 * ref.distance(surf.getVertex(nbrs[i])); if (d < min_dist) { v2 = nbrs[i];
	 * min_dist = d; } } return v2; }
	 */
	protected static float DELTA = 0.25f;
	protected static int MAX_ITERS = 1;
	double[] curvWeight;

	/**
	 * Smooth spherical mapping using an implicit fairing technique. See M.
	 * Desbrun, M. Meyer, P. Schroder et al., "Implicit fairing of irregular
	 * meshes using diffusion and curvature flow," 1999. for details.
	 * 
	 * @param func
	 *            weighting function
	 * @return smoothed spherical coordinate locations.
	 */
	public DenseVector[] iterateConjugateGradient(WeightVectorFunc func) {
		int vertCount = surf.getVertexCount();
		Matrix Ax = new FlexCompRowMatrix(vertCount, vertCount);
		Matrix Ay = new FlexCompRowMatrix(vertCount, vertCount);
		Matrix Az = new FlexCompRowMatrix(vertCount, vertCount);
		DenseVector pX = new DenseVector(vertCount);
		DenseVector pY = new DenseVector(vertCount);
		DenseVector pZ = new DenseVector(vertCount);
		DenseVector pnextX = new DenseVector(vertCount);
		DenseVector pnextY = new DenseVector(vertCount);
		DenseVector pnextZ = new DenseVector(vertCount);
		// Initialize matrix
		for (int i = 0; i < vertCount; i++) {
			func.populate(Ax, Ay, Az, i);
			Point3f p = surf.getVertex(i);
			pX.set(i, p.x);
			pY.set(i, p.y);
			pZ.set(i, p.z);
			pnextX.set(i, p.x);
			pnextY.set(i, p.y);
			pnextZ.set(i, p.z);
		}

		// Scale curvature matrix by update increment
		Ax.scale(-lambda);
		Ay.scale(-lambda);
		Az.scale(-lambda);
		// Add identity term
		for (int i = 0; i < vertCount; i++) {
			Ax.set(i, i, 1 + Ax.get(i, i));
			Ay.set(i, i, 1 + Ay.get(i, i));
			Az.set(i, i, 1 + Az.get(i, i));
		}

		BiCG solverX = new BiCG(pX);
		BiCG solverY = new BiCG(pY);
		BiCG solverZ = new BiCG(pZ);
		IterationMonitor im;
		try {
			solverX.solve(Ax, pX, pnextX);
			im = solverX.getIterationMonitor();
			System.out.println("X: ITERS " + im.iterations() + " RESIDUAL "
					+ im.residual());
			solverY.solve(Ay, pY, pnextY);
			im = solverY.getIterationMonitor();
			System.out.println("Y: ITERS " + im.iterations() + " RESIDUAL "
					+ im.residual());
			solverZ.solve(Az, pZ, pnextZ);
			im = solverZ.getIterationMonitor();
			System.out.println("Z: ITERS " + im.iterations() + " RESIDUAL "
					+ im.residual());
			return new DenseVector[] { pnextX, pnextY, pnextZ };
		} catch (IterativeSolverNotConvergedException e) {
			e.printStackTrace();
			return null;
		}
	}

	/**
	 * Weighting function for smoothing spherical map.
	 * 
	 * @author Blake Lucas
	 * 
	 */
	public interface WeightVectorFunc {
		public double populate(Matrix Ax, Matrix Ay, Matrix Az, int index);
	}

	/**
	 * Laplacian weighting function using uniform weights.
	 * 
	 * @author Blake Lucas
	 * 
	 */
	protected class LaplacianWeightFunc implements WeightVectorFunc {
		public double populate(Matrix Ax, Matrix Ay, Matrix Az, int id) {
			int len = neighborVertexVertexTable[id].length;
			if (len == 0)
				return 0;
			int nbr;
			double curv = curvWeight[id];
			double w = curv * 1.0 / len;
			for (int i = 0; i < len; i++) {
				nbr = neighborVertexVertexTable[id][i];
				Ax.add(id, nbr, w);
				Ay.add(id, nbr, w);
				Az.add(id, nbr, w);
			}
			Ax.add(id, id, -curv);
			Ay.add(id, id, -curv);
			Az.add(id, id, -curv);
			return curv;
		}
	}

	/**
	 * Smooth surface by alternating between implicit fairing and projection to
	 * sphere.
	 * 
	 * @param iters
	 * @return
	 */
	public EmbeddedSurface smooth(int iters) {
		System.out.println("SMOOTH " + iters + " iterations");
		double maxCurvature = 0, meanCurvature, curv;
		double max_angle = 0;
		meanCurvature = 0;
		maxCurvature = 0;
		max_angle = 0;
		Point3f p;
		DenseVector[] pts;
		int vertCount = surf.getVertexCount();
		curvWeight = new double[vertCount];
		for (int id = 0; id < vertCount; id++) {
			p = surf.getVertex(id);
			Vector3f c = getMeanCurvature(id);
			// The curvature weight would make more theoretical sense if it were
			// approximated with the method from:
			// M. Black, G. Sapiro, D. Marimont et al., "Robust anisotropic
			// diffusion," IEEE Transactions on Image Processing, vol. 7, no. 3,
			// pp. 421-432, 1998.
			curvWeight[id] = curv = Math.log(1 + c.length());
			maxCurvature = Math.max(curv, maxCurvature);
		}
		meanCurvature /= vertCount;
		for (int id = 0; id < vertCount; id++) {
			curvWeight[id] /= maxCurvature;
		}
		for (int i = 0; i < iters; i++) {
			pts = iterateConjugateGradient(new LaplacianWeightFunc());
			// Update surface point locations
			for (int id = 0; id < vertCount; id++) {
				if (curvWeight[id] > 0) {
					Point3f np = new Point3f((float) pts[0].get(id),
							(float) pts[1].get(id), (float) pts[2].get(id));
					p = surf.getVertex(id);
					surf.setVertex(id, np);
					pts[0].set(id, p.x);
					pts[1].set(id, p.y);
					pts[2].set(id, p.z);
				}
			}
			max_angle = 0;
			meanCurvature = 0;
			maxCurvature = 0;
			for (int id = 0; id < vertCount; id++) {
				p = surf.getVertex(id);
				Vector3f c = getMeanCurvature(id);
				curvWeight[id] = curv = Math.log(1 + c.length());
				maxCurvature = Math.max(curv, maxCurvature);
			}
			// Rescale volume since Laplacian will decrease volume
			System.out.println("ITERATION " + (i + 1) + " Max Curvature "
					+ maxCurvature);
			for (int id = 0; id < vertCount; id++) {
				curvWeight[id] /= maxCurvature;
			}
		}
		// Map points back to original surface
		for (int id = 0; id < vertCount; id++) {
			Vector3f v = new Vector3f(surf.getVertex(id));
			GeometricUtilities.normalize(v);
			surf.setVertex(id, v);
		}
		return surf;
	}

	/**
	 * Hormann proposed a metric that balances angle and area preservation. See
	 * K. Hormann, G. Greiner, and E.-N. U. C. G. GROUP,
	 * "MIPS: An efficient global parametrization method," ERLANGEN-NUERNBERG
	 * UNIV (GERMANY) COMPUTER GRAPHICS GROUP, 2000. for details.
	 * 
	 * @author Blake Lucas
	 * 
	 */
	protected class HormannVertexMetric implements VertexInsertionMetric {
		protected EmbeddedSurface refSurface;
		protected double rho;

		public HormannVertexMetric(EmbeddedSurface refSurface, double rho) {
			this.rho = rho;
			this.refSurface = refSurface;
		}

		public double evaluate(int id) {
			int[] nbrs = neighborVertexVertexTable[id];
			int len = nbrs.length;
			Point3f pt11 = refSurface.getVertex(id);
			Point3f pt21 = surf.getVertex(id);
			double w = 0;
			double wNorm = 0;
			for (int i = 0; i < len; i++) {
				int v2 = neighborVertexVertexTable[id][i];
				int v3 = neighborVertexVertexTable[id][(i + 1) % len];
				Point3f pt12 = refSurface.getVertex(v2);
				Point3f pt13 = refSurface.getVertex(v3);
				Point3f pt22 = surf.getVertex(v2);
				Point3f pt23 = surf.getVertex(v3);
				double a1 = GeometricUtilities.triangleArea(pt11, pt12, pt13);
				double a2 = GeometricUtilities.triangleArea(pt21, pt22, pt23);
				double alpha = GeometricUtilities.cotAngle(pt22, pt23, pt21);
				double beta = GeometricUtilities.cotAngle(pt23, pt21, pt22);
				double gamma = GeometricUtilities.cotAngle(pt21, pt22, pt23);
				double a = GeometricUtilities.distance(pt13, pt12);
				double b = GeometricUtilities.distance(pt11, pt13);
				double c = GeometricUtilities.distance(pt11, pt12);
				double wAng = ((a1 > 1E-10 && a > 1E-6 && b > 1E-6 && c > 1E-6) ? ((alpha
						* a * a + beta * b * b + gamma * c * c) / a1)
						: 0);
				double wArea = ((a1 > 1E-10) ? a2 / a1 - 1 : 0);
				w += a2 * Math.pow(wArea, rho) * wAng;
				wNorm += a2;
			}
			if (wNorm > 0) {
				w /= wNorm;
			}
			return w;
		}
	}
}
