package edu.jhu.ece.iacl.algorithms.dti.tractography;

import edu.jhu.ece.iacl.algorithms.dti.DiffusionTensor;
import edu.jhu.ece.iacl.algorithms.dti.EstimateTensorLLMSE;
import Jama.Matrix;
import java.util.ArrayList;
import edu.jhu.ece.iacl.algorithms.dti.tractography.FiberTracker;
import javax.vecmath.Point3f;

import edu.jhu.ece.iacl.jist.structures.fiber.FiberCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;

public class WBFiberDistribution {
	//private static boolean detailedDebugging = false;
	public static FiberCollection track(int iterWB, ImageDataFloat inDWdataFloat, 
			float []inbvalues, float [][]ingrads, byte [][][]inmask, boolean usePartialEstimates,
			double startFA, double stopFA, double turningAngle) {
		System.out.println("Wild Boot tensor estimation started");
		float [][][][]inDWdata = inDWdataFloat.toArray4d();
		FiberCollection fibercollection=new FiberCollection();
		//preprocess to get real DWI, grads and bval
		int trueLenOfGrads = 0;
		for(int i = 0; i < ingrads.length; i++){
			float []q = ingrads[i];
			if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] < 100){
				trueLenOfGrads++;
			}
		}
		System.out.println("allocating memories");
		float [][][][]DWdata = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length][trueLenOfGrads];
		System.out.println("allocating residuals");
		float [][][][]residuals = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length][trueLenOfGrads];
		System.out.println("allocating Dprojection");
		float [][][][]Dprojection = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length][trueLenOfGrads];
		
		//float [][][][]residuals = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length][trueLenOfGrads];
		//float [][][][]Dprojection = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length][trueLenOfGrads];
		//float [][][][]tempDW = new float[inDWdata.length][inDWdata[0].length][inDWdata[0][0].length][trueLenOfGrads];
		
		System.out.println("3 images");
		float []bvalues = new float[trueLenOfGrads];
		float [][]grads = new float[trueLenOfGrads][3];
		
		int indexOfDWdata = 0;
		for(int i = 0;i < ingrads.length; i++){
			float []q = ingrads[i];
			if(q[0]*q[0] + q[1]*q[1] + q[2]*q[2] < 100){
				for(int ii = 0; ii < inDWdata.length; ii++){
					for(int jj = 0; jj < inDWdata[0].length; jj++){
						for(int kk = 0; kk < inDWdata[0][0].length; kk++){
							DWdata[ii][jj][kk][indexOfDWdata] = inDWdata[ii][jj][kk][i];
						}
					}
				}
				bvalues[indexOfDWdata] = inbvalues[i];
				grads[indexOfDWdata] = ingrads[i];
				indexOfDWdata++;
			}
			//System.out.println(indexOfDWdata);
		}
		
		//mask
		byte [][][]mask = inmask;
		
		//HCCME
		float []diagOfHCCME = diagOfHCCME(grads, bvalues);
		//for(int i=0;i<diagOfHCCME.length;i++){
			//System.out.println("diagOfHCCME: "+diagOfHCCME[i]);
		//}
		//initial estimation
		float tensors[][][][] = EstimateTensorLLMSE.estimate(DWdata, bvalues, grads, mask, usePartialEstimates);
		
		System.out.println("Initial tensor estimation complete");
		//residuals
		System.out.println("Computing residuals");
		System.out.println(inDWdata.length);
		System.out.println(inDWdata[0].length);
		System.out.println(inDWdata[0][0].length);
		System.out.println(trueLenOfGrads);
		
		System.out.println("Computing residuals");
		for(int i = 0; i < DWdata.length;i++){
			for(int j = 0; j < DWdata[0].length;j++){
				for(int k = 0; k < DWdata[0][0].length;k++){
					//System.out.println(i+" "+j+" "+k);
					if(mask!=null){
						if(mask[i][j][k] == 0){
							tensors[i][j][k][0] = 0;
							tensors[i][j][k][1] = 0;
							tensors[i][j][k][2] = 0;
							tensors[i][j][k][3] = 0;
							tensors[i][j][k][4] = 0;
							tensors[i][j][k][5] = 0;
							continue;
						}
					}
					Dprojection[i][j][k] = 
						projectTensorLog(DWdata[i][j][k][0], tensors[i][j][k], grads, bvalues);
					for(int l = 0; l < Dprojection[0][0][0].length; l++){
						//System.out.println("l: "+l);
						residuals[i][j][k][l] = Dprojection[i][j][k][l]-(float)Math.log(DWdata[i][j][k][l]);
					}
				}
			}
		}
		System.out.println("Starting WB and tracking");
		//WB and fiber tracking
		
		
		for(int iter=0;iter<iterWB;iter++){
			System.out.println("Iteration: "+iter);
			for(int i = 0; i < DWdata.length;i++){
				for(int j = 0; j < DWdata[0].length;j++){
					for(int k = 0; k < DWdata[0][0].length;k++){
						if(mask!=null){
							if(mask[i][j][k] == 0){
								//tempDW[i][j][k][0] = Float.NaN;
								//tempDW[i][j][k][1] = Float.NaN;
								//tempDW[i][j][k][2] = Float.NaN;
								//tempDW[i][j][k][3] = Float.NaN;
								//tempDW[i][j][k][4] = Float.NaN;
								//tempDW[i][j][k][5] = Float.NaN;
								continue;
							}
						}
						for(int l = 0; l < trueLenOfGrads; l++){
							boolean sign = Math.random()>.5;
							if(sign) {
								if(l!=0){
									DWdata[i][j][k][l] = 
										(float)Math.exp(((float)Dprojection[i][j][k][l]+(1/(1-diagOfHCCME[l])*(float)residuals[i][j][k][l])));
									//System.out.println("tempDW:" + tempDW[0][0][0][l]);
								}
								else{
									DWdata[i][j][k][l] = DWdata[i][j][k][l];
								}
								
							} 
							else {
								if(l!=0){
									DWdata[i][j][k][l] = 
										(float)Math.exp(((float)Dprojection[i][j][k][l]-(1/(1-diagOfHCCME[l])*(float)residuals[i][j][k][l])));
									//System.out.println("tempDW:" + tempDW[0][0][0][l]);
								}
								
							}
						}
					}
				}
			}
			tensors = 
				EstimateTensorLLMSE.estimate(DWdata, bvalues, grads, mask, usePartialEstimates);
			ImageHeader hdr = inDWdataFloat.getHeader();
			float []res = hdr.getDimResolutions();
			System.out.println(iter+" tensor estimation complete");
			Point3f resForFACT = new Point3f();
			resForFACT.set(res[0], res[1], res[2]);
			float [][][]FA = new float[DWdata.length][DWdata[0].length][DWdata[0][0].length];
			float [][][][]VEC1 = new float[DWdata.length][DWdata[0].length][DWdata[0][0].length][3];
			for(int i = 0; i < DWdata.length;i++){
				for(int j = 0; j < DWdata[0].length;j++){
					for(int k = 0; k < DWdata[0][0].length;k++){
						if(mask!=null){
							if(mask[i][j][k]==0){
								FA[i][j][k] = 0;
								continue;
							}
						}
						
						DiffusionTensor dt = new DiffusionTensor(tensors[i][j][k]);
						FA[i][j][k] = dt.FA();
						VEC1[i][j][k] = dt.vec1();
					}
				}
			}
			ImageDataFloat imageFA = new ImageDataFloat(FA);
			ImageDataFloat imageVEC1 = new ImageDataFloat(VEC1);
			if(iter!=0){
				FiberTracker track = 
					new FiberTracker(startFA, stopFA, turningAngle, resForFACT, imageFA, imageVEC1);
				FiberCollection fibers = track.solve();
				for(int index=0;index<fibers.size();index++){
					fibercollection.add(fibers.get(index));
				}
			}
			else{
				FiberTracker track = 
					new FiberTracker(startFA, stopFA, turningAngle, resForFACT, imageFA, imageVEC1);
				fibercollection = track.solve();
			}
				
		}
		
		return fibercollection;
	}
	static public float[] projectTensorLog(float S0, float[] outTensors, float [][]grads, float []bvalues) {
		//
		//System.out.println("Projecting Tensor (Log)");
		float []proj = new float[grads.length];
		float Dxx = outTensors[0];
		float Dxy = outTensors[1];
		float Dxz = outTensors[2];
		float Dyy = outTensors[3];
		float Dyz = outTensors[4];
		float Dzz = outTensors[5];
		        
		for(int i=0;i<grads.length;i++) {
			
			float []q = grads[i];
			
			float Gx = q[0];
			float Gy = q[1];
			float Gz = q[2];

			float norm = (float)Math.sqrt(Gx * Gx + Gy * Gy + Gz * Gz);
			if(norm > 100){
				continue;
			}
			if(norm != 0){
				Gx = Gx/norm;
				Gy = Gy/norm;
				Gz = Gz/norm;
			}
		
			 double temp = -bvalues[i] *
					(Gx*Gx*Dxx + 
							2*Gx*Gy*Dxy+
							2*Gx*Gz*Dxz+
							Gy*Gy*Dyy+
							2*Gy*Gz*Dyz+
							Gz*Gz*Dzz) + Math.log(S0);
			 proj[i] = (float) temp;
		}
		
		return proj;
	}
	static public float[] diagOfHCCME(float [][]grads, float []bvalues) {
		System.out.println("Calculating HCCME");
		float []diag = new float[grads.length];
		//Matrix HCCME = new Matrix(grads.length, grads.length, 0);
		Matrix G = new Matrix(grads.length, 7, 0);
		for(int i = 0; i < grads.length; i++){
			float []q = grads[i];
			
			double Gx = q[0];
			double Gy = q[1];
			double Gz = q[2];
			double norm = Math.sqrt(Gx * Gx + Gy * Gy + Gz * Gz);
			if(norm != 0){
				Gx = Gx/norm;
				Gy = Gy/norm;
				Gz = Gz/norm;
			}
			G.set(i,0,-Gx*Gx*bvalues[i]);
			G.set(i,1,-Gx*Gy*bvalues[i]);
			G.set(i,2,-Gx*Gz*bvalues[i]);
			G.set(i,3,-Gy*Gy*bvalues[i]);
			G.set(i,4,-Gy*Gz*bvalues[i]);
			G.set(i,5,-Gz*Gz*bvalues[i]);
			G.set(i,6,1);
		}
		//System.out.println("Calculating GT");
		Matrix GT = G.transpose();
		//System.out.println("Calculating GTG");
		Matrix GTG = GT.times(G);
		//System.out.println("Calculating GTG_inv");
		Matrix GTG_inv = GTG.inverse();
		//System.out.println("Calculating HCCME");
		Matrix HCCME = G.times(GTG_inv);
		//System.out.println("Calculating HCCME");
		HCCME = HCCME.times(GT);
		
		for(int i = 0; i < grads.length; i++){
			diag[i] = (float)HCCME.get(i, i);
		}
		return diag;
	}
}