/**
 * 
 */
package edu.jhu.ece.iacl.algorithms.dti;

import static org.junit.Assert.*;

import java.text.NumberFormat;
import java.util.Random;

import junit.framework.TestCase;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import Jama.Matrix;

/**
 * @author bennett
 *
 */
public class DTEMRLTensorTest extends TestCase {

	private DTEMRLTensor tensor;
	private static double TOLERANCE = 1e-6;



	private static double b0 = 1000, bvalue = 700, sigma = 5;
	private static double g[][] = new double[12][3];
	private static double[] obsDW;
	private static double[] obsB0;
	private static Matrix imgMatrix;
	private static Matrix trueTensor = new Matrix(new double[]{1e-3,0,0,2e-3,0,3e-3},6);
	private static double []kspaceAverages;


	/**
	 * Creates a random tensor.
	 */
	public static void randomizeObservedData() {
		for (int i=0; i<12; i++)
			for (int j=0; j<3; j++)
				g[i][j] = Math.random() - 0.5;
		Matrix G = normalizeGrad(g);
		imgMatrix = makeImgMatrix(G, bvalue);
		Matrix GG = makeImgMatrix(G, -1);

		obsB0 = new double[1];
		obsB0[0] = b0;
		obsDW = GG.times(trueTensor).times(-1*bvalue).getColumnPackedCopy();
		for (int i=0; i<obsDW.length; i++) 
			obsDW[i] = b0 * Math.exp(obsDW[i]); 
		kspaceAverages = new double[obsDW.length+obsB0.length];
		for(int i=0;i<kspaceAverages.length;i++)
			kspaceAverages[i]=1;
		/*Random r = new Random();
		obsB0 = Math.round(Math.sqrt(Math.pow(b0+sigma*r.nextGaussian(), 2) + Math.pow(r.nextGaussian(), 2)));
		obsDW = GG.times(d).times(-1*bvalue).getColumnPackedCopy();
		for (int i=0; i<obsDW.length; i++) {
			obsDW[i] = b0 * Math.exp(obsDW[i]) + sigma * r.nextGaussian();
			obsDW[i] = Math.round(Math.sqrt(Math.pow(obsDW[i], 2) + Math.pow(sigma * r.nextGaussian(), 2)));
		}*/
	}

	/**
	 * normalizes the gradient matrix
	 * @param g gradient matrix to normalize
	 * @return normalized gradient matrix
	 */
	public static Matrix normalizeGrad(double[][] g) {
		double r[] = new double[g.length];
		for (int i=0; i<g.length; i++) {
			double sum = 0;
			for (int j=0; j<g[0].length; j++)
				sum += Math.pow(g[i][j],2);
			sum = Math.sqrt(sum);
			for (int j=0; j<g[0].length; j++)
				g[i][j] /= sum;
			r[i] = sum;
		}
		return new Matrix(g);
	}

	/**
	 * makes the imgMatrix
	 * @param m normalized gradient matrix
	 * @param bvalue the bvalue
	 * @return imgMatrix
	 */
	public static Matrix makeImgMatrix(Matrix m, double bvalue) {
		int rows = m.getRowDimension();
		Matrix imagMatrix = new Matrix(rows,6);
		rows--;
		Matrix x = m.getMatrix(0,rows,0,0);
		Matrix y = m.getMatrix(0,rows,1,1);
		Matrix z = m.getMatrix(0,rows,2,2);
		imagMatrix.setMatrix(0,rows,0,0,x.arrayTimes(x).times(bvalue*-1));		//xx
		imagMatrix.setMatrix(0,rows,1,1,x.arrayTimes(y).times(2*bvalue*-1));	//2xy
		imagMatrix.setMatrix(0,rows,2,2,x.arrayTimes(z).times(2*bvalue*-1));	//2xz
		imagMatrix.setMatrix(0,rows,3,3,y.arrayTimes(y).times(bvalue*-1));		//yy
		imagMatrix.setMatrix(0,rows,4,4,y.arrayTimes(z).times(2*bvalue*-1));	//2yz
		imagMatrix.setMatrix(0,rows,5,5,z.arrayTimes(z).times(bvalue*-1));		//zz
		return imagMatrix;
	}



