package edu.jhu.ece.iacl.algorithms.gvf;


import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;


/**
 * B-Spline interpolation.
 * 
 * @author Philippe Thevenaz
 * @author Aaron Carass
 * 
 */
public class Interpolate3dBSpline implements Interpolate3d {
	private final double DBL_EPSILON = 2.2204460492503131e-16;
	public static int splineDegree = -1;

	private static final Interpolate3dBSpline interpolate = new Interpolate3dBSpline();

	public static ImageDataFloat doSolve(ImageDataFloat in, int xo, int yo, int zo) {
		return interpolate.solve(in, xo, yo, zo);
	}


	public int getSplineDegree(){
		return(interpolate.splineDegree);
	}


	public void setSplineDegree(int s){
		interpolate.splineDegree = s;
	}


	/**
	 * Do the B-Spline Interpolation
	 *
	 * @param Bcoeff        B-spline array of coefficients
	 * @param Width         Width of the image
	 * @param Height        Height of the image
	 * @param Depth         Depth of the image
	 * @param x             x coordinate where to interpolate
	 * @param y             y coordinate where to interpolate
	 * @param z             z coordinate where to interpolate
	 * @param SplineDegree  Degree of the B-Spline
	 */
	public float interpolateBSpline(double[][][] Bcoeff, int Width, int Height, int Depth, double x, double y, double z, int SplineDegree) {
		double p = 0.0, q = 0.0;
		double xWeight[] = new double[6];
		double yWeight[] = new double[6];
		double zWeight[] = new double[6];
		double interpolated = 0.0;
		double w, w2, w4, t, t0, t1;

		long xIndex[] = new long[6];
		long yIndex[] = new long[6];
		long zIndex[] = new long[6];
		long Width2  = 2L * Width - 2L;
		long Height2 = 2L * Height - 2L;
		long Depth2  = 2L * Depth - 2L;

		/**
		 * @param i  Index for x
		 * @param j  Index for y
		 * @param k  Index for z
		 * @param s  Index for Spline degree
		 */
		int i, j, k, s;


		if ((SplineDegree % 2) == 1) {
			/**
			 * Odd Degree
			 */
			i = (int) Math.floor(x) - SplineDegree / 2;
			j = (int) Math.floor(y) - SplineDegree / 2;
			k = (int) Math.floor(z) - SplineDegree / 2;
		} else {
			/**
			 * Even Degree
			 */
			i = (int) Math.floor(x + 0.5) - SplineDegree / 2;
			j = (int) Math.floor(y + 0.5) - SplineDegree / 2;
			k = (int) Math.floor(z + 0.5) - SplineDegree / 2;
		}


		for (s = 0; s <= SplineDegree; s++) {
			xIndex[s] = i++;
			yIndex[s] = j++;
			zIndex[s] = k++;
		}


		switch (SplineDegree) {
			case 2:
				/* x */
				w = x - (double)xIndex[1];
				xWeight[1] = 3.0 / 4.0 - w * w;
				xWeight[2] = (1.0 / 2.0) * (w - xWeight[1] + 1.0);
				xWeight[0] = 1.0 - xWeight[1] - xWeight[2];

				/* y */
				w = y - (double)yIndex[1];
				yWeight[1] = 3.0 / 4.0 - w * w;
				yWeight[2] = (1.0 / 2.0) * (w - yWeight[1] + 1.0);
				yWeight[0] = 1.0 - yWeight[1] - yWeight[2];

				/* z */
				w = z - (double)zIndex[1];
				zWeight[1] = 3.0 / 4.0 - w * w;
				zWeight[2] = (1.0 / 2.0) * (w - zWeight[1] + 1.0);
				zWeight[0] = 1.0 - yWeight[1] - zWeight[2];

				break;

			case 3:
				/* x */
				w = x - (double)xIndex[1];
				xWeight[3] = (1.0 / 6.0) * w * w * w;
				xWeight[0] = (1.0 / 6.0) + (1.0 / 2.0) * w * (w - 1.0) - xWeight[3];
				xWeight[2] = w + xWeight[0] - 2.0 * xWeight[3];
				xWeight[1] = 1.0 - xWeight[0] - xWeight[2] - xWeight[3];

				/* y */
				w = y - (double)yIndex[1];
				yWeight[3] = (1.0 / 6.0) * w * w * w;
				yWeight[0] = (1.0 / 6.0) + (1.0 / 2.0) * w * (w - 1.0) - yWeight[3];
				yWeight[2] = w + yWeight[0] - 2.0 * yWeight[3];
				yWeight[1] = 1.0 - yWeight[0] - yWeight[2] - yWeight[3];

				/* z */
				w = z - (double)zIndex[1];
				zWeight[3] = (1.0 / 6.0) * w * w * w;
				zWeight[0] = (1.0 / 6.0) + (1.0 / 2.0) * w * (w - 1.0) - zWeight[3];
				zWeight[2] = w + zWeight[0] - 2.0 * zWeight[3];
				zWeight[1] = 1.0 - zWeight[0] - zWeight[2] - zWeight[3];

				break;

			case 4:
				/* x */
				w = x - (double)xIndex[2];
				w2 = w * w;
				t = (1.0 / 6.0) * w2;
				xWeight[0] = 1.0 / 2.0 - w;
				xWeight[0] *= xWeight[0];
				xWeight[0] *= (1.0 / 24.0) * xWeight[0];
				t0 = w * (t - 11.0 / 24.0);
				t1 = 19.0 / 96.0 + w2 * (1.0 / 4.0 - t);
				xWeight[1] = t1 + t0;
				xWeight[3] = t1 - t0;
				xWeight[4] = xWeight[0] + t0 + (1.0 / 2.0) * w;
				xWeight[2] = 1.0 - xWeight[0] - xWeight[1] - xWeight[3] - xWeight[4];

				/* y */
				w = y - (double)yIndex[2];
				w2 = w * w;
				t = (1.0 / 6.0) * w2;
				yWeight[0] = 1.0 / 2.0 - w;
				yWeight[0] *= yWeight[0];
				yWeight[0] *= (1.0 / 24.0) * yWeight[0];
				t0 = w * (t - 11.0 / 24.0);
				t1 = 19.0 / 96.0 + w2 * (1.0 / 4.0 - t);
				yWeight[1] = t1 + t0;
				yWeight[3] = t1 - t0;
				yWeight[4] = yWeight[0] + t0 + (1.0 / 2.0) * w;
				yWeight[2] = 1.0 - yWeight[0] - yWeight[1] - yWeight[3] - yWeight[4];

				/* z */
				w = z - (double)zIndex[2];
				w2 = w * w;
				t = (1.0 / 6.0) * w2;
				zWeight[0] = 1.0 / 2.0 - w;
				zWeight[0] *= zWeight[0];
				zWeight[0] *= (1.0 / 24.0) * zWeight[0];
				t0 = w * (t - 11.0 / 24.0);
				t1 = 19.0 / 96.0 + w2 * (1.0 / 4.0 - t);
				zWeight[1] = t1 + t0;
				zWeight[3] = t1 - t0;
				zWeight[4] = zWeight[0] + t0 + (1.0 / 2.0) * w;
				zWeight[2] = 1.0 - zWeight[0] - zWeight[1] - zWeight[3] - zWeight[4];

				break;

			case 5:
				/* x */
				w = x - (double)xIndex[2];
				w2 = w * w;
				xWeight[5] = (1.0 / 120.0) * w * w2 * w2;
				w2 -= w;
				w4 = w2 * w2;
				w -= 1.0 / 2.0;
				t = w2 * (w2 - 3.0);
				xWeight[0] = (1.0 / 24.0) * (1.0 / 5.0 + w2 + w4) - xWeight[5];
				t0 = (1.0 / 24.0) * (w2 * (w2 - 5.0) + 46.0 / 5.0);
				t1 = (-1.0 / 12.0) * w * (t + 4.0);
				xWeight[2] = t0 + t1;
				xWeight[3] = t0 - t1;
				t0 = (1.0 / 16.0) * (9.0 / 5.0 - t);
				t1 = (1.0 / 24.0) * w * (w4 - w2 - 5.0);
				xWeight[1] = t0 + t1;
				xWeight[4] = t0 - t1;

				/* y */
				w = y - (double)yIndex[2];
				w2 = w * w;
				yWeight[5] = (1.0 / 120.0) * w * w2 * w2;
				w2 -= w;
				w4 = w2 * w2;
				w -= 1.0 / 2.0;
				t = w2 * (w2 - 3.0);
				yWeight[0] = (1.0 / 24.0) * (1.0 / 5.0 + w2 + w4) - yWeight[5];
				t0 = (1.0 / 24.0) * (w2 * (w2 - 5.0) + 46.0 / 5.0);
				t1 = (-1.0 / 12.0) * w * (t + 4.0);
				yWeight[2] = t0 + t1;
				yWeight[3] = t0 - t1;
				t0 = (1.0 / 16.0) * (9.0 / 5.0 - t);
				t1 = (1.0 / 24.0) * w * (w4 - w2 - 5.0);
				yWeight[1] = t0 + t1;
				yWeight[4] = t0 - t1;

				/* z */
				w = z - (double)zIndex[2];
				w2 = w * w;
				zWeight[5] = (1.0 / 120.0) * w * w2 * w2;
				w2 -= w;
				w4 = w2 * w2;
				w -= 1.0 / 2.0;
				t = w2 * (w2 - 3.0);
				zWeight[0] = (1.0 / 24.0) * (1.0 / 5.0 + w2 + w4) - zWeight[5];
				t0 = (1.0 / 24.0) * (w2 * (w2 - 5.0) + 46.0 / 5.0);
				t1 = (-1.0 / 12.0) * w * (t + 4.0);
				zWeight[2] = t0 + t1;
				zWeight[3] = t0 - t1;
				t0 = (1.0 / 16.0) * (9.0 / 5.0 - t);
				t1 = (1.0 / 24.0) * w * (w4 - w2 - 5.0);
				zWeight[1] = t0 + t1;
				zWeight[4] = t0 - t1;

				break;

			default:
				System.out.println("\n\nInvalid Spline Degree.\n\n");
				return(0.0f);
		}


		/**
		 * Boundary Conditions
		 */
		if (Width == 1) {
			for (s = 0; s <= SplineDegree; s++) {
				xIndex[s] = 0;
			}
		}


		if (Height == 1) {
			for (s = 0; s <= SplineDegree; s++) {
				yIndex[s] = 0;
			}
		}


		if (Depth == 1) {
			for (s = 0; s <= SplineDegree; s++) {
				zIndex[s] = 0;
			}
		}


		for (s = 0; s <= SplineDegree; s++) {
			xIndex[s] = (xIndex[s] < 0) ? (-xIndex[s] - Width2 * ((-xIndex[s]) / Width2)) : (xIndex[s] - Width2 * (xIndex[s] / Width2));

			if (Width <= xIndex[s]) {
				xIndex[s] = Width2 - xIndex[s];
			}


			yIndex[s] = (yIndex[s] < 0) ? (-yIndex[s] - Height2 * ((-yIndex[s]) / Height2)) : (yIndex[s] - Height2 * (yIndex[s] / Height2));

			if (Height <= yIndex[s]) {
				yIndex[s] = Height2 - yIndex[s];
			}


			zIndex[s] = (zIndex[s] < 0) ? (-zIndex[s] - Depth2 * ((-zIndex[s]) / Depth2)) : (zIndex[s] - Depth2 * (zIndex[s] / Depth2));

			if (Depth <= zIndex[s]) {
				zIndex[s] = Depth2 - zIndex[s];
			}
		}


		/**
		 * At last, perform interpolation.
		 */
		interpolated = 0.0;

		for (k = 0; k <= SplineDegree; k++) {
			t = 0;

			for (j = 0; j <= SplineDegree; j++) {
				w = 0.0;

				for (i = 0; i <= SplineDegree; i++) {
					w += ((double) xWeight[i]) * ((double) Bcoeff[(int) xIndex[i]][(int) yIndex[j]][(int) zIndex[k]]);
				}

				t += yWeight[j] * w;
			}

			interpolated += zWeight[k] * t;
		}



		return((float) interpolated);
	}


