package edu.jhu.pami.spring2009;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;

import Jama.Matrix;
import edu.jhmi.rad.medic.utilities.MedicUtil;
import edu.jhu.bme.smile.commons.textfiles.TextFileReader;
import edu.jhu.ece.iacl.io.CubicVolumeReaderWriter;
import edu.jhu.ece.iacl.io.FileExtensionFilter;
import edu.jhu.ece.iacl.io.ModelImageReaderWriter;
import edu.jhu.ece.iacl.io.StringReaderWriter;
import edu.jhu.ece.iacl.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation.*;
import edu.jhu.ece.iacl.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFileCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.pipeline.parameter.ParamMatrix;
import edu.jhu.ece.iacl.pipeline.parameter.ParamObject;
import edu.jhu.ece.iacl.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.structures.image.ImageData;
import edu.jhu.ece.iacl.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.structures.image.ImageDataUByte;
import edu.jhu.ece.iacl.utility.FileUtil;
import gov.nih.mipav.model.structures.ModelImage;

public class OptimizeTensorVolume2 extends ProcessingAlgorithm {

	/****************************************************
	 * Input Parameters
	 ****************************************************/
	private ParamFloat param_sigma;
	private ParamVolume param_obsDW;
	private ParamFileCollection param_tensorEstVol;
	private ParamFile bvaluesTable;
	private ParamFile gradsTable;

	/****************************************************
	 * Output Parameters
	 ****************************************************/
	private ParamFileCollection tensorVolume;


