package edu.jhu.ece.iacl.algorithms.hardi;

import Jama.Matrix;
import edu.jhu.bme.smile.commons.math.specialFunctions.ComplexVector;
import edu.jhu.bme.smile.commons.math.specialFunctions.SphericalHarmonicRepresentation;
import edu.jhu.bme.smile.commons.math.specialFunctions.SphericalHarmonics;

public class QBall {

	public static int getNumberOfCoefficients(int order) {
		return SphericalHarmonics.numberOfCoefficientsInEvenOrder(order);
	}

	public static float [][][][] estimateSphericalHarmonics(float[][][][] coeff_real,float[][][][] coeff_imag,
			int[][] lm, float[][][][] DWdata, float[] bval, float[][] grads,
			byte[][][] mask, int SHorder, float b0Threshold) {

		/****************************************************
		 * Step 1: Validate Input Arguments 
		 ****************************************************/
		int bvalList[] = null;
		int gradList[] = null;
		int Nvols = DWdata[0][0][0].length;
		if(Nvols!=bval.length)
			throw new RuntimeException("estimateSphericalHarmonics: Number of volumes does not match number of bvalues.");
		if(Nvols!=grads.length)
			throw new RuntimeException("estimateSphericalHarmonics: Number of volumes does not match number of gradient directions.");

		if(mask==null) {
			mask = new byte[DWdata.length][DWdata[0].length][DWdata[0][0].length];
			for(int i=0;i<DWdata.length;i++) {
				for(int j=0;j<DWdata[0].length;j++) {
					for(int k=0;k<DWdata[0][0].length;k++)
						mask[i][j][k]=1;
				}
			}
		} 
		if(mask.length!=DWdata.length)
			throw new RuntimeException("estimateSphericalHarmonics: Mask does not match data in dimension: 0.");
		if(mask[0].length!=DWdata[0].length)
			throw new RuntimeException("estimateSphericalHarmonics: Mask does not match data in dimension: 1.");
		if(mask[0][0].length!=DWdata[0][0].length)
			throw new RuntimeException("estimateSphericalHarmonics: Mask does not match data in dimension: 2.");


		int Ngrad = 0;
		int Nb0 = 0; 
		for(int i=0;i<bval.length;i++) {
			if(bval[i]==0)
				Nb0++;
			if(bval[i]>0 && grads[i][0]<90)
				Ngrad++;
		}

		/****************************************************
		 * Step 2: Index b0 and DW images and normalize DW directions
		 ****************************************************/

		bvalList = new int[Nb0];
		gradList = new int[Ngrad];
		Ngrad = 0;
		Nb0 = 0; 
		for(int i=0;i<bval.length;i++) {
			if(bval[i]==0) {
				bvalList[Nb0]=i;
				Nb0++;
			}

			if(bval[i]>0 && grads[i][0]<90) {
//				System.out.println(grads[i][0]+" "+grads[i][1]+" "+grads[i][2]);
				gradList[Ngrad]=i;
				float norm = (float)Math.sqrt(grads[i][0]*grads[i][0]+
						grads[i][1]*grads[i][1]+
						grads[i][2]*grads[i][2]);
				if(norm==0)
					throw new RuntimeException("estimateSphericalHarmonics: Invalid DW Direction "+i+": ("+grads[i][0]+","+grads[i][1]+","+grads[i][2]+");");
				grads[i][0]/=norm;
				grads[i][1]/=norm;
				grads[i][2]/=norm;


				Ngrad++;
			}
		}
		
		float [][]gradonly = new float[Ngrad][3];
		for(int i=0;i<Ngrad;i++)
			for(int j=0;j<3;j++)
				gradonly[i][j]=grads[gradList[i]][j];

		/****************************************************
		 * Step 3: Build the imaging and inversion matrix 
		 ****************************************************/
System.out.println("Init SHT");
		SphericalHarmonics sht=SphericalHarmonics.evenOrderTransform(SHorder,gradonly);
		sht.recordLM(lm);
		double []dataForOneVoxel = new double[gradList.length]; 
		/****************************************************
		 * Step 4: Loop over all voxels and estimate  
		 ****************************************************/
		float mb0=0;
		int cnt;

		for(int i=0;i<DWdata.length;i++) {
			for(int j=0;j<DWdata[0].length;j++) {
				for(int k=0;k<DWdata[0][0].length;k++) {
					if(mask[i][j][k]!=0) {
						mb0=0;
						cnt=0;
						for(int ii=0;ii<bvalList.length;ii++) {
							if(!Float.isNaN(DWdata[i][j][k][bvalList[ii]])) {
								mb0+=DWdata[i][j][k][bvalList[ii]];
								cnt++;
							}
						}						
						mb0/=cnt;
						if(mb0<=b0Threshold) {
							mask[i][j][k]=0;
						}
					}
					if(mask[i][j][k]!=0) {
						for(int ii=0;ii<gradList.length;ii++) {
							dataForOneVoxel[ii] = DWdata[i][j][k][gradList[ii]]/mb0;								
						}
						SphericalHarmonicRepresentation result = sht.transform(dataForOneVoxel);
						for(int m=0;m<result.length();m++) {
							coeff_real[i][j][k][m]=(float)result.getReal(m);
							coeff_imag[i][j][k][m]=(float)result.getImag(m);
						}
					} else {
						for(int m=0;m<coeff_real[i][j][k].length;m++) {
							coeff_real[i][j][k][m]=Float.NaN;
							coeff_imag[i][j][k][m]=Float.NaN;
						}
					}						
				}
			}
		}
		return coeff_imag;
	}