	/**
	 * @param c  Samples
	 * @param DataLength Number of samples or coefficients
	 * @param z  Poles
	 * @param NbPoles  Number of poles
	 * @param Tolerance  Admissible relative error
	 */
	static void ConvertToInterpolationCoefficients(double c[], int DataLength, double z[], int NbPoles, double Tolerance) {
		double Lambda = 1.0;
		int n, k;


		/**
		 * Special case required by mirror boundaries.
		 */
		if (DataLength == 1) {
			return;
		}


		/**
		 * Compute the overall gain
		 */
		for (k = 0; k < NbPoles; k++) {
			Lambda = Lambda * (1.0 - z[k]) * (1.0 - 1.0 / z[k]);
		}


		/**
		 * Apply the gain
		 */
		for (n = 0; n < DataLength; n++) {
			c[n] *= Lambda;
		}


		/**
		 * Loop over all poles
		 */
		for (k = 0; k < NbPoles; k++) {
			c[0] = InitialCausalCoefficient(c, DataLength, z[k], Tolerance);

			for (n = 1; n < DataLength; n++) {
				c[n] += z[k] * c[n - 1];
			}


			c[DataLength - 1] = InitialAntiCausalCoefficient(c, DataLength, z[k]);

			for (n = DataLength - 2; 0 <= n; n--) {
				c[n] = z[k] * (c[n + 1] - c[n]);
			}
		}
	}