	/****************************************************
	 * Other Parameters
	 ****************************************************/
	private static final String rcsid = "";
	private static final String cvsversion = "";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "");

	protected void createInputParameters(ParamCollection inputParams) {

	//Set Plugin Info
	inputParams.setName("OptimizeTensorVolume2");
	inputParams.setLabel("OptimizeTensorVolume2");
	inputParams.setCategory("PAMI");
	inputParams.setPackage("Spring2009");
	AlgorithmInformation info=getAlgorithmInformation();
	info.setWebsite("http://sites.google.com/site/jhupami/");
	info.add(new AlgorithmAuthor("Name","Email","URL"));
	info.setDescription("Optimize a tensor volume with an 8D, Downhill Simplex Method.");
	info.setAffiliation("Johns Hopkins University, Department of Biomedical Engineering");
	info.setVersion(revnum);

	//Set Input Parameters
	inputParams.add(param_tensorEstVol = new ParamFileCollection("Tensor Estimation Volume",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions)));
	inputParams.add(param_sigma = new ParamFloat("sigma"));
	inputParams.add(param_obsDW = new ParamVolume("Observed DW Volume"));
	inputParams.add(gradsTable=new ParamFile("Table of diffusion weighting directions",new FileExtensionFilter(new String[]{"grad","dpf"})));
	inputParams.add(bvaluesTable=new ParamFile("Table of b-values",new FileExtensionFilter(new String[]{"b"})));
	}

	protected void createOutputParameters(ParamCollection outputParams) {
		tensorVolume = new ParamFileCollection("Tensor Estimate",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions));
		tensorVolume.setName("Tensor (xx,xy,xz,yy,yz,zz)");
		outputParams.add(tensorVolume);
	}

	protected void execute(CalculationMonitor monitor) throws AlgorithmRuntimeException {
		//get input parameters

			//need an obsB0 volume, b0 will be obtained form that
				//ImageData obsB0Vol = param_obsB0Vol.getImageData(); //read observed b0 volume
				//float b0 = param_b0.getFloat();
		Matrix imgMatrix = buildImgMatrix();
		float sigma = param_sigma.getFloat();

		ImageData obsDW = param_obsDW.getImageData(); //read observed DW data
		CubicVolumeReaderWriter rw  = CubicVolumeReaderWriter.getInstance(); //read tensor volume
		ImageData tensorEstVol = rw.read(param_tensorEstVol.getValue(0));
		ArrayList<File> outVols = new ArrayList<File>();

		//initialize output (last dimension should be 6)
		float[][][][] tensors = new float[tensorEstVol.getRows()][tensorEstVol.getCols()][tensorEstVol.getSlices()][tensorEstVol.getComponents()];

		//Loop over volume
		int voxelcounter = 1;
		for (int x = 0; x < tensorEstVol.getRows(); x++) {
			for (int y = 0; y < tensorEstVol.getCols(); y++) {
				for (int z = 0; z <  tensorEstVol.getSlices(); z++) {


						System.out.println("Optmizing voxel " + voxelcounter + " of " + tensorEstVol.getRows()*tensorEstVol.getCols()*tensorEstVol.getSlices());

						//get per-voxel b0 and obsB0 values
						double b0 =  obsDW.getDouble(x, y, z, obsDW.getComponents()-1);
						double voxel_obsB0 = b0;

						//get per-voxel obsDW array
						double[] voxel_obsDW = new double[obsDW.getComponents()];
						for (int a = 0; a < voxel_obsDW.length; a++) {
							voxel_obsDW[a] = obsDW.getDouble(x, y, z, a);
						}
						//get tensor estimate
						double[] voxel_dInit = new double[6];
						for (int b = 0; b < voxel_dInit.length; b++) {
							voxel_dInit[b] = tensorEstVol.getDouble(x, y, z, b);
						}

						//Optimize this tensor!
						float []tensor = Opt8DSimplex.estimate(imgMatrix, b0, sigma, voxel_obsB0, voxel_obsDW, voxel_dInit);

						//store result in output volume!
						for (int c = 0; c < tensor.length; c++) {
							tensors[x][y][z][c] = tensor[c];
						}

						voxelcounter++;
				} //end loop over z
			} //end loop over y
		} //end loop over x

		ImageData  out= (new ImageDataFloat(tensors));
		tensors = null;
		out.setName(param_tensorEstVol.getName()+"_TensorOpt");
		File outputSlab = rw.write(out, getOutputDirectory());
		outVols.add(outputSlab);
		out.dispose();
		out=null;
		tensorVolume.setName(param_tensorEstVol.getName()+"_TensorOpt"+".xml");
		tensorVolume.setValue(outVols);
	}

	/**
	 * Build the imgMatrix from the b table and gradient tables!
	 * Methods copied from DWITensorEstLLMSE and EstimateTensorLLMSE.
	 * @param btable
	 * @param gradtable
	 * @return imgMatrix
	 */
	public Matrix buildImgMatrix() {
		/* Read the b values */
		float [][]bs=null;
		TextFileReader text = new TextFileReader(bvaluesTable.getValue());
		try {
			bs = text.parseFloatFile();
		} catch (IOException e)
		{
			throw new RuntimeException("Unable to parse b-file");
		}

		/* Read the gradient table  */
		float [][]grads=null;
		text = new TextFileReader(gradsTable.getValue());
		try {
			grads  = text.parseFloatFile();
		} catch (IOException e) {

			throw new RuntimeException("Unable to parse grad-file");
		}

		/****************************************************
		 * Perform limited error checking
		 ****************************************************/
		// If there are 4 columns in the gradient table, remove the 1st column (indecies)
		if(grads[0].length==4) {
			float [][]g2 = new float[grads.length][3];
			for(int i=0;i<grads.length;i++)
				for(int j=0;j<3;j++)
					g2[i][j]=grads[i][j+1];
			grads=g2;
		}

		if(grads[0].length!=3)
			throw new RuntimeException("LLMSE: Invalid gradient table. Must have 3 or 4 columns.");
		if(bs[0].length!=1)
			throw new RuntimeException("LLMSE: Invalid b-value table. Must have 1 column.");
		float []bvalues = new float[bs.length];
		for(int i=0;i<bvalues.length;i++)
			bvalues[i]=bs[i][0];

		/****************************************************
		 * Step 1: Validate Input Arguments
		 ****************************************************/
		int bvalList[] = null;
		int gradList[] = null;

		int Ngrad = 0;
		int Nb0 = 0;
		for(int i=0;i<bvalues.length;i++) {
			if(bvalues[i]==0)
				Nb0++;
			if(bvalues[i]>0 && grads[i][0]<90)
				Ngrad++;
		}
		if(Nb0==0)
			throw new RuntimeException("EstimateTensorLLMSE: No reference images specified.");

		if(Ngrad<6)
			throw new RuntimeException("EstimateTensorLLMSE: Less than 6 diffusion weighted volumes specified.");

		/****************************************************
		 * Step 2: Index b0 and DW images and normalize DW directions
		 ****************************************************/

		bvalList = new int[Nb0];
		gradList = new int[Ngrad];
		Ngrad = 0;
		Nb0 = 0;
		for(int i=0;i<bvalues.length;i++) {
			if(bvalues[i]==0) {
				bvalList[Nb0]=i;
				Nb0++;
			}

			if(bvalues[i]>0 && grads[i][0]<90) {
				gradList[Ngrad]=i;
				float norm = (float)Math.sqrt(grads[i][0]*grads[i][0]+
						grads[i][1]*grads[i][1]+
						grads[i][2]*grads[i][2]);
				if(norm==0)
					throw new RuntimeException("EstimateTensorLLMSE: Invalid DW Direction "+i+": ("+grads[i][0]+","+grads[i][1]+","+grads[i][2]+");");
				grads[i][0]/=norm;
				grads[i][1]/=norm;
				grads[i][2]/=norm;


				Ngrad++;
			}
		}

		if(Nb0==0)
			throw new RuntimeException("EstimateTensorLLMSE: No reference images specified.");
		if(Ngrad<6)
			throw new RuntimeException("EstimateTensorLLMSE: Less than 6 diffusion weighted volumes specified.");

		Matrix imagMatrix = new Matrix(gradList.length,6);
		for(int ii=0;ii<gradList.length;ii++) {
			//xx
			imagMatrix.set(ii,0,bvalues[gradList[ii]]*grads[gradList[ii]][0]*grads[gradList[ii]][0]);
			//2xy
			imagMatrix.set(ii,1,bvalues[gradList[ii]]*grads[gradList[ii]][0]*grads[gradList[ii]][1]*2);
			//2xz
			imagMatrix.set(ii,2,bvalues[gradList[ii]]*grads[gradList[ii]][0]*grads[gradList[ii]][2]*2);
			//yy
			imagMatrix.set(ii,3,bvalues[gradList[ii]]*grads[gradList[ii]][1]*grads[gradList[ii]][1]);
			//2yz
			imagMatrix.set(ii,4,bvalues[gradList[ii]]*grads[gradList[ii]][1]*grads[gradList[ii]][2]*2);
			//zz
			imagMatrix.set(ii,5,bvalues[gradList[ii]]*grads[gradList[ii]][2]*grads[gradList[ii]][2]);
		}

		return imagMatrix;

	} //end buildImgMatrix


}