	/**
	 * @throws java.lang.Exception
	 */
	@Before
	public void setUp() throws Exception {
		tensor = new DTEMRLTensor();
		randomizeObservedData();
		double[] dInit = new double[]{0.8e-3,0.2e-3,0.2e-3,1.8e-3,0.2e-3,2.8e-3};
		tensor.init(imgMatrix, b0, sigma, obsB0, obsDW, dInit, kspaceAverages);
	}

	/**
	 * @throws java.lang.Exception
	 */
	@After
	public void tearDown() throws Exception {
		tensor = null;
	}

	/**
	 * Test method for {@link edu.jhu.ece.iacl.algorithms.dti.DTEMRLTensor#getB0()}.
	 */
	@Test
	public void testGetB0() {
		double R = Math.random()*1000;
		double oldB0 = tensor.getB0();
		tensor.setB0(R);
		double newB0 = tensor.getB0();
		if(newB0!=R)
			fail("GetB0 failed: B0 did not update with set. Old="+oldB0+" Set="+R+" Get="+newB0); 
	}


	/**
	 * Test method for {@link edu.jhu.ece.iacl.algorithms.dti.DTEMRLTensor#getSigma()}.
	 */
	@Test
	public void testGetSigma() {
		double S = Math.random()*1000;
		double oldB0 = tensor.getSigma();
		tensor.setSigma(S);
		double newB0 = tensor.getSigma();
		if(Math.abs(newB0-S)>TOLERANCE)
			fail("GetSigma failed: B0 did not update with set. Old="+oldB0+" Set="+S+" Get="+newB0);
		tensor.setSigma(-1);
		if(Math.abs(tensor.getSigma()-1e-6)>TOLERANCE)
			fail("GetSigma failed: Negative sigma resulted in non-positve sigma. Old="+oldB0+" Set="+-1+" Get="+tensor.getSigma());
	}

	/**
	 * Test method for {@link edu.jhu.ece.iacl.algorithms.dti.DTEMRLTensor#getD()}.
	 */
	@Test
	public void testGetD() {
		Matrix D = new Matrix(new double[]{1e-3,0,0,2e-3,0,3e-3},6);
		tensor.setUseDirectTensorRepresentation(true);
		tensor.setD(D.getColumnPackedCopy());
		Matrix D2 = tensor.getD();
		for(int i=0;i<6;i++){
			if(Math.abs(D.get(i,0)-D2.get(i,0))>TOLERANCE)
				fail("getD/setD: Direct representation not equiv. Set:"+D.toString()+" Get:"+D2.toString());
		}
		D = new Matrix(new double[]{2e-3,0,0,5e-3,0,7e-3},6);
		tensor.setUseDirectTensorRepresentation(false);
		tensor.setD(D.getColumnPackedCopy());
		D2 = tensor.getD();
		for(int i=0;i<6;i++){
			if(Math.abs(D.get(i,0)-D2.get(i,0))>TOLERANCE) {
				D.print(6, 5);
				D2.print(6, 5);
				fail("getD/setD: Rod. representation not equiv.");
			}
		}		 		
	}