	static double InitialCausalCoefficient(double c[], int DataLength, double z, double Tolerance) {
		double Sum, zn, z2n, iz;
		int n, Horizon;

		Horizon = DataLength;
		if (Tolerance > 0.0) {
			/**
			 * Java Math.ceil() returns a double, which seems silly.
			 */
			Horizon = (int) Math.ceil(Math.log(Tolerance) / Math.log(Math.abs(z)));
		}


		if (Horizon < DataLength) {
			/**
			 * Accelerated Loop
			 */
			zn = z;
			Sum = c[0];

			for (n = 1; n < Horizon; n++) {
				Sum += zn * c[n];
				zn *= z;
			}

			return(Sum);
		} else {
			/**
			 * Full Loop
			 */
			zn = z;
			iz = 1.0 / z;
			z2n = Math.pow(z, (DataLength - 1));
			Sum = c[0] + z2n * c[DataLength - 1];
			z2n *= z2n * iz;
			for (n = 1; n <= DataLength - 2; n++) {
				Sum += (zn + z2n) * c[n];
				zn *= z;
				z2n *= iz;
			}

			return(Sum / (1.0 - zn * zn));
		}
	}


	static double InitialAntiCausalCoefficient(double c[], int DataLength, double z) {
		return((z / (z * z - 1.0)) * (z * c[DataLength - 2] + c[DataLength - 1]));
	}


