package edu.jhu.ece.iacl.algorithms.dti;

import edu.jhu.ece.iacl.algorithms.dti.DiffusionTensor;
import edu.jhu.ece.iacl.algorithms.dti.EstimateTensorLLMSE;
import Jama.Matrix;
import java.util.ArrayList;
import edu.jhu.ece.iacl.algorithms.dti.FiberTrackerWithROI;
import javax.vecmath.Point3f;

import edu.jhu.ece.iacl.jist.structures.fiber.FiberCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;

public class testWBFiberDistribution {
	//private static boolean detailedDebugging = false;
	public static FiberCollection track(int iterWB, ImageDataFloat inDWdataFloat, 
			float []inbvalues, float [][]ingrads, byte [][][]inmask, boolean usePartialEstimates,
			double startFA, double stopFA, double turningAngle, float [][][][]tensor, float[][][] seedroi) {
		System.out.println("Wild Boot tensor estimation started");
		
		float [][][][]inDWdata = inDWdataFloat.toArray4d();
		FiberCollection fibercollection=new FiberCollection();
		//preprocess to get real DWI, grads and bval
		int trueLenOfGrads = 0;
		for(int i = 0; i < ingrads.length; i++){
			float []q = ingrads[i];
			if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] < 100){
				trueLenOfGrads++;
			}
		}
		System.out.println(inDWdata.length);
		System.out.println(inDWdata[0].length);
		System.out.println(inDWdata[0][0].length);
		System.out.println("Directions: " + trueLenOfGrads);
		int row = (int)Math.ceil(inDWdata.length/2);
		int col = (int)Math.ceil(inDWdata[0].length/2);
		int slice = (int)Math.ceil(inDWdata[0][0].length/2);
		System.out.println("allocating memories");
		float [][][][]DWdata = new float[row][col][slice][trueLenOfGrads];
		System.out.println("allocating residuals");
		float [][][][]residuals = new float[row][col][slice][trueLenOfGrads];
		System.out.println("allocating Dprojection");
		float [][][][]Dprojection = new float[row][col][slice][trueLenOfGrads];


		System.out.println("3 images");
		float []bvalues = new float[trueLenOfGrads];
		float [][]grads = new float[trueLenOfGrads][3];
		int indexOfDWdata = 0;
		int b0 = -1;
		//int noOfB0 = 0;
		for(int i = 0;i < ingrads.length; i++){
			float []q = ingrads[i];
			if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] == 0){
				b0 = indexOfDWdata;
			}
				
			if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] < 100){
				bvalues[indexOfDWdata] = inbvalues[i];
				grads[indexOfDWdata] = ingrads[i];
				indexOfDWdata++;
			}
			
			//System.out.println(indexOfDWdata);
		}
		
		float []diagOfHCCME = diagOfHCCME(grads, bvalues);
		
		for(int block = 0; block < 8; block++){
			indexOfDWdata = 0;
			System.out.println("Block: "+block+" of DWI");
			for(int i = 0;i < ingrads.length; i++){
				float []q = ingrads[i];
				if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] < 100){
					for(int ii = 0; ii < row; ii++){
						for(int jj = 0; jj < col; jj++){
							for(int kk = 0; kk < slice; kk++){
								//System.out.println("kk: "+kk);
								if(ii+(block&1)*row<inDWdata.length && jj+(block&2)/2*col<inDWdata[0].length && kk+(block&4)/4*slice<inDWdata[0][0].length){
									DWdata[ii][jj][kk][indexOfDWdata] = inDWdata[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][i];
								}
							}
						}
					}
					indexOfDWdata++;
				}
				//System.out.println(indexOfDWdata);
			}
			
			//mask
			byte [][][]mask = null;
			if(inmask!=null){
				mask = new byte[row][col][slice];
				for(int ii = 0; ii < row; ii++){
					for(int jj = 0; jj < col; jj++){
						for(int kk = 0; kk < slice; kk++){
							if(ii+(block&1)*row<inDWdata.length && jj+(block&2)/2*col<inDWdata[0].length && kk+(block&4)/4*slice<inDWdata[0][0].length){
								mask[ii][jj][kk] = inmask[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice];
							}
						}
					}
				}
			}
			
			//for(int i=0;i<diagOfHCCME.length;i++){
			//System.out.println("diagOfHCCME: "+diagOfHCCME[i]);
			//}
			//initial estimation

			float tensor_block[][][][] = EstimateTensorLLMSE.estimate(DWdata, bvalues, grads, mask, usePartialEstimates);
			
			System.out.println("Initial tensor estimation complete");
			//residuals
			System.out.println("Computing residuals");


			System.out.println("Computing residuals");
			for(int i = 0; i < DWdata.length;i++){
				for(int j = 0; j < DWdata[0].length;j++){
					for(int k = 0; k < DWdata[0][0].length;k++){
						//System.out.println(i+" "+j+" "+k);
						if(i+(block&1)*row<inDWdata.length && j+(block&2)/2*col<inDWdata[0].length && k+(block&4)/4*slice<inDWdata[0][0].length){
							if(mask!=null){
								if(mask[i][j][k] == 0){
									tensor_block[i][j][k][0] = 0;
									tensor_block[i][j][k][1] = 0;
									tensor_block[i][j][k][2] = 0;
									tensor_block[i][j][k][3] = 0;
									tensor_block[i][j][k][4] = 0;
									tensor_block[i][j][k][5] = 0;
									continue;
								}
							}
							Dprojection[i][j][k] = 
								projectTensorLog(DWdata[i][j][k][b0], tensor_block[i][j][k], grads, bvalues);
							for(int l = 0; l < Dprojection[0][0][0].length; l++){
								//System.out.println("l: "+l);
								residuals[i][j][k][l] = Dprojection[i][j][k][l]-(float)Math.log(DWdata[i][j][k][l]);
							}
						}
					}
				}
			}

			
			System.out.println("Starting WB and tracking");
			//WB and fiber tracking
			

			for(int i = 0; i < DWdata.length;i++){
				for(int j = 0; j < DWdata[0].length;j++){
					for(int k = 0; k < DWdata[0][0].length;k++){
						if(i+(block&1)*row<inDWdata.length && j+(block&2)/2*col<inDWdata[0].length && k+(block&4)/4*slice<inDWdata[0][0].length){
							
							if(mask!=null){
								if(mask[i][j][k] == 0){
									//tempDW[i][j][k][0] = Float.NaN;
									//tempDW[i][j][k][1] = Float.NaN;
									//tempDW[i][j][k][2] = Float.NaN;
									//tempDW[i][j][k][3] = Float.NaN;
									//tempDW[i][j][k][4] = Float.NaN;
									//tempDW[i][j][k][5] = Float.NaN;
									continue;
								}
							}
							for(int l = 0; l < trueLenOfGrads; l++){
								boolean sign = Math.random()>.5;
								if(sign) {
									if(l!=b0){
										DWdata[i][j][k][l] = 
											(float)Math.exp(((float)Dprojection[i][j][k][l]+(1/(1-diagOfHCCME[l])*(float)residuals[i][j][k][l])));
										//System.out.println("tempDW:" + tempDW[0][0][0][l]);
									}
								} 
								else {
									if(l!=b0){
										DWdata[i][j][k][l] = 
											(float)Math.exp(((float)Dprojection[i][j][k][l]-(1/(1-diagOfHCCME[l])*(float)residuals[i][j][k][l])));
										//System.out.println("tempDW:" + tempDW[0][0][0][l]);
									}
								}
								//System.out.println(DWdata[i][j][k][l]+" "+inDWdata[i+(block&1)*row][j+(block&2)/2*col][k+(block&4)/4*slice][l]);
							}
							
						}
					}
				}
			}

			
			tensor_block = 
				EstimateTensorLLMSE.estimate(DWdata, bvalues, grads, mask, usePartialEstimates);
			
			System.out.println("Estimation completed");
			for(int ii = 0; ii < row; ii++){
				for(int jj = 0; jj < col; jj++){
					for(int kk = 0; kk < slice; kk++){
						if(ii+(block&1)*row<inDWdata.length && jj+(block&2)/2*col<inDWdata[0].length && kk+(block&4)/4*slice<inDWdata[0][0].length){
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][0] = tensor_block[ii][jj][kk][0];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][1] = tensor_block[ii][jj][kk][1];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][2] = tensor_block[ii][jj][kk][2];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][3] = tensor_block[ii][jj][kk][3];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][4] = tensor_block[ii][jj][kk][4];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][5] = tensor_block[ii][jj][kk][5];
							
						}
					}
				}
			}
			System.out.println("Block "+block+" completed");
		}
		ImageHeader hdr = inDWdataFloat.getHeader();
		float []res = hdr.getDimResolutions();

		Point3f resForFACT = new Point3f();
		resForFACT.set(res[0], res[1], res[2]);
		float [][][]FA = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length];
		float [][][][]VEC1 = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length][3];
		for(int i = 0; i < inDWdata.length;i++){
			for(int j = 0; j < inDWdata[0].length;j++){
				for(int k = 0; k < inDWdata[0][0].length;k++){
					if(inmask!=null){
						if(inmask[i][j][k]==0){
							FA[i][j][k] = 0;
							continue;
						}
					}

					DiffusionTensor dt = new DiffusionTensor(tensor[i][j][k]);
					FA[i][j][k] = dt.FA();
					//System.out.println(FA[i][j][k]);
					VEC1[i][j][k] = dt.vec1();
				}
			}
		}

		
		ImageDataFloat imageFA = new ImageDataFloat(FA);
		ImageDataFloat imageVEC1 = new ImageDataFloat(VEC1);



		FiberTrackerWithROI track = 
			new FiberTrackerWithROI(startFA, stopFA, turningAngle, resForFACT, imageFA, imageVEC1);
		if(seedroi == null){
		fibercollection = track.solve();
		}
		else{
			fibercollection = track.solve(seedroi);
		}
		return fibercollection;
	}
	
	public static FiberCollection track(int iterWB, ImageDataFloat inDWdataFloat, 
			float []inbvalues, float [][]ingrads, byte [][][]inmask, boolean usePartialEstimates,
			double startFA, double stopFA, double turningAngle, float [][][][]tensor, double[][] seedroi) {
		System.out.println("Wild Boot tensor estimation started");
		
		float [][][][]inDWdata = inDWdataFloat.toArray4d();
		FiberCollection fibercollection=new FiberCollection();
		//preprocess to get real DWI, grads and bval
		int trueLenOfGrads = 0;
		for(int i = 0; i < ingrads.length; i++){
			float []q = ingrads[i];
			if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] < 100){
				trueLenOfGrads++;
			}
		}
		System.out.println(inDWdata.length);
		System.out.println(inDWdata[0].length);
		System.out.println(inDWdata[0][0].length);
		System.out.println("Directions: " + trueLenOfGrads);
		int row = (int)Math.ceil(inDWdata.length/2);
		int col = (int)Math.ceil(inDWdata[0].length/2);
		int slice = (int)Math.ceil(inDWdata[0][0].length/2);
		System.out.println("allocating memories");
		float [][][][]DWdata = new float[row][col][slice][trueLenOfGrads];
		System.out.println("allocating residuals");
		float [][][][]residuals = new float[row][col][slice][trueLenOfGrads];
		System.out.println("allocating Dprojection");
		float [][][][]Dprojection = new float[row][col][slice][trueLenOfGrads];


		System.out.println("3 images");
		float []bvalues = new float[trueLenOfGrads];
		float [][]grads = new float[trueLenOfGrads][3];
		int indexOfDWdata = 0;
		int b0 = -1;
		for(int i = 0;i < ingrads.length; i++){
			float []q = ingrads[i];
			if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] == 0){
				b0 = indexOfDWdata;
			}
				
			if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] < 100){
				bvalues[indexOfDWdata] = inbvalues[i];
				grads[indexOfDWdata] = ingrads[i];
				indexOfDWdata++;
			}
			
			//System.out.println(indexOfDWdata);
		}
		float []diagOfHCCME = diagOfHCCME(grads, bvalues);
		
		for(int block = 0; block < 8; block++){
			indexOfDWdata = 0;
			System.out.println("Block: "+block+" of DWI");
			for(int i = 0;i < ingrads.length; i++){
				float []q = ingrads[i];
				if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] < 100){
					for(int ii = 0; ii < row; ii++){
						for(int jj = 0; jj < col; jj++){
							for(int kk = 0; kk < slice; kk++){
								//System.out.println("kk: "+kk);
								if(ii+(block&1)*row<inDWdata.length && jj+(block&2)/2*col<inDWdata[0].length && kk+(block&4)/4*slice<inDWdata[0][0].length){
									DWdata[ii][jj][kk][indexOfDWdata] = inDWdata[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][i];
								}
							}
						}
					}
					indexOfDWdata++;
				}
				//System.out.println(indexOfDWdata);
			}
			
			//mask
			byte [][][]mask = null;
			if(inmask!=null){
				mask = new byte[row][col][slice];
				for(int ii = 0; ii < row; ii++){
					for(int jj = 0; jj < col; jj++){
						for(int kk = 0; kk < slice; kk++){
							if(ii+(block&1)*row<inDWdata.length && jj+(block&2)/2*col<inDWdata[0].length && kk+(block&4)/4*slice<inDWdata[0][0].length){
								mask[ii][jj][kk] = inmask[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice];
							}
						}
					}
				}
			}
			
			//for(int i=0;i<diagOfHCCME.length;i++){
			//System.out.println("diagOfHCCME: "+diagOfHCCME[i]);
			//}
			//initial estimation

			float tensor_block[][][][] = EstimateTensorLLMSE.estimate(DWdata, bvalues, grads, mask, usePartialEstimates);
			
			System.out.println("Initial tensor estimation complete");
			//residuals
			System.out.println("Computing residuals");


			System.out.println("Computing residuals");
			for(int i = 0; i < DWdata.length;i++){
				for(int j = 0; j < DWdata[0].length;j++){
					for(int k = 0; k < DWdata[0][0].length;k++){
						//System.out.println(i+" "+j+" "+k);
						if(i+(block&1)*row<inDWdata.length && j+(block&2)/2*col<inDWdata[0].length && k+(block&4)/4*slice<inDWdata[0][0].length){
							if(mask!=null){
								if(mask[i][j][k] == 0){
									tensor_block[i][j][k][0] = 0;
									tensor_block[i][j][k][1] = 0;
									tensor_block[i][j][k][2] = 0;
									tensor_block[i][j][k][3] = 0;
									tensor_block[i][j][k][4] = 0;
									tensor_block[i][j][k][5] = 0;
									continue;
								}
							}
							Dprojection[i][j][k] = 
								projectTensorLog(DWdata[i][j][k][b0], tensor_block[i][j][k], grads, bvalues);
							for(int l = 0; l < Dprojection[0][0][0].length; l++){
								//System.out.println("l: "+l);
								residuals[i][j][k][l] = Dprojection[i][j][k][l]-(float)Math.log(DWdata[i][j][k][l]);
							}
						}
					}
				}
			}

			
			System.out.println("Starting WB and tracking");
			//WB and fiber tracking
			

			for(int i = 0; i < DWdata.length;i++){
				for(int j = 0; j < DWdata[0].length;j++){
					for(int k = 0; k < DWdata[0][0].length;k++){
						if(i+(block&1)*row<inDWdata.length && j+(block&2)/2*col<inDWdata[0].length && k+(block&4)/4*slice<inDWdata[0][0].length){
							
							if(mask!=null){
								if(mask[i][j][k] == 0){
									//tempDW[i][j][k][0] = Float.NaN;
									//tempDW[i][j][k][1] = Float.NaN;
									//tempDW[i][j][k][2] = Float.NaN;
									//tempDW[i][j][k][3] = Float.NaN;
									//tempDW[i][j][k][4] = Float.NaN;
									//tempDW[i][j][k][5] = Float.NaN;
									continue;
								}
							}
							for(int l = 0; l < trueLenOfGrads; l++){
								boolean sign = Math.random()>.5;
								if(sign) {
									if(l!=b0){
										DWdata[i][j][k][l] = 
											(float)Math.exp(((float)Dprojection[i][j][k][l]+(1/(1-diagOfHCCME[l])*(float)residuals[i][j][k][l])));
										//System.out.println("tempDW:" + tempDW[0][0][0][l]);
									}
								} 
								else {
									if(l!=b0){
										DWdata[i][j][k][l] = 
											(float)Math.exp(((float)Dprojection[i][j][k][l]-(1/(1-diagOfHCCME[l])*(float)residuals[i][j][k][l])));
										//System.out.println("tempDW:" + tempDW[0][0][0][l]);
									}
								}
								//System.out.println(DWdata[i][j][k][l]+" "+inDWdata[i+(block&1)*row][j+(block&2)/2*col][k+(block&4)/4*slice][l]);
							}
							
						}
					}
				}
			}

			
			tensor_block = 
				EstimateTensorLLMSE.estimate(DWdata, bvalues, grads, mask, usePartialEstimates);
			
			System.out.println("Estimation completed");
			for(int ii = 0; ii < row; ii++){
				for(int jj = 0; jj < col; jj++){
					for(int kk = 0; kk < slice; kk++){
						if(ii+(block&1)*row<inDWdata.length && jj+(block&2)/2*col<inDWdata[0].length && kk+(block&4)/4*slice<inDWdata[0][0].length){
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][0] = tensor_block[ii][jj][kk][0];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][1] = tensor_block[ii][jj][kk][1];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][2] = tensor_block[ii][jj][kk][2];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][3] = tensor_block[ii][jj][kk][3];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][4] = tensor_block[ii][jj][kk][4];
							tensor[ii+(block&1)*row][jj+(block&2)/2*col][kk+(block&4)/4*slice][5] = tensor_block[ii][jj][kk][5];
							
						}
					}
				}
			}
			System.out.println("Block "+block+" completed");
		}
		ImageHeader hdr = inDWdataFloat.getHeader();
		float []res = hdr.getDimResolutions();

		Point3f resForFACT = new Point3f();
		resForFACT.set(res[0], res[1], res[2]);
		float [][][]FA = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length];
		float [][][][]VEC1 = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length][3];
		for(int i = 0; i < inDWdata.length;i++){
			for(int j = 0; j < inDWdata[0].length;j++){
				for(int k = 0; k < inDWdata[0][0].length;k++){
					if(inmask!=null){
						if(inmask[i][j][k]==0){
							FA[i][j][k] = 0;
							continue;
						}
					}

					DiffusionTensor dt = new DiffusionTensor(tensor[i][j][k]);
					FA[i][j][k] = dt.FA();
					//System.out.println(FA[i][j][k]);
					VEC1[i][j][k] = dt.vec1();
				}
			}
		}

		
		ImageDataFloat imageFA = new ImageDataFloat(FA);
		ImageDataFloat imageVEC1 = new ImageDataFloat(VEC1);



		FiberTrackerWithROI track = 
			new FiberTrackerWithROI(startFA, stopFA, turningAngle, resForFACT, imageFA, imageVEC1);
		if(seedroi == null){
		fibercollection = track.solve();
		}
		else{
			fibercollection = track.solve(seedroi);
		}
		return fibercollection;
	}
	
	static public float[] projectTensorLog(float S0, float[] outTensors, float [][]grads, float []bvalues) {
		//
		//System.out.println("Projecting Tensor (Log)");
		float []proj = new float[grads.length];
		float Dxx = outTensors[0];
		float Dxy = outTensors[1];
		float Dxz = outTensors[2];
		float Dyy = outTensors[3];
		float Dyz = outTensors[4];
		float Dzz = outTensors[5];

		for(int i=0;i<grads.length;i++) {

			float []q = grads[i];

			float Gx = q[0];
			float Gy = q[1];
			float Gz = q[2];

			float norm = (float)Math.sqrt(Gx * Gx + Gy * Gy + Gz * Gz);
			if(norm > 100){
				continue;
			}
			if(norm != 0){
				Gx = Gx/norm;
				Gy = Gy/norm;
				Gz = Gz/norm;
			}

			double temp = -bvalues[i] *
			(Gx*Gx*Dxx + 
					2*Gx*Gy*Dxy+
					2*Gx*Gz*Dxz+
					Gy*Gy*Dyy+
					2*Gy*Gz*Dyz+
					Gz*Gz*Dzz) + Math.log(S0);
			proj[i] = (float) temp;
		}

		return proj;
	}
	static public float[] diagOfHCCME(float [][]grads, float []bvalues) {
		System.out.println("Calculating HCCME");
		float []diag = new float[grads.length];
		//Matrix HCCME = new Matrix(grads.length, grads.length, 0);
		Matrix G = new Matrix(grads.length, 7, 0);
		for(int i = 0; i < grads.length; i++){
			float []q = grads[i];

			double Gx = q[0];
			double Gy = q[1];
			double Gz = q[2];
			double norm = Math.sqrt(Gx * Gx + Gy * Gy + Gz * Gz);
			if(norm != 0){
				Gx = Gx/norm;
				Gy = Gy/norm;
				Gz = Gz/norm;
			}
			G.set(i,0,-Gx*Gx*bvalues[i]);
			G.set(i,1,-Gx*Gy*bvalues[i]);
			G.set(i,2,-Gx*Gz*bvalues[i]);
			G.set(i,3,-Gy*Gy*bvalues[i]);
			G.set(i,4,-Gy*Gz*bvalues[i]);
			G.set(i,5,-Gz*Gz*bvalues[i]);
			G.set(i,6,1);
		}
		//System.out.println("Calculating GT");
		Matrix GT = G.transpose();
		//System.out.println("Calculating GTG");
		Matrix GTG = GT.times(G);
		//System.out.println("Calculating GTG_inv");
		Matrix GTG_inv = GTG.inverse();
		//System.out.println("Calculating HCCME");
		Matrix HCCME = G.times(GTG_inv);
		//System.out.println("Calculating HCCME");
		HCCME = HCCME.times(GT);

		for(int i = 0; i < grads.length; i++){
			diag[i] = (float)HCCME.get(i, i);
		}
		return diag;
	}
}