package edu.jhu.ece.iacl.algorithms.gvf;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.structures.image.ImageDataMath;


/*
 * Java implementation of the C GVF MultiGrid approach.
 *
 * @author Xiao Han
 * @author Aaron Carass
 * @author Min Chen
 *
 */
public class GradVecFlowOptimized extends GradVecFlow {
	private static final GradVecFlowOptimized gvf = new GradVecFlowOptimized();

	public GradVecFlowOptimized(AbstractCalculation parent) {
		super(parent);
	}


	public GradVecFlowOptimized() {
		super();
	}


	public static ImageData doSolve(ImageData f, float lambda, boolean normalize, int iter){
		System.out.print("\n\nGVF Optimized -- doSolve()\n\n");
		return gvf.solve(f, lambda, normalize, iter);
	}


	/* Solve Poisson Equation using Multigrid techniques.
	 * The equation to solve is given by
	 * div( g(x) \nabla u(x)) - h(x) (u - f(x)) = 0, which
	 * is the GGVF equation
	 * For convenience, in the program, g(x) is called weight,
	 * h is called hu
	 */

	boolean JACOBI = false;

	static int maxlevel;
	static int n1, n2;

	float[][][][] inField;
	ImageDataFloat inImage;


	public ImageDataFloat process(ImageDataFloat In, float lambda, boolean normalize, int iter){
		System.out.print("\n\nGVF Optimized -- process() start\n\n");
		int XN = In.getRows();
		int YN = In.getCols();
		int ZN = In.getSlices();
		int i, j, k, c;

		String name = In.getName()+"_gvf";
		System.out.print("\nGVF Optimized -- Output Name " + name + "\n\n");

		this.inField = new float[ZN][YN][XN][3];
		this.inImage = In;


		/*
		 * THESE FUNCTIONS RUN ON C INDEXING FOR VOLUMES,
		 * [Z][Y][X], SO WE MUST CONVERT INPUT FIRST:
		 *
		 */

		for (i = 0; i < XN; i++)
			for (j = 0; j < YN; j++)
				for (k = 0; k < ZN; k++)
					for (c = 0; c < 3; c++){
						this.inField[k][j][i][c] = In.getFloat(i, j, k, c);
					}



		float[][][][] outField = new float[ZN][YN][XN][3];
		double x, y, z;


		float[][][] SqrMagf = new float[ZN][YN][XN];

		System.out.print("\n\nGVF Optimized -- process() Square Magnitude Image");
		for (i = 0; i < XN; i++)
			for (j = 0; j < YN; j++)
				for (k = 0; k < ZN; k++){
					x = inField[k][j][i][0];
					y = inField[k][j][i][1];
					z = inField[k][j][i][2];

					SqrMagf[k][j][i] = (float)Math.sqrt(x*x + y*y + z*z);
				}
		System.out.print(" ... done\n\n");


		outField = scaleFlowGen(outField, inField, SqrMagf, lambda, ZN, YN, XN);


		/*
		 * HAVE TO UNDO THE [Z][Y][X]
		 *
		 * CONVERTING THE OUTPUT
		 */

		ImageDataFloat outImage = new ImageDataFloat(XN, YN, ZN, 3);

		for (i = 0; i < XN;i++)
			for (j = 0; j < YN;j++)
				for (k = 0; k < ZN;k++)
					for (c = 0; c < 3;c++){
						outImage.set(i,j,k,c, outField[k][j][i][c]);
					}


		if (normalize) {
			outImage.normalize();
		}
		outImage.setName(name);


		System.out.print("\n\nGVF Optimized -- process() finish\n\n");
		return outImage;
	}