	public static float [][][][] estimateSphericalHarmonicsADC(float[][][][] coeff_real,float[][][][] coeff_imag,
			int[][] lm, float[][][][] DWdata, float[] bval, float[][] grads,
			byte[][][] mask, int SHorder, float b0Threshold) {

		/****************************************************
		 * Step 1: Validate Input Arguments 
		 ****************************************************/
		int bvalList[] = null;
		int gradList[] = null;
		int Nvols = DWdata[0][0][0].length;
		if(Nvols!=bval.length)
			throw new RuntimeException("estimateSphericalHarmonics: Number of volumes does not match number of bvalues.");
		if(Nvols!=grads.length)
			throw new RuntimeException("estimateSphericalHarmonics: Number of volumes does not match number of gradient directions.");

		if(mask==null) {
			mask = new byte[DWdata.length][DWdata[0].length][DWdata[0][0].length];
			for(int i=0;i<DWdata.length;i++) {
				for(int j=0;j<DWdata[0].length;j++) {
					for(int k=0;k<DWdata[0][0].length;k++)
						mask[i][j][k]=1;
				}
			}
		} 
		if(mask.length!=DWdata.length)
			throw new RuntimeException("estimateSphericalHarmonics: Mask does not match data in dimension: 0.");
		if(mask[0].length!=DWdata[0].length)
			throw new RuntimeException("estimateSphericalHarmonics: Mask does not match data in dimension: 1.");
		if(mask[0][0].length!=DWdata[0][0].length)
			throw new RuntimeException("estimateSphericalHarmonics: Mask does not match data in dimension: 2.");


		int Ngrad = 0;
		int Nb0 = 0; 
		for(int i=0;i<bval.length;i++) {
			if(bval[i]==0)
				Nb0++;
			if(bval[i]>0 && grads[i][0]<90)
				Ngrad++;
		}

		/****************************************************
		 * Step 2: Index b0 and DW images and normalize DW directions
		 ****************************************************/

		bvalList = new int[Nb0];
		gradList = new int[Ngrad];
		Ngrad = 0;
		Nb0 = 0; 
		for(int i=0;i<bval.length;i++) {
			if(bval[i]==0) {
				bvalList[Nb0]=i;
				Nb0++;
			}

			if(bval[i]>0 && grads[i][0]<90) {
//				System.out.println(grads[i][0]+" "+grads[i][1]+" "+grads[i][2]);
				gradList[Ngrad]=i;
				float norm = (float)Math.sqrt(grads[i][0]*grads[i][0]+
						grads[i][1]*grads[i][1]+
						grads[i][2]*grads[i][2]);
				if(norm==0)
					throw new RuntimeException("estimateSphericalHarmonics: Invalid DW Direction "+i+": ("+grads[i][0]+","+grads[i][1]+","+grads[i][2]+");");
				grads[i][0]/=norm;
				grads[i][1]/=norm;
				grads[i][2]/=norm;


				Ngrad++;
			}
		}
		
		float [][]gradonly = new float[Ngrad][3];
		for(int i=0;i<Ngrad;i++)
			for(int j=0;j<3;j++)
				gradonly[i][j]=grads[gradList[i]][j];

		/****************************************************
		 * Step 3: Build the imaging and inversion matrix 
		 ****************************************************/
System.out.println("Init SHT");
		SphericalHarmonics sht=SphericalHarmonics.evenOrderTransform(SHorder,gradonly);
		sht.recordLM(lm);
		double []dataForOneVoxel = new double[gradList.length]; 
		/****************************************************
		 * Step 4: Loop over all voxels and estimate  
		 ****************************************************/
		float mb0=0;
		int cnt;

		for(int i=0;i<DWdata.length;i++) {
			for(int j=0;j<DWdata[0].length;j++) {
				for(int k=0;k<DWdata[0][0].length;k++) {
					if(mask[i][j][k]!=0) {
						mb0=0;
						cnt=0;
						for(int ii=0;ii<bvalList.length;ii++) {
							if(!Float.isNaN(DWdata[i][j][k][bvalList[ii]])) {
								mb0+=DWdata[i][j][k][bvalList[ii]];
								cnt++;
							}
						}						
						mb0/=cnt;
						if(mb0<=b0Threshold) {
							mask[i][j][k]=0;
						}
					}
					if(mask[i][j][k]!=0) {
						for(int ii=0;ii<gradList.length;ii++) {
							// - ln(dw/ref)
							dataForOneVoxel[ii] = Math.log(mb0/DWdata[i][j][k][gradList[ii]]);								
						}
						SphericalHarmonicRepresentation result = sht.transform(dataForOneVoxel);
						for(int m=0;m<result.length();m++) {
							coeff_real[i][j][k][m]=(float)result.getReal(m);
							coeff_imag[i][j][k][m]=(float)result.getImag(m);
						}
					} else {
						for(int m=0;m<coeff_real[i][j][k].length;m++) {
							coeff_real[i][j][k][m]=Float.NaN;
							coeff_imag[i][j][k][m]=Float.NaN;
						}
					}						
				}
			}
		}
		return coeff_imag;
	}

