package edu.jhu.ece.iacl.plugins.dti;

import imaging.Scheme;
import imaging.SchemeV1;
import inverters.AlgebraicDT_Inversion;
import inverters.BallStickInversion;
import inverters.DT_Inversion;
import inverters.DiffusionInversion;
import inverters.LinearDT_Inversion;
import inverters.NonLinearDT_Inversion;
import inverters.RestoreDT_Inversion;
import inverters.TensorModelFitter;
import inverters.WeightedLinearDT_Inversion;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;

import javax.help.TOCView.DefaultTOCFactory;

import com.thoughtworks.xstream.XStream;

import edu.jhu.bme.smile.commons.textfiles.TextFileReader;
import edu.jhu.ece.iacl.algorithms.dti.EstimateTensorLLMSE;
import edu.jhu.ece.iacl.io.CubicVolumeReaderWriter;
import edu.jhu.ece.iacl.io.FileExtensionFilter;
import edu.jhu.ece.iacl.io.ModelImageReaderWriter;
import edu.jhu.ece.iacl.io.StringReaderWriter;
import edu.jhu.ece.iacl.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation.*;
import edu.jhu.ece.iacl.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFileCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.pipeline.parameter.ParamObject;
import edu.jhu.ece.iacl.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.plugins.dti.DWITensorEstCaminoFileCollection.TensorEstimationWrapper;
import edu.jhu.ece.iacl.structures.image.ImageData;
import edu.jhu.ece.iacl.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.structures.image.ImageDataUByte;
import edu.jhu.ece.iacl.utility.FileUtil;
import gov.nih.mipav.model.structures.ModelImage;

public class DWIBallStickEstCaminoFileCollection extends ProcessingAlgorithm{ 

	/****************************************************
	 * Input Parameters 
	 ****************************************************/
	private ParamFileCollection DWdata4D; 		// Imaging Data
	private ParamFileCollection Mask3D;			// Binary mask to indicate computation volume
	private ParamFile SchemeFile;
	/****************************************************
	 * Output Parameters
	 ****************************************************/
	private ParamFileCollection ballDVolume;	// A 3D volume with the diffusivity of the "ball"	
	private ParamFileCollection stickVecVolume;	// A 4D volume with the vector orientation of the stick
	private ParamFileCollection fractionVolume;	// A 3D volume with the diffusivity of the "ball"
	private ParamFileCollection exitCodeVolume;	// A 3D volume 
	private ParamFileCollection intensityVolume;// A 3D volume 

	private static final String rcsid =
		"$Id: DWIBallStickEstCaminoFileCollection.java,v 1.1 2009/03/27 01:28:44 bennett Exp $";
	private static final String cvsversion =
		"$Revision: 1.1 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "");

	protected void createInputParameters(ParamCollection inputParams) {

		/****************************************************
		 * Step 1. Set Plugin Information 
		 ****************************************************/
		inputParams.setName("Ball and Stick Estimation Camino (v2)");
		inputParams.setLabel("Ball and Stick Est");
		inputParams.setCategory("Modeling.Diffusion");
		inputParams.setPackage("Camino");
		AlgorithmInformation info=getAlgorithmInformation();
		info.setWebsite("http://sites.google.com/site/jhupami/");
		info.add(new AlgorithmAuthor("Bennett Landman","landman@jhu.edu","http://sites.google.com/site/bennettlandman/"));
		info.add(new AlgorithmAuthor("Philip Cook","camino@cs.ucl.ac.uk","http://www.cs.ucl.ac.uk/research/medic/camino/"));
		info.setDescription("Fit diffusion weighted imaging data with the ball and stick model. ");
		info.setLongDescription("Fits the Behrens ball and stick model to diffusion-weighted data."+
				"The model is S(g, b) = S_0 (f \\exp[-b d (-g^T v)^2] + [1-f] \\exp[-b d]), where S(g,b) is the " +
				"DW signal along gradient direction g with b-value b, d is a diffusion coefficient, v is the " +
		"orientation of anisotropic diffusion, and f is a mixing parameter (0 <= f <= 1).."); 
		info.setAffiliation("Computer Science Department - University College London");		
		info.add(new Citation("Behrens et al, Magnetic Resonance in Medicine, 50:1077-1088, 2003"));		
		info.setVersion(revnum);	


		/****************************************************
		 * Step 2. Add input parameters to control system 
		 ****************************************************/
		inputParams.add(DWdata4D=new ParamFileCollection("DWI and Reference Image(s) Data (4D)",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions)));