	public ImageData solve(ImageData image, float lambda, boolean normalize, int iter){
		System.out.print("\n\nGVF Optimized -- solve() start\n\n");
		ImageDataFloat floatImage = new ImageDataFloat(image);
		int XN = image.getRows();
		int YN = image.getCols();
		int ZN = image.getSlices();
		int i, j, k, c;

		float imageMax = Float.MIN_VALUE;
		float imageMin = Float.MAX_VALUE;


		for (i = 0; i < XN; i++){
			for (j = 0; j < YN; j++){
				for (k = 0; k < ZN; k++){
					imageMax = Math.max(imageMax, image.getFloat(i, j, k));
					imageMin = Math.min(imageMin, image.getFloat(i, j, k));
				}
			}
		}


		for (i = 0; i < XN; i++){
			for (j = 0; j < YN; j++){
				for (k = 0; k < ZN; k++){
					floatImage.set(i, j, k, (image.getFloat(i, j, k) - imageMin)/(imageMax - imageMin));
				}
			}
		}


		ImageDataFloat V;
		ImageDataFloat gradientImage;
		gradientImage = Gradient3d.doSolve(floatImage);
		gradientImage.setName(image.getName());

		V = process(gradientImage, lambda, normalize, iter);

		System.out.print("\n\nGVF Optimized -- solve() finish\n\n");
		return V;
	}


	float[][][][] scaleFlowGen(float[][][][] V, float[][][][] Gf, float[][][] SqrMagf, float lambda, int ZN, int YN, int XN)
	{
		int i, x, y, z;
		double sqrmagf;
		float lambda2;
		float[][][] U, f, g, h;

		lambda2 = lambda*lambda;

		U = new float[ZN][YN][XN];
		f = new float[ZN][YN][XN];
		g = new float[ZN][YN][XN];
		h = new float[ZN][YN][XN];


		for (z = 0; z < ZN; z++)
			for (y = 0; y < YN; y++)
				for (x = 0; x < XN; x++){
					sqrmagf = SqrMagf[z][y][x];
					g[z][y][x] = (float)Math.exp(-sqrmagf/lambda2);
					h[z][y][x] = 1 - g[z][y][x];
					g[z][y][x] /= 6.0;
					f[z][y][x] = 0 - h[z][y][x]*Gf[z][y][x][0];
					U[z][y][x] = Gf[z][y][x][0];
				}


		//printf("Solve the x-component\n");
		multigrid(U, f, g, h, XN, YN, ZN);

		for (z = 0; z < ZN; z++)
			for (y = 0; y < YN; y++)
				for (x = 0; x < XN; x++){
					V[z][y][x][0] = U[z][y][x];
					f[z][y][x] = 0 - h[z][y][x]*Gf[z][y][x][1];
					U[z][y][x] = Gf[z][y][x][1];
				}

		//printf("Solve the y-component\n");
		multigrid(U, f, g, h, XN, YN, ZN);
		for (z = 0; z < ZN; z++)
			for (y = 0; y < YN; y++)
				for (x = 0; x < XN; x++){
					V[z][y][x][1] = U[z][y][x];
					f[z][y][x] = 0- h[z][y][x]*Gf[z][y][x][2];
					U[z][y][x] = Gf[z][y][x][2];
				}

		//printf("Solve the z-component\n");
		multigrid(U, f, g, h, XN, YN, ZN);
		for (z = 0; z < ZN; z++)
			for (y = 0; y < YN; y++)
				for (x = 0; x < XN; x++){
					V[z][y][x][2] = U[z][y][x];
				}

		return V;
	}


	void rstrct(float[][][] out, float[][][] in, int Xo, int Yo, int Zo,
			int Xi, int Yi, int Zi){
		int inew, jnew, knew, iold, jold, kold;

		for (knew = 0; knew < Zo; knew++)
			for (inew = 0; inew < Yo; inew++)
				for (jnew =0; jnew < Xo; jnew++){
					kold = knew<<1; iold = inew <<1; jold = jnew<<1;
					if (kold > (Zi-2)) kold = Zi-2;
					if (iold > (Yi-2)) iold = Yi-2;
					if (jold > (Xi-2)) jold = Xi-2;

					out[knew][inew][jnew] = (float) 0.125*(in[kold][iold][jold]+in[kold][iold][jold+1]+in[kold][iold+1][jold]+in[kold][iold+1][jold+1] +
							in[kold+1][iold][jold]+in[kold+1][iold][jold+1]+in[kold+1][iold+1][jold]+in[kold+1][iold+1][jold+1]);
				}

		return;
	}


