package edu.jhu.ece.iacl.algorithms.skull_strip;

import javax.vecmath.Point3f;
import javax.vecmath.Vector3f;

import edu.jhmi.rad.medic.libraries.Morphology;
import edu.jhmi.rad.medic.utilities.CropParameters;
import edu.jhmi.rad.medic.utilities.CubicVolumeCropper;
import edu.jhmi.rad.medic.utilities.Numerics;
import edu.jhu.ece.iacl.algorithms.graphics.GeometricUtilities;
import edu.jhu.ece.iacl.algorithms.graphics.edit.Tessellate;
import edu.jhu.ece.iacl.algorithms.graphics.isosurf.IsoSurfaceOnGrid;
import edu.jhu.ece.iacl.algorithms.graphics.smooth.MultiResolutionImplicitSmoothing;
import edu.jhu.ece.iacl.algorithms.graphics.smooth.SmoothAndRegularize;
import edu.jhu.ece.iacl.algorithms.graphics.smooth.MultiResolutionImplicitSmoothing.WeightVectorFunc;
import edu.jhu.ece.iacl.algorithms.graphics.smooth.SmoothAndRegularize.Method;
import edu.jhu.ece.iacl.algorithms.graphics.surf.HierarchicalSurface;
import edu.jhu.ece.iacl.algorithms.graphics.topo.ConnectedSurfaceComponents;
import edu.jhu.ece.iacl.algorithms.graphics.utilities.SurfaceToMask;
import edu.jhu.ece.iacl.algorithms.graphics.utilities.quickhull.QuickHull3D;
import edu.jhu.ece.iacl.algorithms.gvf.FastMarchingGradient;
import edu.jhu.ece.iacl.algorithms.gvf.FastMarchingGradient.Normalization;
import edu.jhu.ece.iacl.algorithms.tgdm.GenericTGDM;
import edu.jhu.ece.iacl.algorithms.topology.ConnectivityRule;
import edu.jhu.ece.iacl.algorithms.topology.TopologyCorrection;
import edu.jhu.ece.iacl.algorithms.topology.TopologyCorrection.PropagationTypes;
import edu.jhu.ece.iacl.algorithms.volume.DistanceField;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.geom.EmbeddedSurface;
import edu.jhu.ece.iacl.jist.structures.geom.NormalGenerator;
import edu.jhu.ece.iacl.jist.structures.geom.Surface;
import edu.jhu.ece.iacl.jist.structures.geom.EmbeddedSurface.Direction;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;

import gov.nih.mipav.model.structures.ModelImage;

import no.uib.cipr.matrix.Matrix;

/**
 * @author Blake Lucas (bclucas@jhu.edu)
 */
public class SmoothBrainMask extends MultiResolutionImplicitSmoothing {
	private static final String cvsversion = "$Revision: 1.1 $";
	public static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");

	public static String get_version() {
		return revnum;
	}


	protected ImageDataFloat levelSet;
	protected ImageData maskedBrain;
	protected ImageData newBrainMask;
	protected ModelImage levelSetImg;
	protected float maxDist = 20, scaleWeight = 1;
	protected float[][][][] gradient;

	public SmoothBrainMask() {
		super(15, 0.5f, 200, 1);
		setLabel("Shrink Wrap Level Set");
	}

