package edu.jhu.ece.iacl.plugins.dti;

import java.io.*;
import java.util.List;
import java.util.LinkedList;

import javax.vecmath.Point3f;

import edu.jhu.ece.iacl.algorithms.dti.EstimateTensorLLMSE;
import edu.jhu.bme.smile.commons.textfiles.TextFileReader;
import edu.jhu.ece.iacl.algorithms.dti.EstimateTensorWildBoot;
import edu.jhu.ece.iacl.algorithms.dti.tractography.FiberTracker;
import edu.jhu.ece.iacl.algorithms.dti.tractography.WBFiberDistribution;
import edu.jhu.ece.iacl.jist.io.CurveVtkReaderWriter;
import edu.jhu.ece.iacl.jist.io.FiberCollectionReaderWriter;
import edu.jhu.ece.iacl.jist.io.FileExtensionFilter;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.AlgorithmAuthor;
//import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.Citation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamDouble;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFileCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamObject;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;

import edu.jhu.ece.iacl.jist.structures.fiber.Fiber;
import edu.jhu.ece.iacl.jist.structures.fiber.FiberCollection;


import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataUByte;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MedicAlgorithmWBFiberDistribution extends ProcessingAlgorithm { 
	/****************************************************
	 * Input Parameters 
	 ****************************************************/
	private ParamVolumeCollection DWdata4D; 		// SLAB-enabled Imaging Data
	private ParamVolumeCollection Mask3D;			// SLAB-enabled Binary mask to indicate computation volume
	//private ParamVolumeCollection ROI3D;
	private ParamOption estOptions;		// Option to attempt to estimate with missing data
	private ParamFile bvaluesTable;		// .b file with a list of b-valuesFiberCollectionWithUnc
	private ParamFile gradsTable;		// .grad or .dpf file with a list of gradient directions
	private ParamInteger iterWB;
	
	//private ParamInteger seedsPerVoxel; 
	private ParamDouble startFA;
	private ParamDouble stopFA;
	private ParamInteger turningAngle;
	
	/****************************************************
	 * Output Parameters
	 ****************************************************/
	//private ParamVolumeCollection tensorVolume;	// SLAB-enabled A 4D volume with one tensor estimated per pixel
	private ParamObject<FiberCollection> fibers;
	//private ParamObject<CurveCollection> fiberLines;
	//private ParamBoolean writeVtk;
	//private ParamFile confidenceTable;
	//private ParamFileCollection confidenceTable;
	
	private static final String cvsversion = "$Revision: 1.1 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");
	private static final String shortDescription = "Wild Bootstrap Fiber Distribution";
	private static final String longDescription = "Wild Bootstrap Fiber Distribution";
	File outputdir = this.getOutputDirectory();
	protected void createInputParameters(ParamCollection inputParams) {
		/****************************************************
		 * Step 1. Set Plugin Information 
		 ****************************************************/
		inputParams.setPackage("IACL");
		inputParams.setCategory("DTI");
		inputParams.setLabel("Wild Bootstrap Fiber Distribution");	
		inputParams.setName("Wild Bootstrap Fiber Distribution");

		AlgorithmInformation info = getAlgorithmInformation();
		//info.setWebsite("http://www.nitrc.org/projects/jist/");
		info.add(new AlgorithmAuthor("Chuyang Ye", "cye4@jhu.edu", ""));
		info.setAffiliation("Johns Hopkins University, Department of Electrical and Computer Engineering");
		info.setDescription(shortDescription);
		info.setLongDescription(shortDescription +"\n" +longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.RC);

		/****************************************************
		 * Step 2. Add input parameters to control system 
		 ****************************************************/
		inputParams.add(DWdata4D=new ParamVolumeCollection("DWI and Reference Image(s) Data (4D)"));
		DWdata4D.setLoadAndSaveOnValidate(false);
		inputParams.add(gradsTable=new ParamFile("Table of diffusion weighting directions",new FileExtensionFilter(new String[]{"grad","dpf"})));
		inputParams.add(bvaluesTable=new ParamFile("Table of b-values",new FileExtensionFilter(new String[]{"b"})));
		inputParams.add(Mask3D=new ParamVolumeCollection("Mask Volume to Determine Region of Tensor Estimation (3D)"));
		Mask3D.setLoadAndSaveOnValidate(false);
		Mask3D.setMandatory(false); // Not required. A null mask will estimate all voxels.	
		//inputParams.add(ROI3D=new ParamVolumeCollection("Mask Volume to Determine ROI of Fibers"));
		//ROI3D.setLoadAndSaveOnValidate(false);
		//ROI3D.setMandatory(false); // Not required. A null mask will estimate all voxels.	
		inputParams.add(estOptions=new ParamOption("Attempt to estimate tensors for voxels with missing data?",new String[]{"Yes","No"}));
		inputParams.add(iterWB=new ParamInteger("Wild Boot Iterarion", 20));
		ParamCollection tracking = new ParamCollection("Fiber Tracking");
		tracking.add(startFA=new ParamDouble("Start FA",0,1,0.3));
		tracking.add(stopFA=new ParamDouble("Stop FA",0,1,0.13));
		tracking.add(turningAngle=new ParamInteger("Max Turn Angle",0,180,30));
		startFA.setDescription("The FA at which to begin tracking");
		stopFA.setDescription("The FA at which to terminate tracking");
		turningAngle.setDescription("The angle (in degrees) which terminates tracking");
		//tracking.add(seedsPerVoxel = new ParamInteger("Seeds Per Voxels",1,1000,1));
		//tracking.add(probabilityOfSeedStart = new ParamFloat("Probability of starting each seed", 0, 1, 1));
		//probabilityOfSeedStart.setDescription("Used for downsampling fiber tracking field.");
		inputParams.add(tracking);
	}


	protected void createOutputParameters(ParamCollection outputParams) {
		/****************************************************
		 * Step 1. Add output parameters to control system 
		 ****************************************************/
		//tensorVolume = new ParamVolumeCollection("Tensor Estimate");
		//tensorVolume.setLoadAndSaveOnValidate(false);
		//tensorVolume.setName("Tensor (xx,xy,xz,yy,yz,zz)");
		//outputParams.add(tensorVolume);		
		outputParams.add(fibers=new ParamObject<FiberCollection>("Fibers (DTI Studio)",new FiberCollectionReaderWriter()));	
		//outputParams.add(fiberLines=new ParamObject<CurveCollection>("Fibers (VTK)",new CurveVtkReaderWriter()));		
		//fiberLines.setMandatory(false);
		//outputParams.add(confidenceTable = new ParamFile("confidence table", new FileExtensionFilter(new String[]{"txt"})));
		//outputParams.add(confidenceTable = new ParamFileCollection("confidence table", new FileExtensionFilter(new String[]{"txt"})));
		//confidenceTable.setMandatory(false);//
	}


	protected void execute(CalculationMonitor monitor) throws AlgorithmRuntimeException {		
		TensorEstimationWrapper wrapper=new TensorEstimationWrapper();
		monitor.observe(wrapper);
		wrapper.execute(this);
	}


	protected class TensorEstimationWrapper extends AbstractCalculation {
		
		protected void execute(ProcessingAlgorithm thisParent) {
			/****************************************************
			 * Step 1. Indicate that the plugin has started.
			 * 		 	Tip: Use limited System.out.println statements
			 * 			to allow end users to monitor the status of
			 * 			your program and report potential problems/bugs
			 * 			along with information that will allow you to 
			 * 			know when the bug happened.  
			 ****************************************************/
			System.out.println("Estimation Started");


			/****************************************************
			 * Step 2. Loop over input slabs
			 ****************************************************/
			
			List<ParamVolume> dwList = DWdata4D.getParamVolumeList();
			List<ParamVolume> maskList = Mask3D.getParamVolumeList();
			//List<ParamVolume> ROImaskList = ROI3D.getParamVolumeList();
			int iter = iterWB.getInt();
			
			this.addTotalUnits(dwList.size());
			
			for(int jSlab=0;jSlab<dwList.size();jSlab++) {
				this.setLabel("Load");
				System.out.println("Processing subject: " + jSlab);
				/****************************************************
				 * Step 2. Parse the input data 
				 ****************************************************/
				
				System.out.println("Checking out data");
				ImageData dwd=dwList.get(jSlab).getImageData();
				ImageHeader hdr = dwd.getHeader();
				String sourceName = dwd.getName();
				ImageDataFloat DWFloat=new ImageDataFloat(dwd);
				dwList.get(jSlab).dispose();
				dwd.dispose();
				dwd=null;

				/* Read the brain mask */
				System.out.println("Checking out mask");
				ImageData maskVol=null;
				if(maskList!=null)
					if(maskList.size()>jSlab)
						if(maskList.get(jSlab)!=null)
							maskVol=maskList.get(jSlab).getImageData();


				byte [][][]mask=null;
				if(maskVol!=null) {
					ImageDataUByte maskByte = new ImageDataUByte (maskVol);
					mask = maskByte.toArray3d();
					maskByte.dispose();
					maskByte=null;
					maskVol.dispose();
					maskVol=null;
					maskList.get(jSlab).dispose();
				}
				else{
					System.out.println("Null mask");
				}
				/* Read the ROI mask */
				System.out.println("Checking out ROI mask");
				//ImageData ROImaskVol=null;
				//if(ROImaskList!=null)
					//if(ROImaskList.size()>jSlab)
						//if(ROImaskList.get(jSlab)!=null)
							//ROImaskVol=ROImaskList.get(jSlab).getImageData();


				/*byte [][][]ROImask=null;
				if(ROImaskVol!=null) {
					ImageDataUByte ROImaskByte = new ImageDataUByte (ROImaskVol);
					ROImask = ROImaskByte.toArray3d();
					ROImaskByte.dispose();
					ROImaskByte=null;
					ROImaskVol.dispose();
					ROImaskVol=null;
					//ROImaskList.get(jSlab).dispose();
				}
				else{
					System.out.println("Null ROI mask");
				}*/
				/* Read the b values */
				System.out.println("Checking out b values");
				float [][]bs=null;		
				TextFileReader text = new TextFileReader(bvaluesTable.getValue());
				try {
					bs = text.parseFloatFile();
				} catch (IOException e) 
				{
					JistLogger.logError(JistLogger.WARNING, "WB: Unable to parse b-file.");
					throw new RuntimeException("WB: Unable to parse b-file");
				}

				/* Read the gradient table  */
				System.out.println("Checking out gradient table");
				float [][]grads=null;
				text = new TextFileReader(gradsTable.getValue());
				try {
					grads  = text.parseFloatFile();
				} catch (IOException e) { 
					JistLogger.logError(JistLogger.WARNING, "WB: Unable to parse grad-file.");
					throw new RuntimeException("WB: Unable to parse grad-file");
				}

				/****************************************************
				 * Step 3. Perform limited error checking 
				 ****************************************************/
				// If there are 4 columns in the gradient table, remove the 1st column (indecies)
				if(grads[0].length==4) {
					float [][]g2 = new float[grads.length][3];
					for(int i=0;i<grads.length;i++) 
						for(int j=0;j<3;j++)
							g2[i][j]=grads[i][j+1];
					grads=g2;
				}

				if(grads[0].length!=3){
					JistLogger.logError(JistLogger.WARNING, "Invalid gradient table. Must have 3 or 4 columns.");
					throw new RuntimeException("Invalid gradient table. Must have 3 or 4 columns.");
				}
				if(bs[0].length!=1){
					JistLogger.logError(JistLogger.WARNING, "Invalid b-value table. Must have 1 column.");
					throw new RuntimeException("Invalid b-value table. Must have 1 column.");
				}
				float []bval = new float[bs.length];
				for(int i=0;i<bval.length;i++)
					bval[i]=bs[i][0];

				/****************************************************
				 * Step 4. Wild Bootstrap Tensor Estimation and Tractography
 		  
				 ****************************************************/
				
				this.setLabel("Estimate");
				//System.out.println("Staring initial tensor estimation");

				FiberCollection fibercollection = 
					WBFiberDistribution.track(iter, DWFloat, bval, grads, mask, estOptions.getValue().compareToIgnoreCase("Yes")==0,
							startFA.getDouble(), stopFA.getDouble(), turningAngle.getFloat()*Math.PI/180.0);
				
				System.out.println("Tracking done");
				DWFloat.dispose();
				DWFloat=null;
				
				
				/****************************************************
				 * Step 5. Retrieve the image data and put it into a new
				 * 			data structure. Be sure to update the file information
				 * 			so that the resulting image has the correct
				 * 		 	field of view, resolution, etc.  
				 ****************************************************/
				this.setLabel("Save");
					
				System.out.println("Saving fibers");
				fibercollection.setName(sourceName+"_wild_boot_fibers");
				fibers.setFileName(sourceName+"_wild_boot_fibers");
					//fibers.setObject(fibercollectionWithROI);
				fibers.setObject(fibercollection);
				System.out.println("Saving fibers_vtk");
					//fibercollectionWithROI.setName(sourceName+"_wild_boot_fibers");
					//fiberLines.setFileName(sourceName+"_tovtk");

					//fiberLines.setObject(fibercollectionWithROI.toCurveCollection());
				
				this.incrementCompletedUnits();
				System.out.println("Image finished");
				
			}
				
//			tensorVolume.setValue(outVols);

			/****************************************************
			 * Step 6. Let the user know that your code is finished.  
			 ****************************************************/
			JistLogger.logError(JistLogger.INFO, "WildBootTensorTractography: FINISHED");
			System.out.println("All finished");
//			System.out.println(getClass().getCanonicalName()+"\t"+"DWITensorEstLLMSE: FINISHED");
		}
	}
}