	void interp(float[][][] out, float[][][] in, int Xo, int Yo, int Zo,
			int Xi, int Yi, int Zi){

		int ic, jc, kc, i, j, k;

		for (kc = 0; kc < Zo; kc++)
			for (ic = 0; ic < Yo; ic++)
				for (jc = 0; jc < Xo; jc++){
					k = kc >> 1; j = jc >> 1; i = ic >> 1;
					out[kc][ic][jc] = in[k][i][j];
				}
		return;
	}


	void addin(float[][][] uf, float[][][] uc, float[][][] res, int Xo, int Yo, int Zo, int Xi, int Yi, int Zi){
		int i, j, k;

		interp(res, uc, Xo, Yo,Zo, Xi, Yi, Zi);

		for (k = 0; k < Zo; k++){
			for (i = 0; i < Yo; i++){
				for (j = 0; j < Xo; j++){
					uf[k][i][j] += res[k][i][j];
				}
			}
		}
		return;
	}


	void resid(float[][][] res, float[][][] u, float[][][] rhs, float[][][] w, float[][][] hu, int XN, int YN, int ZN, int lev){
		int i,j,k;
		int h;
		float cmu;

		int prek, nextk, prei, nexti, prej, nextj;

		h = 1<<(maxlevel-lev);
		h = h*h;
		cmu = (float)1.0/(float)h;


		/*
		 * If change to half-point symmetric extension, no longer
		 * converges
		 *
		 */

		for (k = 0; k < ZN; k++){
			prek = (k>0) ? (k-1) : 0;
			nextk = (k == (ZN-1)) ? (ZN-1) : (k+1);
			for (i = 0; i < YN; i++){
				prei = (i>0) ? (i-1) : 0;
				nexti = (i == (YN-1)) ? (YN-1) : (i+1);

				for (j = 0;j < XN; j++){
					prej = (j>0) ? (j-1) : 0;
					nextj = (j == (XN-1)) ? (XN-1) : (j+1);

					res[k][i][j] = hu[k][i][j]*u[k][i][j] + rhs[k][i][j] -
					cmu*w[k][i][j]*
					(u[prek][i][j] +
							u[nextk][i][j] +
							u[k][prei][j] +
							u[k][nexti][j] +
							u[k][i][prej] +
							u[k][i][nextj] -
							u[k][i][j]*6);
				}
			}
		}

		return;

	}


	void copymem(float[][][] out, float[][][] in, int XN, int YN, int ZN){
		int size;
		int i, j, k;


		for (k = 0; k < ZN; k++)
			for (i = 0; i < YN; i++)
				for (j = 0; j < XN; j++){
					out[k][i][j] = in[k][i][j];
				}

		return;
	}


	void octant(float[][][] u, float[][][] rhs, float[][][] w, float[][][] hu, float hsquare, int XN, int YN, int ZN, int kc, int ic, int jc){

		int k, i, j;
		int prek, nextk, prei,nexti, prej, nextj;
		float tmpv;

		/*
		 * hsquare = 2*hsquare;
		 *
		 * For the first case that g is inside div()
		 *
		 */

		for (k = kc; k < ZN; k += 2 ){
			prek = (k>0) ? (k-1) : 0;
			nextk = (k == (ZN-1)) ? (ZN-1) : (k+1);
			for (i = ic; i < YN; i+= 2){
				prei = (i>0) ? (i-1) : 0;
				nexti = (i == (YN-1)) ? (YN-1) : (i+1);

				for (j = jc; j < XN; j+= 2){
					prej = (j>0) ? (j-1) : 0;
					nextj = (j == (XN-1)) ? (XN-1) : (j+1);

					tmpv = hu[k][i][j]*hsquare;

					u[k][i][j] = ((u[prek][i][j] +
							u[nextk][i][j] + u[k][prei][j] +
							u[k][nexti][j] + u[k][i][prej] +
							u[k][i][nextj]) * w[k][i][j]
							- rhs[k][i][j]*hsquare)/(6*w[k][i][j] + tmpv);
				}
			}
		}

		return;
	}



