package edu.jhu.ece.iacl.plugins.dti;

import java.io.*;
import java.util.List;
import java.util.LinkedList;

import javax.vecmath.Point3f;

import edu.jhu.ece.iacl.algorithms.dti.DiffusionTensor;
import edu.jhu.ece.iacl.algorithms.dti.EstimateTensorLLMSE;
import edu.jhu.bme.smile.commons.textfiles.TextFileReader;
//import edu.jhu.ece.iacl.algorithms.dti.EstimateTensorWildBoot;
import edu.jhu.ece.iacl.algorithms.dti.tractography.FiberTracker;
import edu.jhu.ece.iacl.algorithms.dti.testWBFiberDistribution;
import edu.jhu.ece.iacl.jist.io.CurveVtkReaderWriter;
import edu.jhu.ece.iacl.jist.io.FiberCollectionReaderWriter;
import edu.jhu.ece.iacl.jist.io.FileExtensionFilter;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.AlgorithmAuthor;
//import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.Citation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamDouble;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFileCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamObject;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;

import edu.jhu.ece.iacl.jist.structures.fiber.Fiber;
import edu.jhu.ece.iacl.jist.structures.fiber.FiberCollection;


import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataUByte;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.jist.structures.image.VoxelType;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.jist.io.StringReaderWriter;
import edu.jhu.ece.iacl.jist.io.ImageDataReaderWriter;

public class TestNewWB extends ProcessingAlgorithm { 
	/****************************************************
	 * Input Parameters 
	 ****************************************************/
	private ParamVolumeCollection DWdata4D; 		// SLAB-enabled Imaging Data
	private ParamVolumeCollection Mask3D;			// SLAB-enabled Binary mask to indicate computation volume
	//private ParamVolumeCollection ROI3D;
	private ParamOption estOptions;		// Option to attempt to estimate with missing data
	private ParamFile bvaluesTable;		// .b file with a list of b-valuesFiberCollectionWithUnc
	private ParamFile gradsTable;		// .grad or .dpf file with a list of gradient directions
	private ParamInteger iterWB;
	
	private ParamVolume seedVol;
	private ParamBoolean specroi;
	//private ParamInteger seedsPerVoxel; 
	private ParamDouble startFA;
	private ParamDouble stopFA;
	private ParamInteger turningAngle;