	@Test
	public void testOptimize8DDirect() {
		tensor = new DTEMRLTensor();
		
			tensor.setUseDirectTensorRepresentation(true);
			randomizeObservedData();
			double[] dInit = new double[]{0.8e-3,0.2e-3,0.2e-3,1.8e-3,0.2e-3,2.8e-3};
			tensor.init(imgMatrix, b0, sigma, obsB0, obsDW, dInit, kspaceAverages);
			if(!tensor.optimize8D())
				fail("Optimization 8D did not converge");
			Matrix D2 = tensor.getD();
			for(int i=0;i<6;i++){
				if(Math.abs(trueTensor.get(i,0)-D2.get(i,0))>TOLERANCE*100) {
					for(int k=0;k<6;k++) {
						System.out.println(":"+trueTensor.get(k,0)+"\t"+D2.get(k,0));						
					}
					System.out.println("s:"+sigma+"\t"+tensor.getSigma());
					System.out.println("b:"+b0+"\t"+tensor.getB0());
					fail("optimize8D: Rod. representation not equiv.");
				}
			}		

	}
	
	@Test
	public void testOptimize6DDirect() {
		tensor = new DTEMRLTensor();
		
			tensor.setUseDirectTensorRepresentation(true);
			randomizeObservedData();
			double[] dInit = new double[]{0.8e-3,0.2e-3,0.2e-3,1.8e-3,0.2e-3,2.8e-3};
			tensor.init(imgMatrix, b0, sigma, obsB0, obsDW, dInit, kspaceAverages);
			if(!tensor.optimize6D2D())
				fail("Optimization 6D did not converge");
			Matrix D2 = tensor.getD();
			for(int i=0;i<6;i++){
				if(Math.abs(trueTensor.get(i,0)-D2.get(i,0))>TOLERANCE*100) {
					for(int k=0;k<6;k++) {
						System.out.println(":"+trueTensor.get(k,0)+"\t"+D2.get(k,0));						
					}
					System.out.println("s:"+sigma+"\t"+tensor.getSigma());
					System.out.println("b:"+b0+"\t"+tensor.getB0());
					fail("optimize8D: Rod. representation not equiv.");
				}
			}		

	}	
		
	@Test
	public void testOptimize8DRod() {
		tensor = new DTEMRLTensor();
		
			tensor.setUseDirectTensorRepresentation(false);
			randomizeObservedData();
			double[] dInit = new double[]{0.8e-3,0.2e-3,0.2e-3,1.8e-3,0.2e-3,2.8e-3};
			tensor.init(imgMatrix, b0, sigma, obsB0, obsDW, dInit, kspaceAverages);
			if(!tensor.optimize8D())
				fail("Optimization 8D did not converge");
			Matrix D2 = tensor.getD();
			for(int i=0;i<6;i++){
				if(Math.abs(trueTensor.get(i,0)-D2.get(i,0))>TOLERANCE*100) {
					for(int k=0;k<6;k++) {
						System.out.println(":"+trueTensor.get(k,0)+"\t"+D2.get(k,0));						
					}
					System.out.println("s:"+sigma+"\t"+tensor.getSigma());
					System.out.println("b:"+b0+"\t"+tensor.getB0());
					
					fail("optimize8D: Rod. representation not equiv.");
				}
			}		

	}
	
	@Test
	public void testOptimize6DRod() {
		tensor = new DTEMRLTensor();
		
			tensor.setUseDirectTensorRepresentation(false);
			randomizeObservedData();
			double[] dInit = new double[]{0.8e-3,0.2e-3,0.2e-3,1.8e-3,0.2e-3,2.8e-3};
			tensor.init(imgMatrix, b0, sigma, obsB0, obsDW, dInit, kspaceAverages);
			if(!tensor.optimize6D2D())
				fail("Optimization 6D did not converge");
			Matrix D2 = tensor.getD();
			for(int i=0;i<6;i++){
				if(Math.abs(trueTensor.get(i,0)-D2.get(i,0))>TOLERANCE*100) {
					for(int k=0;k<6;k++) {
						System.out.println(":"+trueTensor.get(k,0)+"\t"+D2.get(k,0));						
					}
					System.out.println("s:"+sigma+"\t"+tensor.getSigma());
					System.out.println("b:"+b0+"\t"+tensor.getB0());
					fail("optimize8D: Rod. representation not equiv.");
				}
			}		

	}
	
	
	

}