	void Jacobi(float[][][] u, float[][][] rhs, float[][][] w,float[][][] hu, int lev, int XN, int YN, int ZN, int iter){

		float[][][] tmpvol;

		int k, i, j, index, h;
		int prek, nextk, prei,nexti, prej, nextj;

		float alpha = 0.9f; /* alpha has to be large to converge in 3D */
		float hsquare, sum;

		float tmpv;

		h = 1 << (maxlevel - lev);
		hsquare = 2*h*h;

		tmpvol = new float[ZN][YN][XN];


		for (index = 1; index <= iter; index++){
			for (k = 0; k < ZN; k++){
				prek = (k>0) ? (k-1) : 0;
				nextk = (k == (ZN-1)) ? (ZN-1) : (k+1);

				for (i = 0; i < YN; i++){
					prei = (i>0) ? (i-1) : 0;
					nexti = (i == (YN-1)) ? (YN-1) : (i+1);


					for (j = 0; j < XN; j++){
						prej = (j > 0) ? (j-1) : 0;
						nextj = (j == (XN - 1)) ? (XN - 1) : (j + 1);

						tmpv = hu[k][i][j] * hsquare;

						tmpvol[k][i][j] = (1-alpha)*u[k][i][j] 
							+ alpha*(u[prek][i][j]*(w[prek][i][j]+w[k][i][j])
							+ u[nextk][i][j]*(w[nextk][i][j]+w[k][i][j])
							+ u[k][prei][j]*(w[k][prei][j]+w[k][i][j])
							+ u[k][nexti][j]*(w[k][nexti][j]+w[k][i][j])
							+ u[k][i][prej]*(w[k][i][prej]+w[k][i][j])
							+ u[k][i][nextj]*(w[k][i][nextj]+w[k][i][j])
							- rhs[k][i][j]*hsquare)/(6*w[k][i][j]
							+ w[prek][i][j] + w[nextk][i][j]
							+ w[k][prei][j] + w[k][nexti][j]
							+ w[k][i][prej] + w[k][i][nextj] + tmpv);

					}
				}
			}


			for (k = 0; k < ZN; k++)
				for (i = 0;i < YN; i++)
					for (j = 0;j < XN; j++){
						u[k][i][j] = tmpvol[k][i][j];
					}
		}

		return;
	}


	void Gauss_Seidel(float[][][] u, float[][][] rhs, float[][][] weight, float[][][] hu, int lev, int XN, int YN, int ZN, int iter){
		int h, i;
		float hsquare;

		h = 1 << (maxlevel - lev);
		hsquare = h*h;

		if (XN == 1 || YN == 1 || ZN == 1){
			System.out.format("OOPS!!! PROBLEM IN Gauss_Seidel USING octant\n");
			return;
		}

		for (i = 1; i <= iter; i++){
			/*
			 * Process black first like follows gives slightly better
			 * convergence.
			 *
			 */
			octant(u,rhs,weight,hu,hsquare, XN, YN, ZN, 0, 0, 0);
			octant(u,rhs,weight,hu,hsquare, XN, YN, ZN, 0, 1, 1);
			octant(u,rhs,weight,hu,hsquare, XN, YN, ZN, 1, 0, 1);
			octant(u,rhs,weight,hu,hsquare, XN, YN, ZN, 1, 1, 0);
			octant(u,rhs,weight,hu,hsquare, XN, YN, ZN, 1, 0, 0);
			octant(u,rhs,weight,hu,hsquare, XN, YN, ZN, 0, 0, 1);
			octant(u,rhs,weight,hu,hsquare, XN, YN, ZN, 0, 1, 0);
			octant(u,rhs,weight,hu,hsquare, XN, YN, ZN, 1, 1, 1);
		}

		return;
	}