	static double[] GetRow(int y, int z, double Image[][][], int Width){
		double Line[] = new double[Width];
		int i = 0;


		for (i = 0; i < Width; i++) {
			Line[i] = Image[i][y][z];
		}

		return Line;
	}


	static void PutRow(double Line[], int y, int z, double Image[][][], int Width){
		int i = 0;


		for (i = 0; i < Width; i++) {
			Image[i][y][z] = Line[i];
		}

		return;
	}


	static double[] GetColumn(int x, int z, double Image[][][], int Height) {
		double Line[] = new double[Height];
		int j = 0;


		for (j = 0; j < Height; j++) {
			Line[j] = Image[x][j][z];
		}

		return Line;
	}


	static void PutColumn(double Line[], int x, int z, double Image[][][], int Height) {
		int j = 0;


		for (j = 0; j < Height; j++) {
			Image[x][j][z] = Line[j];
		}

		return;
	}


	static double[] GetVertical(int x, int y, double Image[][][], int Depth) {
		double Line[] = new double[Depth];
		int k = 0;

		for (k = 0; k < Depth; k++){
			Line[k] = Image[x][y][k];
		}

		return Line;
	}


	static void PutVertical(double Line[], int x, int y, double Image[][][], int Depth) {
		int k = 0;

		for (k = 0; k < Depth; k++){
			Line[k] = Image[x][y][k];
			Image[x][y][k] = Line[k];
		}

		return;
	}