	public ImageDataFloat getLevelSet() {
		return levelSet;
	}
	public ImageData getMaskedBrain(){
		return maskedBrain;
	}
	public ImageData getBrainMask(){
		return newBrainMask;
	}
	public EmbeddedSurface solve(ImageData brainmask, ImageData origData, float isoLevel,float maxCurv,float translation) {
		DistanceField df = new DistanceField(this);
		CubicVolumeCropper cropper=new CubicVolumeCropper();
		origData=cropper.crop(origData,0,3);
		CropParameters params=cropper.getLastCropParams();
		brainmask=cropper.crop(brainmask,params);
		ImageHeader header=origData.getHeader().clone();
		levelSet = new ImageDataFloat(brainmask);
		int rows = levelSet.getRows();
		int cols = levelSet.getCols();
		int slices = levelSet.getSlices();
		float[][][] levelSetMat = levelSet.toArray3d();
		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					levelSetMat[i][j][k] = isoLevel - levelSetMat[i][j][k];
				}
			}
		}
		levelSet = df.solve(levelSet, maxDist);
		levelSetMat = levelSet.toArray3d();
		levelSet.setName(brainmask.getName() + "_speed");
		System.out.println("DIMENSIONS " + rows + ", " + cols + ", " + slices);
		levelSetMat = levelSet.toArray3d();

		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					levelSetMat[i][j][k] = 1
							- Math.max(0, levelSetMat[i][j][k]) / maxDist;
				}
			}
		}

		TopologyCorrection tc = new TopologyCorrection();
		this.levelSet = tc.solve(levelSet, null, isoLevel,
				ConnectivityRule.CONNECT_6_18,
				PropagationTypes.BACKGROUND_TO_OBJECT);
		levelSetMat = levelSet.toArray3d();

		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					levelSetMat[i][j][k] = (1-levelSetMat[i][j][k]);
				}
			}
		}
		IsoSurfaceOnGrid iso = new IsoSurfaceOnGrid();
		EmbeddedSurface mesh = iso.solve(levelSet,ConnectivityRule.CONNECT_18_6, 0, false);
		int genus=EmbeddedSurface.getGenus(mesh);
		if(genus!=0){
			maskedBrain=new ImageDataFloat(levelSetMat);
			maskedBrain.setHeader(header);
			maskedBrain=cropper.uncrop(maskedBrain,params);
			maskedBrain.setName(brainmask.getName()+"_strip");
			
			newBrainMask=new ImageDataFloat(levelSetMat);
			newBrainMask.setHeader(header);
			newBrainMask=cropper.uncrop(newBrainMask,params);
			newBrainMask.setName(origData.getName()+"_brainmask");		
			return mesh;
		}
		ConnectedSurfaceComponents csc=new ConnectedSurfaceComponents(mesh);
		mesh=csc.solve().get(0);
		System.out.println("LARGEST COMPONENT GENUS "+EmbeddedSurface.getGenus(mesh));
		MultiResolutionImplicitSmoothing mrs=new MultiResolutionImplicitSmoothing(maxCurv,0.8f,5,5);
		mesh=mrs.solve(mesh);
		mesh.setName(origData.getName()+"_mask");
		SurfaceToMask surf2vol=new SurfaceToMask();
		ImageDataFloat maskVol=surf2vol.solve(origData, mesh, 1, translation);
		float[][][] maskMat=maskVol.toArray3d();
		float[][][] resultMat=new float[rows][cols][slices];
		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					resultMat[i][j][k] = (float)((maskMat[i][j][k]>0.5)?origData.getFloat(i, j,k):0);
					maskMat[i][j][k]=(maskMat[i][j][k]>0.5)?1:0;
				}
			}
		}
		maskedBrain=new ImageDataFloat(resultMat);
		maskedBrain.setHeader(header);
		maskedBrain=cropper.uncrop(maskedBrain,params);
		maskedBrain.setName(brainmask.getName()+"_strip");
		
		newBrainMask=new ImageDataFloat(maskMat);
		newBrainMask.setHeader(header);
		newBrainMask=cropper.uncrop(newBrainMask,params);
		newBrainMask.setName(origData.getName()+"_brainmask");
		
		Point3f offset = new Point3f(params.xmin, params.ymin, params.zmin);
		mesh.translate(offset);
		return mesh;
	}

	protected double speedFunc(Point3f p, Vector3f norm) {
		float levelSet = levelSetImg.getFloatTriLinearBounds(p.x, p.y, p.z);
		Vector3f grad = FastMarchingGradient.interpolate(p, gradient);
		GeometricUtilities.normalize(grad);
		System.out.println("GRADIENT " + grad + " NORMAL " + norm);
		double speed = 2 * (1.0 / (1 + Math.exp(-scaleWeight * levelSet)) - 0.5);
		// grad.scale((float)speed);
		// double force=grad.dot(norm);
		// return force;
		return speed;
	}

	protected class ShrinkWrapWeightFunc implements WeightVectorFunc {
		public double populate(Matrix Ax, Matrix Ay, Matrix Az, int id) {
			int len = neighborVertexVertexTable[id].length;
			if (len == 0)
				return 0;
			int nbr;

			double speed = 1;// speedFunc(surf.getVertex(id));
			double curv = -0.25 * speed;// weights[id];
			// System.out.println("SPEED "+curv);
			double w = curv * 1.0 / len;
			// Point3f pivot = surf.getVertex(id);
			// double speed = ;
			for (int i = 0; i < len; i++) {
				nbr = neighborVertexVertexTable[id][i];
				Ax.add(id, nbr, w);
				Ay.add(id, nbr, w);
				Az.add(id, nbr, w);
			}
			Ax.add(id, id, -curv);
			Ay.add(id, id, -curv);
			Az.add(id, id, -curv);
			return curv;
		}
	}

	protected EmbeddedSurface shrink(EmbeddedSurface mesh, int maxIters,
			double relaxation) {
		int[][] nbhdTable = EmbeddedSurface.buildNeighborVertexVertexTable(
				mesh, Direction.COUNTER_CLOCKWISE);
		int[] nbrs;
		Point3f pivot, p;
		Point3f pnbr;
		double dsum, d;
		double speed;
		Vector3f norm = new Vector3f();
		Vector3f grad;
		for (int iter = 0; iter < maxIters; iter++) {
			Vector3f[] norms = NormalGenerator.generate(mesh);
			for (int i = 0; i < nbhdTable.length; i++) {
				norm = norms[i];
				pivot = mesh.getVertex(i);

				speed = speedFunc(mesh.getVertex(i), norm);

				norm.scale(-0.25f * (float) (speed - 0.5));

				pivot.add(norm);
				mesh.setVertex(i, pivot);
			}
		}
		return mesh;
	}

}