	/****************************************************
	 * Output Parameters
	 ****************************************************/
	//private ParamVolume tensorVolume;	// SLAB-enabled A 4D volume with one tensor estimated per pixel
	//private ParamVolume faVolume;
	//private ParamObject<FiberCollection> fibers;
	private File dir;
	//private ParamObject<CurveCollection> fiberLines;
	//private ParamBoolean writeVtk;
	private ParamFile fibers;
	//private ParamFileCollection confidenceTable;
	private FiberCollectionReaderWriter fcrw = FiberCollectionReaderWriter.getInstance();
	private static final String cvsversion = "$Revision: 1.3 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");
	private static final String shortDescription = "Wild Bootstrap Fiber Distribution";
	private static final String longDescription = "Wild Bootstrap Fiber Distribution";
	File outputdir = this.getOutputDirectory();
	protected void createInputParameters(ParamCollection inputParams) {
		/****************************************************
		 * Step 1. Set Plugin Information 
		 ****************************************************/
		inputParams.setPackage("IACL");
		inputParams.setCategory("DTI");
		inputParams.setLabel("Wild Bootstrap Fiber Distribution");	
		inputParams.setName("Wild Bootstrap Fiber Distribution");

		AlgorithmInformation info = getAlgorithmInformation();
		//info.setWebsite("http://www.nitrc.org/projects/jist/");
		info.add(new AlgorithmAuthor("Chuyang Ye", "cye4@jhu.edu", ""));
		info.setAffiliation("Johns Hopkins University, Department of Electrical and Computer Engineering");
		info.setDescription(shortDescription);
		info.setLongDescription(shortDescription +"\n" +longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.RC);

		/****************************************************
		 * Step 2. Add input parameters to control system 
		 ****************************************************/
		inputParams.add(DWdata4D=new ParamVolumeCollection("DWI and Reference Image(s) Data (4D)"));
		DWdata4D.setLoadAndSaveOnValidate(false);
		inputParams.add(gradsTable=new ParamFile("Table of diffusion weighting directions",new FileExtensionFilter(new String[]{"grad","dpf"})));
		inputParams.add(bvaluesTable=new ParamFile("Table of b-values",new FileExtensionFilter(new String[]{"b"})));
		inputParams.add(Mask3D=new ParamVolumeCollection("Mask Volume to Determine Region of Tensor Estimation (3D)"));
		Mask3D.setLoadAndSaveOnValidate(false);
		Mask3D.setMandatory(false); // Not required. A null mask will estimate all voxels.	
		//inputParams.add(ROI3D=new ParamVolumeCollection("Mask Volume to Determine ROI of Fibers"));
		//ROI3D.setLoadAndSaveOnValidate(false);
		//ROI3D.setMandatory(false); // Not required. A null mask will estimate all voxels.	
		inputParams.add(estOptions=new ParamOption("Attempt to estimate tensors for voxels with missing data?",new String[]{"Yes","No"}));
		inputParams.add(iterWB=new ParamInteger("Wild Boot Iterarion", 20));
		ParamCollection tracking = new ParamCollection("Fiber Tracking");
		inputParams.add(specroi=new ParamBoolean("Specify ROI"));
		inputParams.add(seedVol=new ParamVolume("Seeding ROI",VoxelType.FLOAT,-1,-1,-1,1));
		seedVol.setMandatory(false);
		tracking.add(startFA=new ParamDouble("Start FA",0,1,0.3));
		tracking.add(stopFA=new ParamDouble("Stop FA",0,1,0.13));
		tracking.add(turningAngle=new ParamInteger("Max Turn Angle",0,180,30));
		startFA.setDescription("The FA at which to begin tracking");
		stopFA.setDescription("The FA at which to terminate tracking");
		turningAngle.setDescription("The angle (in degrees) which terminates tracking");
		//tracking.add(seedsPerVoxel = new ParamInteger("Seeds Per Voxels",1,1000,1));
		//tracking.add(probabilityOfSeedStart = new ParamFloat("Probability of starting each seed", 0, 1, 1));
		//probabilityOfSeedStart.setDescription("Used for downsampling fiber tracking field.");
		inputParams.add(tracking);
	}


	protected void createOutputParameters(ParamCollection outputParams) {
		/****************************************************
		 * Step 1. Add output parameters to control system 
		 ****************************************************/
		//tensorVolume = new ParamVolumeCollection("Tensor Estimate");
		//tensorVolume.setLoadAndSaveOnValidate(false);
		//tensorVolume.setName("Tensor (xx,xy,xz,yy,yz,zz)");
		//outputParams.add(tensorVolume);		
		outputParams.add(fibers=new ParamFile("fibers (filenames)", new FileExtensionFilter(new String[]{"txt"})));	
		//outputParams.add(tensorVolume=new ParamVolume("The last tensor (for test)",null,-1,-1,-1,-1));
		//outputParams.add(faVolume=new ParamVolume("The last FA (for test)",null,-1,-1,-1, 1));
		//outputParams.add(fiberLines=new ParamObject<CurveCollection>("Fibers (VTK)",new CurveVtkReaderWriter()));		
		//fiberLines.setMandatory(false);
		//outputParams.add(confidenceTable = new ParamFile("confidence table", new FileExtensionFilter(new String[]{"txt"})));
		//outputParams.add(confidenceTable = new ParamFileCollection("confidence table", new FileExtensionFilter(new String[]{"txt"})));
		//confidenceTable.setMandatory(false);//
	}


	protected void execute(CalculationMonitor monitor) throws AlgorithmRuntimeException {	
		dir = new File(this.getOutputDirectory()+File.separator+edu.jhu.ece.iacl.jist.utility.FileUtil.forceSafeFilename(this.getAlgorithmName()));
		System.out.println(dir);
		System.out.println(dir.exists());
		dir.mkdir();
		System.out.println(dir.exists());
		TensorEstimationWrapper wrapper=new TensorEstimationWrapper();
		monitor.observe(wrapper);
		wrapper.execute(this);
	}


	protected class TensorEstimationWrapper extends AbstractCalculation {

		protected void execute(ProcessingAlgorithm thisParent) {
			/****************************************************
			 * Step 1. Indicate that the plugin has started.
			 * 		 	Tip: Use limited System.out.println statements
			 * 			to allow end users to monitor the status of
			 * 			your program and report potential problems/bugs
			 * 			along with information that will allow you to 
			 * 			know when the bug happened.  
			 ****************************************************/
			System.out.println("Estimation Started");


			/****************************************************
			 * Step 2. Loop over input slabs
			 ****************************************************/

			List<ParamVolume> dwList = DWdata4D.getParamVolumeList();
			List<ParamVolume> maskList = Mask3D.getParamVolumeList();
			//List<ParamVolume> ROImaskList = ROI3D.getParamVolumeList();
			int iter = iterWB.getInt();

			this.addTotalUnits(dwList.size());

			for(int jSlab=0;jSlab<dwList.size();jSlab++) {
				this.setLabel("Load");
				System.out.println("Processing subject: " + jSlab);
				/****************************************************
				 * Step 2. Parse the input data 
				 ****************************************************/

				System.out.println("Checking out data");
				ImageData dwd=dwList.get(jSlab).getImageData();
				ImageHeader hdr = dwd.getHeader();
				String sourceName = dwd.getName();
				ImageDataFloat DWFloat=new ImageDataFloat(dwd);
				dwList.get(jSlab).dispose();
				dwd.dispose();
				dwd=null;

				/* Read the brain mask */
				System.out.println("Checking out mask");
				ImageData maskVol=null;
				if(maskList!=null)
					if(maskList.size()>jSlab)
						if(maskList.get(jSlab)!=null)
							maskVol=maskList.get(jSlab).getImageData();


				byte [][][]mask=null;
				if(maskVol!=null) {
					ImageDataUByte maskByte = new ImageDataUByte (maskVol);
					mask = maskByte.toArray3d();
					maskByte.dispose();
					maskByte=null;
					maskVol.dispose();
					maskVol=null;
					maskList.get(jSlab).dispose();
				}
				else{
					System.out.println("Null mask");
				}
				/* Read the ROI mask */
				System.out.println("Checking out ROI mask");
				//ImageData ROImaskVol=null;
				//if(ROImaskList!=null)
				//if(ROImaskList.size()>jSlab)
				//if(ROImaskList.get(jSlab)!=null)
				//ROImaskVol=ROImaskList.get(jSlab).getImageData();


				/*byte [][][]ROImask=null;
				if(ROImaskVol!=null) {
					ImageDataUByte ROImaskByte = new ImageDataUByte (ROImaskVol);
					ROImask = ROImaskByte.toArray3d();
					ROImaskByte.dispose();
					ROImaskByte=null;
					ROImaskVol.dispose();
					ROImaskVol=null;
					//ROImaskList.get(jSlab).dispose();
				}
				else{
					System.out.println("Null ROI mask");
				}*/
				/* Read the b values */
				System.out.println("Checking out b values");
				float [][]bs=null;		
				TextFileReader text = new TextFileReader(bvaluesTable.getValue());
				try {
					bs = text.parseFloatFile();
				} catch (IOException e) 
				{
					JistLogger.logError(JistLogger.WARNING, "WB: Unable to parse b-file.");
					throw new RuntimeException("WB: Unable to parse b-file");
				}

				/* Read the gradient table  */
				System.out.println("Checking out gradient table");
				float [][]grads=null;
				text = new TextFileReader(gradsTable.getValue());
				try {
					grads  = text.parseFloatFile();
				} catch (IOException e) { 
					JistLogger.logError(JistLogger.WARNING, "WB: Unable to parse grad-file.");
					throw new RuntimeException("WB: Unable to parse grad-file");
				}

				/****************************************************
				 * Step 3. Perform limited error checking 
				 ****************************************************/
				// If there are 4 columns in the gradient table, remove the 1st column (indecies)
				if(grads[0].length==4) {
					float [][]g2 = new float[grads.length][3];
					for(int i=0;i<grads.length;i++) 
						for(int j=0;j<3;j++)
							g2[i][j]=grads[i][j+1];
					grads=g2;
				}

				if(grads[0].length!=3){
					JistLogger.logError(JistLogger.WARNING, "Invalid gradient table. Must have 3 or 4 columns.");
					throw new RuntimeException("Invalid gradient table. Must have 3 or 4 columns.");
				}
				if(bs[0].length!=1){
					JistLogger.logError(JistLogger.WARNING, "Invalid b-value table. Must have 1 column.");
					throw new RuntimeException("Invalid b-value table. Must have 1 column.");
				}
				float []bval = new float[bs.length];
				for(int i=0;i<bval.length;i++)
					bval[i]=bs[i][0];

				/****************************************************
				 * Step 4. Wild Bootstrap Tensor Estimation and Tractography

				 ****************************************************/


				String[] files = new String[iter];
				//System.out.println("Staring initial tensor estimation");
				for(int i = 0; i < iter; i++){
					this.setLabel("Estimate");
					float [][][][] tensor = new float[DWFloat.getRows()][DWFloat.getCols()][DWFloat.getSlices()][6];
					FiberCollection fibercollection = new FiberCollection();
					//System.out.println(seedVol.getImageData().getRows()+" "+seedVol.getImageData().getCols()+" "+seedVol.getImageData().getSlices());
					System.out.println(DWFloat.getRows()+" "+DWFloat.getCols()+" "+DWFloat.getSlices());
					
					if(specroi.getValue()){
						fibercollection = 
						testWBFiberDistribution.track(iter, DWFloat, bval, grads, mask, estOptions.getValue().compareToIgnoreCase("Yes")==0,
								startFA.getDouble(), stopFA.getDouble(), turningAngle.getFloat()*Math.PI/180.0, tensor, new ImageDataFloat(seedVol.getImageData()).toArray3d());
					}
					else{
						float[][][] noROI = null; 
						fibercollection =
							testWBFiberDistribution.track(iter, DWFloat, bval, grads, mask, estOptions.getValue().compareToIgnoreCase("Yes")==0,
								startFA.getDouble(), stopFA.getDouble(), turningAngle.getFloat()*Math.PI/180.0, tensor, noROI);
					
					}

					System.out.println("Tracking done");



					/****************************************************
					 * Step 5. Retrieve the image data and put it into a new
					 * 			data structure. Be sure to update the file information
					 * 			so that the resulting image has the correct
					 * 		 	field of view, resolution, etc.  
					 ****************************************************/
					this.setLabel("Save");

					System.out.println("Saving fibers");
					fibercollection.setName(sourceName+"_wild_boot_fibers_"+(i+1));
					System.out.println(dir);
					System.out.println(fibercollection.getName());
					File f = writeFibers(fibercollection,dir);
					files[i] = dir+File.separator+sourceName+"_wild_boot_fibers_"+(i+1);
					ImageDataMipav image_tensor = new ImageDataMipav(sourceName+"_tensor_"+(i+1),DWFloat.getType(),DWFloat.getRows(), DWFloat.getCols(), DWFloat.getSlices(),6);				
					//tensorVolume.setName(sourceName+"_tensor_"+(i+1));
					System.out.println(sourceName+"_tensor_"+(i+1));
					try{
						for(int ii=0;ii<DWFloat.getRows();ii++){
							for(int jj=0;jj<DWFloat.getCols();jj++){
								for(int kk=0;kk<DWFloat.getSlices();kk++){
									image_tensor.set(ii, jj, kk, 0, tensor[ii][jj][kk][0]);
									image_tensor.set(ii, jj, kk, 1, tensor[ii][jj][kk][1]);
									image_tensor.set(ii, jj, kk, 2, tensor[ii][jj][kk][2]);
									image_tensor.set(ii, jj, kk, 3, tensor[ii][jj][kk][3]);
									image_tensor.set(ii, jj, kk, 4, tensor[ii][jj][kk][4]);
									image_tensor.set(ii, jj, kk, 5, tensor[ii][jj][kk][5]);
								}
							}
						}
					}
					catch(Exception e){e.printStackTrace();}
					image_tensor.setHeader(hdr);
					ImageDataReaderWriter tensor_writer = ImageDataReaderWriter.getInstance();
					File ftensor = new File(dir+File.separator+sourceName+"_tensor_"+(i+1)+".xml");
					tensor_writer.write(image_tensor, ftensor);
					System.out.println("Tensor set");
					
					//tensorVolume.setValue(image_tensor);
					
					/*float [][][]FA = new float[DWFloat.getRows()][DWFloat.getCols()][DWFloat.getSlices()];
					float [][][][]VEC1 = new float[DWFloat.getRows()][DWFloat.getCols()][DWFloat.getSlices()][3];
					for(int ii = 0; ii < DWFloat.getRows();ii++){
						for(int jj = 0; jj < DWFloat.getCols();jj++){
							for(int kk = 0; kk < DWFloat.getSlices();kk++){
								if(mask!=null){
									if(mask[ii][jj][kk]==0){
										FA[ii][jj][kk] = 0;
										continue;
									}
								}

								DiffusionTensor dt = new DiffusionTensor(tensor[ii][jj][kk]);
								FA[ii][jj][kk] = dt.FA();
								//System.out.println(FA[i][j][k]);
								VEC1[ii][jj][kk] = dt.vec1();
							}
						}
					}
					ImageDataFloat imageFA = new ImageDataFloat(FA);
					imageFA.setHeader(hdr);
					imageFA.setName(sourceName+"WB_FA_"+(i+1));
					ImageDataReaderWriter fa_writer = ImageDataReaderWriter.getInstance();
					File ffa = new File(dir+File.separator+sourceName+"WB_FA_"+(i+1)+".xml");
					System.out.println(ffa);
					fa_writer.write(imageFA, ffa);
					
					
					//faVolume.setValue(imageFA);
					ImageDataFloat imageVEC1 = new ImageDataFloat(VEC1);
					imageVEC1.setHeader(hdr);
					imageVEC1.setName(sourceName+"WB_VEC1_"+(i+1));
					ImageDataReaderWriter vec1_writer = ImageDataReaderWriter.getInstance();
					File fvec1 = new File(dir+File.separator+sourceName+"WB_VEC1_"+(i+1)+".xml");
					System.out.println(fvec1);
					fa_writer.write(imageVEC1, fvec1);
					*/
					System.out.println("Iteration "+i+" finished");
				}
				String output_fibers = files[0];
				for(int i = 1;i < iter;i++){
					output_fibers = output_fibers+'\n'+files[i];
				}
				System.out.println(output_fibers);
				File fibername = new File(dir+File.separator+sourceName+"_wild_boot_fibers.txt");
				System.out.println(fibername);
				
				StringReaderWriter writer = StringReaderWriter.getInstance();
				writer.write(output_fibers, fibername);
				/*try
				{
					PrintWriter outputStream=new PrintWriter(new OutputStreamWriter(new FileOutputStream(dir+File.separator+fibername)), true);
					System.out.println(dir+File.separator+'/'+fibername);
					for (int i=0; i<iter; i++)
					{
						outputStream.println(files[i]);
					}
					outputStream.close();
				} catch (Exception e) { e.printStackTrace();}*/
				
				fibers.setValue(fibername);
				this.incrementCompletedUnits();
				System.out.println("Image finished");
				DWFloat.dispose();
				DWFloat=null;
			}

//			tensorVolume.setValue(outVols);

			/****************************************************
			 * Step 6. Let the user know that your code is finished.  
			 ****************************************************/
			JistLogger.logError(JistLogger.INFO, "WildBootTensorTractography: FINISHED");
			System.out.println("All finished");
		}
	}

	private File writeFibers(FiberCollection fibers, File dir){
		File out = null;
		out = fcrw.write(fibers, dir);
		return out;
	}
}