	public int SamplesToCoefficients(double Image[][][], int Width, int Height, int Depth, int SplineDegree) {
		double Line[];
		double Pole[] = new double[2];
		int NbPoles;
		int x, y, z;



		switch (SplineDegree) {
			case 2:
				NbPoles = 1;
				Pole[0] = Math.sqrt(8.0) - 3.0;
				break;

			case 3:
				NbPoles = 1;
				Pole[0] = Math.sqrt(3.0) - 2.0;
				break;

			case 4:
				NbPoles = 2;
				Pole[0] = Math.sqrt(664.0 - Math.sqrt(438976.0)) + Math.sqrt(304.0) - 19.0;
				Pole[1] = Math.sqrt(664.0 + Math.sqrt(438976.0)) - Math.sqrt(304.0) - 19.0;
				break;

			case 5:
				NbPoles = 2;
				Pole[0] = Math.sqrt(135.0 / 2.0 - Math.sqrt(17745.0 / 4.0)) + Math.sqrt(105.0 / 4.0) - 13.0 / 2.0;
				Pole[1] = Math.sqrt(135.0 / 2.0 + Math.sqrt(17745.0 / 4.0)) - Math.sqrt(105.0 / 4.0) - 13.0 / 2.0;
				break;

			default:
				System.out.print("\n\nInvalid spline degree:" + SplineDegree + "\n\n");
				return(1);
		}


		/**
		 * Convert the image samples into interpolation coefficients
		 * Separable process: X-Axis
		 */
		for (z = 0; z < Depth; z++){
			for (y = 0; y < Height; y++) {
				Line = GetRow(y, z, Image, Width);
				ConvertToInterpolationCoefficients(Line, Width, Pole, NbPoles, DBL_EPSILON);
				PutRow(Line, y, z, Image, Width);
			}
		}


		/**
		 * Convert the image samples into interpolation coefficients
		 * Separable process: Y-Axis
		 */
		for (z = 0; z < Depth; z++){
			for (x = 0; x < Width; x++) {
				Line = GetColumn(x, z, Image, Height);
				ConvertToInterpolationCoefficients(Line, Height, Pole, NbPoles, DBL_EPSILON);
				PutColumn(Line, x, z, Image, Height);
			}
		}


		/**
		 * Convert the image samples into interpolation coefficients
		 * Separable process: Z-Axis
		 */
		for (y = 0; y < Height; y++) {
			for (x = 0; x < Width; x++) {
				Line = GetVertical(x, y, Image, Depth);
				ConvertToInterpolationCoefficients(Line, Depth, Pole, NbPoles, DBL_EPSILON);
				PutVertical(Line, x, y, Image, Depth);
				
			}
		}

		return(0);
	}


	public ImageDataFloat solve(ImageDataFloat in, int xo, int yo, int zo) {
		float[][][] inMat = in.toArray3d();
		float[][][] outMat = solve(inMat, xo, yo, zo);

		ImageDataFloat out = new ImageDataFloat(outMat);

		return out;
	}


	public float[][][] solve(float[][][] inMat, int xo, int yo, int zo) {
		int i = 0, j = 0, k = 0;
		double workingMat[][][] = new double[xo][yo][zo];
		double tmpDouble;
		float maxFloat = inMat[0][0][0];
		float minFloat = inMat[0][0][0];

		float outMat[][][] = new float[xo][yo][zo];

		int Error = 0;

		for (k = 0; k < zo; k++) {
			for (j = 0; j < yo; j++) {
				for (i = 0; i < xo; i++) {
					workingMat[i][j][k] = inMat[i][j][k];

					if (maxFloat < inMat[i][j][k]) {
						maxFloat = inMat[i][j][k];
					}


					if (minFloat > inMat[i][j][k]) {
						minFloat = inMat[i][j][k];
					}
				}
			}
		}


		Error = SamplesToCoefficients(workingMat, xo, yo, zo, splineDegree);


		if (Error == 1) {
			System.out.print("\n\nChange of basis failed\n\n");
			return null;
		}


		if (splineDegree <= 0) {
			setSplineDegree(3);
		}


		for (k = 0; k < zo; k++) {
			for (j = 0; j < yo; j++) {
				for (i = 0; i < xo; i++) {
					tmpDouble = interpolateBSpline(workingMat, xo, yo, zo, i, j, k, splineDegree);
					tmpDouble = Math.floor(tmpDouble + 0.5);

					outMat[i][j][k] = (float) tmpDouble;

					if (outMat[i][j][k] > maxFloat) {
						outMat[i][j][k] = maxFloat;
					}


					if (outMat[i][j][k] < minFloat) {
						outMat[i][j][k] = minFloat;
					}
				}
			}
		}


		return outMat;
	}
}
