package edu.jhu.ece.iacl.algorithms.thickness.laplace;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;

/**
 * Solve system of linear equations using the red/black Gauss Seidel approximation to Laplace
 * 
 * u(i,j,k)
 *
 * @author Blake Lucas
 * 
 */
public class LaplaceSolveOnGrid{
	private static final LaplaceSolveOnGrid linearSolve = new LaplaceSolveOnGrid();
	private double dt;
	private double maxIters;
	private double maxError;
	private float minBound;
	private float maxBound;
	public LaplaceSolveOnGrid(double dt,int maxIters,double maxError,float minBound,float maxBound){
		this.dt=dt/2.0;
		this.maxIters=maxIters;
		this.maxError=maxError;
	}
	public LaplaceSolveOnGrid(){
		this(1, 200,1E-7,0,1);
	}
	public static ImageDataFloat doSolve(ImageData V) {
		return linearSolve.solve(V);
	}
	public static ImageDataFloat doSolve(ImageData V,byte[][][] mask) {
		return linearSolve.solve(V,mask);
	}
	public static ImageDataFloat doSolve(ImageData V,double[] res) {
		return linearSolve.solve(V,res);
	}
	public ImageDataFloat solve(ImageData V,byte[][][] mask) {		
		int i;
		ImageDataFloat result=new ImageDataFloat(V);
		double err;
		double norm=V.getRows()*V.getCols()*V.getSlices()*8;
		dt=dt/3.0;
		
		for (i = 0; i < maxIters; i++) {
			err=0;
			err+=octant(result,mask, dt, 0, 0, 0);
			err+=octant(result,mask, dt, 0, 1, 1);
			err+=octant(result,mask, dt, 1, 0, 1);
			err+=octant(result,mask, dt, 1, 1, 0);
			err+=octant(result,mask, dt, 1, 0, 0);
			err+=octant(result,mask, dt, 0, 0, 1);
			err+=octant(result,mask, dt, 0, 1, 0);
			err+=octant(result,mask, dt, 1, 1, 1);
			err/=8;
			if(err<maxError)break;
			//System.out.println("Iteration "+i+" Error "+err);
			//plotError(i,Math.log(err+1));
		}
		return result;
	}
	private static double octant(ImageDataFloat U,byte[][][] maskb,double dt, int ic, int jc, int kc) {
		int x, y, z;
		//CubicVolumeFloat U = _U.clone();
		int xn = U.getRows();
		int yn = U.getCols();
		int zn = U.getSlices();
		int prek, nextk, prei, nexti, prej, nextj;
		float[][][] Umat=U.toArray3d();
		double err=0;
		int count=0;
		for (x = kc; x < xn; x += 2) {
			prek = (x > 0) ? (x - 1) : 0;
			nextk = (x == (xn - 1)) ? (xn - 1) : (x + 1);
			for (y = ic; y < yn; y += 2) {
				prei = (y > 0) ? (y - 1) : 0;
				nexti = (y == (yn - 1)) ? (yn - 1) : (y + 1);
				for (z = jc; z < zn; z += 2) {
					prej = (z > 0) ? (z - 1) : 0;
					nextj = (z == (zn - 1)) ? (zn - 1) : (z + 1);
					if(maskb[x][y][z]!=0){
						double res=dt*(Umat[prek][y][z] + Umat[nextk][y][z]
								+ Umat[x][prei][z] + Umat[x][nexti][z]
								+ Umat[x][y][prej] + Umat[x][y][nextj]-6*Umat[x][y][z]);
						Umat[x][y][z] += res;
						err+=res*res;
						count++;
					}
				}
			}
		}
		if(count>0){
			return err/count;
		} else {
			return 0;
		}
	}
	public ImageDataFloat solve(ImageData V) {
		int i;
		ImageDataFloat result=new ImageDataFloat(V);
		double err=0;
		double norm=V.getRows()*V.getCols()*V.getSlices()*8;
		dt=dt/3.0;
		
		for (i = 0; i < maxIters; i++) {
			err=0;
			err+=octant(result, dt, 0, 0, 0);
			err+=octant(result, dt, 0, 1, 1);
			err+=octant(result, dt, 1, 0, 1);
			err+=octant(result, dt, 1, 1, 0);
			err+=octant(result, dt, 1, 0, 0);
			err+=octant(result, dt, 0, 0, 1);
			err+=octant(result, dt, 0, 1, 0);
			err+=octant(result, dt, 1, 1, 1);
			err/=8;
			if(err<maxError)break;
			//
			//plotError(i,Math.log(err+1));
		}
		System.out.println("Laplace Solver Iterations "+i+" Error "+err);
		return result;
	}
	private static double octant(ImageDataFloat U,double dt, int ic, int jc, int kc) {
		int x, y, z;
		//CubicVolumeFloat U = _U.clone();
		int xn = U.getRows();
		int yn = U.getCols();
		int zn = U.getSlices();
		int prek, nextk, prei, nexti, prej, nextj;
		float[][][] Umat=U.toArray3d();
		double err=0;
		int count=0;
		for (x = kc; x < xn; x += 2) {
			prek = (x > 0) ? (x - 1) : 0;
			nextk = (x == (xn - 1)) ? (xn - 1) : (x + 1);
			for (y = ic; y < yn; y += 2) {
				prei = (y > 0) ? (y - 1) : 0;
				nexti = (y == (yn - 1)) ? (yn - 1) : (y + 1);
				for (z = jc; z < zn; z += 2) {
					prej = (z > 0) ? (z - 1) : 0;
					nextj = (z == (zn - 1)) ? (zn - 1) : (z + 1);
					if(Umat[x][y][z]!=0&&Umat[x][y][z]!=1&&Umat[x][y][z]!=-3&&Umat[x][y][z]!=2){
						double res=dt*(Umat[prek][y][z] + Umat[nextk][y][z]
								+ Umat[x][prei][z] + Umat[x][nexti][z]
								+ Umat[x][y][prej] + Umat[x][y][nextj]-6*Umat[x][y][z]);
						Umat[x][y][z] += res;
						err+=res*res;
						count++;
					}
				}
			}
		}
		if(count>0){
			return err/count;
		} else {
			return 0;
		}
	}
	public ImageDataFloat solve(ImageData V, double[] res) {
		int i;
		ImageDataFloat result=new ImageDataFloat(V);
		double err=0;
		double norm=V.getRows()*V.getCols()*V.getSlices()*8;
		dt=dt/3.0;
		
		// Only need the cross terms
		double[] res2 = new double[4];
		res2[0] = res[0]*res[0]*res[1]*res[1];
		res2[1] = res[0]*res[0]*res[2]*res[2];
		res2[2] = res[1]*res[1]*res[2]*res[2];
		res2[3] = res2[0] + res2[1] + res2[2];
		
		for (i = 0; i < maxIters; i++) {
			err=0;
			err+=octant(result, dt, 0, 0, 0, res2);
			err+=octant(result, dt, 0, 1, 1, res2);
			err+=octant(result, dt, 1, 0, 1, res2);
			err+=octant(result, dt, 1, 1, 0, res2);
			err+=octant(result, dt, 1, 0, 0, res2);
			err+=octant(result, dt, 0, 0, 1, res2);
			err+=octant(result, dt, 0, 1, 0, res2);
			err+=octant(result, dt, 1, 1, 1, res2);
			err/=8;
			if(err<maxError)break;
			//
			//plotError(i,Math.log(err+1));
		}
		System.out.println("Laplace Solver Iterations "+i+" Error "+err);
		return result;
	}
	private static double octant(ImageDataFloat U,double dt, int ic, int jc, int kc, double[] res2) {
		// Equation modifications from Diep et al. (2007) 
		
		int x, y, z;
		//CubicVolumeFloat U = _U.clone();
		int xn = U.getRows();
		int yn = U.getCols();
		int zn = U.getSlices();
		int prek, nextk, prei, nexti, prej, nextj;
		float[][][] Umat=U.toArray3d();
		double err=0;
		int count=0;
		for (x = kc; x < xn; x += 2) {
			prek = (x > 0) ? (x - 1) : 0;
			nextk = (x == (xn - 1)) ? (xn - 1) : (x + 1);
			for (y = ic; y < yn; y += 2) {
				prei = (y > 0) ? (y - 1) : 0;
				nexti = (y == (yn - 1)) ? (yn - 1) : (y + 1);
				for (z = jc; z < zn; z += 2) {
					prej = (z > 0) ? (z - 1) : 0;
					nextj = (z == (zn - 1)) ? (zn - 1) : (z + 1);
					if(Umat[x][y][z]!=0&&Umat[x][y][z]!=1&&Umat[x][y][z]!=-3&&Umat[x][y][z]!=2){
						double res=(dt/res2[3])*(res2[2]*(Umat[prek][y][z] + Umat[nextk][y][z])
								+ res2[1]*(Umat[x][prei][z] + Umat[x][nexti][z])
								+ res2[0]*(Umat[x][y][prej] + Umat[x][y][nextj]) - 2*res2[3]*Umat[x][y][z]);
						Umat[x][y][z] += res;
						err+=res*res;
						count++;
					}
				}
			}
		}
		if(count>0){
			return err/count;
		} else {
			return 0;
		}
	}
}