	public static int getOrderFromNumberOfCoeff(int components) {
		int order =0;
		int comp = SphericalHarmonics.numberOfCoefficientsInEvenOrder(order);
		System.out.println("A order: "+order+" comp:"+comp);
		while(components>comp) {
			System.out.println("C order: "+order+" comp:"+comp);
			order+=2;
			comp = SphericalHarmonics.numberOfCoefficientsInEvenOrder(order);
			System.out.println("B order: "+order+" comp:"+comp);
		}
		if(comp==components)
			return order;
		throw new RuntimeException("Invalid number of SH coefficients for an even order:"+components);

	}

	public static void projectSphericalHarmonics(float[][][][] reconReal,
			float[][][][] reconImag, float[][][][] realCoef,
			float[][][][] imagCoef, float[][] grads, int SHorder) {

		/****************************************************
		 * Step 2: Index b0 and DW images and normalize DW directions
		 ****************************************************/
		for(int i=0;i<grads.length;i++) {

			float norm = (float)Math.sqrt(grads[i][0]*grads[i][0]+
					grads[i][1]*grads[i][1]+
					grads[i][2]*grads[i][2]);
			if(norm==0)
				throw new RuntimeException("projectSphericalHarmonics: Invalid DW Direction "+i+": ("+grads[i][0]+","+grads[i][1]+","+grads[i][2]+");");
			grads[i][0]/=norm;
			grads[i][1]/=norm;
			grads[i][2]/=norm;


		}

		/****************************************************
		 * Step 3: Build the imaging and inversion matrix 
		 ****************************************************/

		SphericalHarmonics sht=SphericalHarmonics.evenOrderTransform(SHorder,grads);
//		System.out.println(sht.toString());
		double []realDataForOneVoxel = new double[realCoef[0][0][0].length];
		double []imagDataForOneVoxel = new double[realCoef[0][0][0].length];
		/****************************************************
		 * Step 4: Loop over all voxels and estimate  
		 ****************************************************/

		for(int i=0;i<realCoef.length;i++) {
			for(int j=0;j<realCoef[0].length;j++) {
				for(int k=0;k<realCoef[0][0].length;k++) {
					if(!(Float.isInfinite(realCoef[i][j][k][0])||
							Float.isNaN(realCoef[i][j][k][0]))) {
						for(int ii=0;ii<realCoef[0][0][0].length;ii++) {
							realDataForOneVoxel[ii] = realCoef[i][j][k][ii];
							imagDataForOneVoxel[ii] = imagCoef[i][j][k][ii];
						}
						ComplexVector result = sht.inverseTransform(realDataForOneVoxel, imagDataForOneVoxel);
						for(int m=0;m<result.length();m++) {
							reconReal[i][j][k][m]=(float)result.getReal(m);
						}
						if(reconImag!=null) {
							for(int m=0;m<result.length();m++) {
								reconImag[i][j][k][m]=(float)result.getImag(m);
							}
						}
					} else {
						for(int m=0;m<reconReal[i][j][k].length;m++) {
							reconReal[i][j][k][m]=Float.NaN;
						}
						if(reconImag!=null) {
							for(int m=0;m<reconReal[i][j][k].length;m++) {
								reconImag[i][j][k][m]=Float.NaN;
							}
						}

					}						
				}
			}
		}

	}


}