		inputParams.add(SchemeFile=new ParamFile("CAMINO DTI Description (SchemeV1)",new FileExtensionFilter(new String[]{"scheme","schemev1"})));
		inputParams.add(Mask3D=new ParamFileCollection("Mask Volume to Determine Region of Tensor Estimation (3D)",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions)));

		Mask3D.setMandatory(false); // Not required. A null mask will estimate all voxels.	
	}

	protected void createOutputParameters(ParamCollection outputParams) {
		/****************************************************
		 * Step 1. Add output parameters to control system 
		 ****************************************************/
		ballDVolume = new ParamFileCollection("Ball Diffusivity Estimate",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions));
		ballDVolume.setName("Ball (mm2/s)");
		outputParams.add(ballDVolume);
		stickVecVolume = new ParamFileCollection("Stick Vector Estimate",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions));
		stickVecVolume.setName("Vector (x,y,z)");
		outputParams.add(stickVecVolume);
		fractionVolume = new ParamFileCollection("Ball Fraction Estimate",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions));
		fractionVolume.setName("Ball Fraction");
		outputParams.add(fractionVolume);
		exitCodeVolume = new ParamFileCollection("Estimation Exit Code",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions));
		exitCodeVolume.setName("Exit Code");
		outputParams.add(exitCodeVolume);	
		intensityVolume = new ParamFileCollection("Intensity Estimate",new FileExtensionFilter(ModelImageReaderWriter.supportedFileExtensions));
		intensityVolume.setName("Intensity");
		outputParams.add(intensityVolume);

	}

	protected void execute(CalculationMonitor monitor) throws AlgorithmRuntimeException {		
		TensorEstimationWrapper wrapper=new TensorEstimationWrapper();
		monitor.observe(wrapper);
		wrapper.execute();
	}

	protected class TensorEstimationWrapper extends AbstractCalculation {

		protected void execute() {

			/****************************************************
			 * Step 1. Indicate that the plugin has started.
			 * 		 	Tip: Use limited System.out.println statements
			 * 			to allow end users to monitor the status of
			 * 			your program and report potential problems/bugs
			 * 			along with information that will allow you to 
			 * 			know when the bug happened.  
			 ****************************************************/
			System.out.println("DWITensorEstLLMSE: Start");
			/****************************************************
			 * Step 2. Loop over input slabs
			 ****************************************************/
			List<File> dwList = DWdata4D.getValue();
			List<File> maskList = Mask3D.getValue();
			CubicVolumeReaderWriter rw  = CubicVolumeReaderWriter.getInstance();
			ArrayList<File> outStickVecVols = new ArrayList<File>();
			ArrayList<File> outBallFracVols = new ArrayList<File>();
			ArrayList<File> outBallDiffVols = new ArrayList<File>();
			ArrayList<File> outExitVols = new ArrayList<File>();
			ArrayList<File> outIntensityVols = new ArrayList<File>();
			this.addTotalUnits(dwList.size());
			for(int jSlab=0;jSlab<dwList.size();jSlab++) {
				/****************************************************
				 * Step 2. Parse the input data 
				 ****************************************************/
				System.out.println("Load data.");System.out.flush();
				this.setLabel("Load");
				ImageData dwd=rw.read(dwList.get(jSlab));//DWdata4D.getImageData();
				String imageName = dwd.getName();
				ImageDataFloat DWFloat=new ImageDataFloat(dwd);dwd.dispose();

				ImageData maskVol=null;
				if(maskList.size()>jSlab)
					maskVol=rw.read(maskList.get(jSlab));//Mask3D.getImageData();
				byte [][][]mask=null;
				if(maskVol!=null) {
					ImageDataUByte maskByte = new ImageDataUByte (maskVol);
					mask = maskByte.toArray3d();
					maskVol.dispose(); maskVol=null;
					maskByte.dispose(); maskByte=null;
				}

				System.out.println("Load scheme.");System.out.flush();
				SchemeV1 DTIscheme = null;

				XStream xstream = new XStream();
				xstream.alias("CaminoDWScheme-V1",imaging.SchemeV1.class);
				try {
					ObjectInputStream in = xstream.createObjectInputStream(new FileReader(SchemeFile.getValue()));
					DTIscheme=(SchemeV1)in.readObject();
					in.close();
				} catch (IOException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
					throw new RuntimeException(e);
				} catch (ClassNotFoundException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
					throw new RuntimeException(e);
				}


				/****************************************************
				 * Step 3. Perform limited error checking 
				 ****************************************************/
				System.out.println("Error checking."); System.out.flush();

				BallStickInversion dtiFit=new BallStickInversion(DTIscheme);
				String code = "ballAndStick";

				/****************************************************
				 * Step 4. Run the core algorithm. Note that this program 
				 * 		   has NO knowledge of the MIPAV data structure and 
				 * 		   uses NO MIPAV specific components. This dramatic 
				 * 		   separation is a bit inefficient, but it dramatically 
				 * 		   lower the barriers to code re-use in other applications.  		  
				 ****************************************************/
				System.out.println("Allocate memory."); System.out.flush();
				float [][][][]data=DWFloat.toArray4d();
				int rows = data.length;
				int cols= data[0].length;
				int slices= data[0][0].length;
				int components= data[0][0][0].length;
				float [][][][]ballDiff = new float[rows][cols][slices][1];
				float [][][][]ballFrac = new float[rows][cols][slices][1];
				float [][][][]stickVec = new float[rows][cols][slices][3];
				float [][][][]exitCode= new float[rows][cols][slices][1];
				float [][][][]intensity= new float[rows][cols][slices][1];

				this.setLabel("Estimate");
				System.out.println("Run CAMINO estimate."); System.out.flush();
				EstimateTensorLLMSE.estimateBallAndStickCamino(data,mask,dtiFit,ballDiff,ballFrac,stickVec,exitCode,intensity);

				/****************************************************
				 * Step 5. Retrieve the image data and put it into a new
				 * 			data structure. Be sure to update the file information
				 * 			so that the resulting image has the correct
				 * 		 	field of view, resolution, etc.  
				 ****************************************************/
				System.out.println("Data export."); System.out.flush();
				this.setLabel("Save");

				ImageDataFloat out=new ImageDataFloat(stickVec);
				out.setHeader(DWFloat.getHeader());
				out.setName(imageName+"_Stick");
				File outputSlab = rw.write(out, getOutputDirectory());			
				outStickVecVols.add(outputSlab);	
				System.out.println(outputSlab);System.out.flush();

				out=new ImageDataFloat(ballDiff);
				out.setHeader(DWFloat.getHeader());
				out.setName(imageName+"_BallDiff");
				outputSlab = rw.write(out, getOutputDirectory());			
				outBallDiffVols.add(outputSlab);	
				System.out.println(outputSlab);System.out.flush();

				out=new ImageDataFloat(ballFrac);
				out.setHeader(DWFloat.getHeader());
				out.setName(imageName+"_BallFrac");
				outputSlab = rw.write(out, getOutputDirectory());			
				outBallFracVols.add(outputSlab);	
				System.out.println(outputSlab);System.out.flush();


				out=new ImageDataFloat(exitCode);
				out.setHeader(DWFloat.getHeader());
				out.setName(imageName+"_ExitCode"+code);

				outputSlab = rw.write(out, getOutputDirectory());
				//			exitCodeVolume.setValue(out);
				outExitVols.add(outputSlab);
				System.out.println(outputSlab);System.out.flush();

				out=new ImageDataFloat(intensity);
				out.setHeader(DWFloat.getHeader());
				out.setName(imageName+"_Intensity"+code);			
				//			intensityVolume.setValue(out);	
				outputSlab = rw.write(out, getOutputDirectory());
				outIntensityVols.add(outputSlab);
				System.out.println(outputSlab);System.out.flush();

				/****************************************************
				 * Step 6. Let the user know that your code is finished.  
				 ****************************************************/
				System.out.println("DWITensorEstLLMSE: FINISHED");
				this.incrementCompletedUnits();
			}

			stickVecVolume.setValue(outStickVecVols);
			ballDVolume.setValue(outBallDiffVols);
			fractionVolume.setValue(outBallFracVols);
			exitCodeVolume.setValue(outExitVols);
			intensityVolume.setValue(outIntensityVols);
		}
	}
}