	void slvsml(float[][][] u, float[][][] rhs, float[][][] weight, float[][][] hu, int XN, int YN, int ZN){
		int i,j,k;
		float sum;

		for (k = 0; k < ZN; k++)
			for (i = 0; i < YN; i++)
				for (j = 0;j < XN; j++){
					u[k][i][j] = 0;
				}



		if (JACOBI) {
			Jacobi(u, rhs, weight, hu, 1, XN, YN, ZN, 2);
		} else {
			Gauss_Seidel(u, rhs, weight, hu, 1, XN, YN, ZN, 2);
		}
	}


	void multigrid(float[][][] u, float[][][] f, float[][][] weight, float[][][] hu, int XN, int YN, int ZN){
		/*
		 * u is the solution for
		 * div( weight(x) \nabla u(x)) - hu(x) ( u(x) - f(x)) = 0;
		 *
		 */

		int m, n, l, mf;
		int iters;
		int[] mo, no, lo;

		int j, jcycle, jj;
		int msize, nsize, lsize;

		float maxerr = 10, tmpv;

		float[][][][] ires, irho, irhs, iu, iwei, ih;

		float[][][] uold;

		/* Computer maximum level */
		m = 0;
		mf = ZN;
		while((mf >> (m+1)) > 0) m++;
		maxlevel = m;
		m = 0;
		mf = YN;
		while((mf >> (m+1)) > 0) m++;
		if (maxlevel > m){
			maxlevel = m;
		}
		m = 0;
		mf = XN;

		while ((mf >> (m+1)) > 0){
			m++;
		}

		if (maxlevel > m){
			maxlevel = m;
		}

		System.out.format("GVF Maxlevel:" + maxlevel +"\n");

		/* Set pre-smoothing steps */
		n1 = 2;
		/* Set post-smoothing steps */
		n2 = 2;

		ires =new float[maxlevel+1][][][];
		irho =new float[maxlevel+1][][][];
		irhs =new float[maxlevel+1][][][];

		iu =new float[maxlevel+1][][][];
		iwei =new float[maxlevel+1][][][];
		ih =new float[maxlevel+1][][][];

		mo =new int[maxlevel+1];
		no =new int[maxlevel+1];
		lo =new int[maxlevel+1];

		uold =new float[ZN][YN][XN];
		msize = ZN;
		nsize = YN;
		lsize = XN;


		for (j = maxlevel;j >= 1; j--){
			ires[j] = new float[msize][nsize][lsize];
			irho[j] = new float[msize][nsize][lsize];
			irhs[j] = new float[msize][nsize][lsize];
			iu[j] = new float[msize][nsize][lsize];
			ih[j] = new float[msize][nsize][lsize];
			iwei[j] = new float[msize][nsize][lsize];

			mo[j] = msize;
			no[j] = nsize;
			lo[j] = lsize; 

			msize = (msize + 1) >> 1;
			nsize = (nsize + 1) >> 1;
			lsize = (lsize + 1) >> 1;
		}


		for (m = 0; m < ZN; m++)
			for (n = 0; n < YN; n++)
				for (l = 0; l < XN; l++){
					iwei[maxlevel][m][n][l] = weight[m][n][l];
					ih[maxlevel][m][n][l] = hu[m][n][l];
					uold[m][n][l] = 0.0f;
				}

		for (j = maxlevel - 1; j >= 1; j--){
			rstrct(iwei[j], iwei[j+1], lo[j], no[j], mo[j], lo[j+1], no[j+1], mo[j+1]);
			rstrct(ih[j], ih[j+1], lo[j], no[j], mo[j], lo[j+1], no[j+1], mo[j+1]);
		}


		/* Now start multigrid processing */
		for (iters = 1; iters <= 10; iters++){
			/* Several iterations of fmgv */
			/* Compute the initial residue */
			resid(irho[maxlevel], u, f, iwei[maxlevel], ih[maxlevel], XN,YN,ZN, maxlevel);

			maxerr = 0.0f;
			for (m = 0; m < ZN; m++)
				for (n = 0; n < YN; n++)
					for (l = 0; l < XN; l++){
						tmpv = irho[maxlevel][m][n][l];
						if (tmpv < 0) tmpv = 0 - tmpv;
						if (maxerr < tmpv) maxerr = tmpv;
					}


			if (maxerr < 0.0001) break;

			for (j = maxlevel - 1; j >= 1; j--){
				rstrct(irho[j], irho[j+1], lo[j], no[j], mo[j], lo[j+1], no[j+1], mo[j+1]);
			}

			/*Now solve the PDE at level 1, the coarsest level */
			slvsml(iu[1], irho[1], iwei[1], ih[1], lo[1], no[1], mo[1]);
			/* Initial solution on coarest grid */


			/* Now start Full-Multigrid V or W cycle */
			for (j = 2; j <= maxlevel; j++){

				interp(iu[j],iu[j-1], lo[j], no[j], mo[j], lo[j-1], no[j-1],mo[j-1]); /* Get the initial guess of the solution of the original eq */

				copymem(irhs[j], irho[j],lo[j], no[j],mo[j]); /* Set up right hand side */
				/* Now begin mgv.m 12-9-03 */
				for (jcycle = 1; jcycle <= 1; jcycle++){
					for (jj = j; jj >= 2; jj--){ /* Down stroke of the V */
						if (JACOBI) {
							Jacobi(iu[jj], irhs[jj], iwei[jj], ih[jj], jj, lo[jj],no[jj],mo[jj], n1);
						} else {
							Gauss_Seidel(iu[jj], irhs[jj], iwei[jj], ih[jj], jj, lo[jj],no[jj],mo[jj], n1);
						}

						resid(ires[jj], iu[jj], irhs[jj], iwei[jj], ih[jj],lo[jj],no[jj],mo[jj],jj); /* Defect */

						/* Does the following initialization only need to be done at jj = 2
						 * or doesn't need to be done at all ?? No! It need to be done
						 * at all level, otherwise, the addin need to be changed!
						 * Why not get rid of += in addin, and get rid of all these
						 * initialization?? Am I missing something?? Indeed, no change
						 * should be made here!
						 */
						for (m = 0; m < mo[jj-1]; m++) /*loop over image */
							for (n = 0; n < no[jj-1]; n++)
								for (l = 0; l < lo[jj-1]; l++){
									iu[jj-1][m][n][l] = 0; /* Initial value for errors at levels from j-1 to 1 which are iu[j-1] ~ iu[1] and are zeros */
								}

						rstrct(irhs[jj-1], ires[jj], lo[jj-1], no[jj-1], mo[jj-1], lo[jj], no[jj], mo[jj]); /* mf, nf is the size of the first parameter */
					}


					slvsml(iu[1], irhs[1], iwei[1], ih[1], lo[1], no[1], mo[1]); /* Bottom of the V */
					/* Now iu[1] stores the solution of the error at the coarest level */

					for (jj = 2; jj <= j; jj++){ /*Upard stroke of V */
						addin(iu[jj], iu[jj-1], ires[jj], lo[jj], no[jj], mo[jj], lo[jj-1], no[jj-1], mo[jj-1]);
						/*ires[jj] is used for temporary storage inside addint */

						/* Post-smooyhing */
						if (JACOBI)
							Jacobi(iu[jj], irhs[jj], iwei[jj], ih[jj], jj, lo[jj], no[jj], mo[jj], n1);
						else
							Gauss_Seidel(iu[jj], irhs[jj], iwei[jj], ih[jj], jj, lo[jj], no[jj], mo[jj], n2);
					}
				}

			}


			/* Update solution */
			maxerr = 0;

			for (m = 0; m < ZN; m++)
				for (n = 0; n < YN; n++)
					for (l = 0; l < XN; l++){
							u[m][n][l] += iu[maxlevel][m][n][l];
					}
		}
		/*
		 * end of for loop over iters
		 *
		 */


		for (j = maxlevel; j >= 1; j--){
			msize = mo[j];
			nsize = no[j];
			lsize = lo[j];
		}
		return;
	}
}
