package edu.vanderbilt.masi.plugins.CRUISE.utilities;

import javax.vecmath.Point3f;
import javax.vecmath.Point3i;

import Jama.Matrix;

import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.parameter.*;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataIntent;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.plugins.registration.MedicAlgorithmTransformVolume2;
import edu.jhu.ece.iacl.plugins.registration.MedicAlgorithmFLIRT;


public class MedicAlgorithmIACLSpaceTest extends ProcessingAlgorithm{
	// Various volumes.
	public ParamVolume InputVol, InputSkullMaskVol, MNIAtlasVol;
	public ParamVolume OutputVol, OutputSkullMaskVol;
	public ParamPointInteger OutputDims;
	public ParamPointFloat OutputRes;
	public ParamVolume OutputMNIAtlasResampled;

	// The isotropic resample options param.
	public ParamOption IsotropicResampleOptions;
	public ParamOption DimensionOptions;

	// The resample options.
	public static final String[] IsotropicResampleChoices = { "0.8", "1.0", "0.83", "1.5", "Input Best", "Atlas Best"};
	public static final String[] DimensionChoices = { "320x320x320", "Match Input", "Match Atlas","Dilate Input"};
	
	// Boolean to decide if we reorient the volume to the MNI atlas.
	public ParamBoolean MNIAtlasBool;

	// Output matrix.
	public ParamMatrix OutputMatrix;


	private static final String cvsversion = "$Revision: 1.3 $".replace("Revision: ", "").replace("$", "").replace(" ", "");
	// private static final String revnum = ReorientVolume.getVersion();
	private static final String revnum = cvsversion;
	private static final String shortDescription =
		"Reslice a volume to one of three dimensions:\n" +
		" 1) 0.8mm Isoropic\n" +
		" 2) 1.0mm Isoropic\n" +
		" 3) 0.83mm Isoropic\n" +
		" 4) Isotropic matching the best resolution dimension in input.\n\n" +
		" 5) Isotropic matching the best resolution dimension in atlas." +
		"Can also align to an MNI atlas.\n\n" +
		"Algorithm Version: " + revnum + "\n";
	private static final String longDescription = "";


	protected void createInputParameters(ParamCollection inputParams) {
		InputVol = new ParamVolume("Input volume");
		InputVol.setDescription("Input volume.");


		InputSkullMaskVol = new ParamVolume("Input Skull Mask");
		InputSkullMaskVol.setDescription("Input skull stripping mask, which is applied after the resampling and any reorientation.");
		InputSkullMaskVol.setMandatory(false);


		IsotropicResampleOptions = new ParamOption("Isotropic Resample Options", IsotropicResampleChoices);
		IsotropicResampleOptions.setDescription("Determines which isotropic resampling option to take.");

		DimensionOptions = new ParamOption("Output Image Dimensions Options", DimensionChoices);
		DimensionOptions.setDescription("Determines the dimensions of the output image");


		MNIAtlasVol = new ParamVolume("MNI Atlas");
		MNIAtlasVol.setDescription("The MNI Atlas used as target in the MNI spatial allignment.");
		MNIAtlasVol.setMandatory(false);


		MNIAtlasBool = new ParamBoolean("Reorient data to MNI Atlas", false);
		MNIAtlasBool.setDescription("Determines if the data is kept in the original orientation or reorientated to the MNI Atlas.");


		inputParams.add(InputVol);
		inputParams.add(InputSkullMaskVol);
		inputParams.add(IsotropicResampleOptions);
		inputParams.add(DimensionOptions);
		inputParams.add(MNIAtlasVol);
		inputParams.add(MNIAtlasBool);
		

		inputParams.setPackage("IACL");
		inputParams.setCategory("Utilities.Volume");
		inputParams.setLabel("IACL Space");
		inputParams.setName("IACL_Space");


		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("http://www.iacl.ece.jhu.edu/");
		info.add(new AlgorithmAuthor("Aaron Carass", "aaron_carass@jhu.edu", "http://www.iacl.ece.jhu.edu/"));
		info.setAffiliation("Johns Hopkins University, Departments of Electrical and Computer Engineering");
		info.setDescription(shortDescription);
		info.setLongDescription(shortDescription + longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.Release);
	}


	protected void createOutputParameters(ParamCollection outputParams) {
		OutputVol = new ParamVolume("Output volume");
		OutputVol.setDescription("The input image after it has been resliced and reorientated.");


		OutputSkullMaskVol = new ParamVolume("Output mask");
		OutputSkullMaskVol.setDescription("The input skull mask after it has been resliced and reorientated.");
		OutputSkullMaskVol.setDataIntent(ImageDataIntent.NIFTI_INTENT_LABEL);
		OutputSkullMaskVol.setMandatory(false);


		Matrix OutMatrix = new Matrix(4, 4);

		OutputMatrix = new ParamMatrix("Output transformation matrix", OutMatrix);
		OutputMatrix.setDescription("The transformation matrix that generated the output.");

		outputParams.add(OutputVol);
		outputParams.add(OutputSkullMaskVol);
		outputParams.add(OutputMatrix);
		outputParams.add(OutputRes = new ParamPointFloat("Output Resolutions"));
		outputParams.add(OutputDims = new ParamPointInteger("Output Dimensions"));
		outputParams.add(OutputMNIAtlasResampled = new ParamVolume("Resampled MNI Atlas"));
		OutputMNIAtlasResampled.setMandatory(false);
		
	}


	protected void execute(CalculationMonitor monitor)
			throws AlgorithmRuntimeException {
		System.out.format("\n\nMedicAlgorithmIACLSpace\n\n\n");

		ImageData inputVol = InputVol.getImageData();
		ImageData maskVol;
		ImageData maskedInputVol = InputVol.getImageData().clone();

		if (InputSkullMaskVol.getImageData() != null) {
			maskVol = InputSkullMaskVol.getImageData();
			maskedInputVol = InputVol.getImageData().clone();
			
			//get skullstripped version of input if available
			for (int i = 0; i < maskVol.getRows(); i++)
				for (int j = 0; j < maskVol.getCols(); j++) 
					for (int k = 0; k < maskVol.getSlices(); k++) {
						if (maskVol.getFloat(i, j, k) < 0.5) {
							maskedInputVol.set(i, j, k, 0.0);
						}
					}
		}else{
			maskVol = null;
			maskedInputVol = null;
		}

		ImageData mniVol;//MNI atlas
		ImageData outVol = null, outMask = null;//output volumes
		String outputName = inputVol.getName();;//output name

		ImageHeader oldhdr = inputVol.getHeader();//old header info
		ImageHeader newHdr = oldhdr.clone();//new header info (for output)

		//set resolutions and dimensions
		float[] oldRes = oldhdr.getDimResolutions();//Input Resolution
		int[] oldDim = new int[3];//Input Dimensions
		oldDim[0] = inputVol.getRows();
		oldDim[1] = inputVol.getCols();
		oldDim[2] = inputVol.getSlices();

		
		float[] newRes = new float[3];//Output Resolution
		int[] newDim = new int[3];//Output Dimensions
		
		float[] oldO, newO; /* Origin(s) */
		int splineDegree = 3;//Spline Degree in Interpolation

		Matrix transToAtlas = new Matrix(4,4); /* Transformation Matrix */
		Matrix idMatrix =  new Matrix(4,4); /* Identity Matrix */

		System.out.format("Initializing matrices.\n");
		for (int i = 0; i < 4; i++){
			for (int j = 0; j < 4; j++) {
				if(i == j ) idMatrix.set(i, j, 1);
				else idMatrix.set(i, j, 0.0);
			}
		}
		
		
		//update origin
		oldO = oldhdr.getOrigin();//old origin
		newO = findNewOrigin(oldO, newRes, oldRes);//reset origin using new resolution (Is this correctly implemented?? - Min)
		newHdr.setOrigin(newO); //The new origin.

		//find and set output resolution and dimension, update name
		outputName = setOutputResolutionAndDimensions(newRes, newDim, outputName);
		newHdr.setDimResolutions(newRes);

		System.out.format("\nOld Resolution: %g %g %g\n", oldRes[0], oldRes[1], oldRes[2]);
		System.out.format("New Resolution: %g %g %g\n\n", newRes[0], newRes[1], newRes[2]);

		
		//Perform input to atlas transformation if needed
		if (MNIAtlasBool.getValue() && MNIAtlasVol.getImageData() != null){
			System.out.format("\n\nMNI Alignment\n\n\n");
			//first transform MNIAtlas to Target Resolution and Dimensions.
			mniVol = transform(MNIAtlasVol.getImageData(), idMatrix, newDim,newRes, splineDegree);
			
			
			OutputMNIAtlasResampled.setValue(mniVol);
			System.out.println(mniVol.getName() + "\n\n");
			outputName += "_mni";

			//register to atlas depending on if skullstripped version is available
			if(maskedInputVol != null)
				transToAtlas = findTransformationToAtlas(maskedInputVol, mniVol);
			else
				transToAtlas = findTransformationToAtlas(inputVol, mniVol);
		}else{//transform is identity if not performing alignment
			transToAtlas = idMatrix.copy();
		}

		System.out.format("\n\nm =\n");
		for (int i = 0; i < 4 ; i++) {
			System.out.format("%2.2g %2.2g %2.2g %2.2g\n", transToAtlas.get(i, 0), transToAtlas.get(i, 1), transToAtlas.get(i, 2), transToAtlas.get(i, 3));
		}


		//set output volume using transformation from alignment
		outVol = transform(inputVol, transToAtlas, newDim,newRes, splineDegree);

		//apply transformation to brainmask, and apply mask to output if available.
		if (maskVol != null) {
			System.out.format("\n\nUsing Mask\n\n\n");

			/*
			 * Make the mask have sane values, just in case.
			 */
			for (int i = 0; i < maskVol.getRows(); i++)
				for (int j = 0; j < maskVol.getCols(); j++) 
					for (int k = 0; k < maskVol.getSlices(); k++) {
						if (maskVol.getFloat(i, j, k) < 0.5) {
							maskVol.set(i, j, k, 0.0);
						} else {
							maskVol.set(i, j, k, 1.0);
						}
					}

			//transform mask using input to atlas transformation if needed,
			//otherwise transToAtlas == idMatrix, and only a resampling is done
			outMask = transform(maskVol, transToAtlas, newDim,newRes, splineDegree);
			

			System.out.format("\n\nMask Dimensions: %d %d %d\n", outMask.getRows(), outMask.getCols(), outMask.getSlices());
			System.out.format("Output Dimensions: %d %d %d\n\n\n", outVol.getRows(), outVol.getCols(), outVol.getSlices());


			//Make mask binary and Apply to output
			for (int i = 0; i < outMask.getRows(); i++){
				for (int j = 0; j < outMask.getCols(); j++) {
					for (int k = 0; k < outMask.getSlices(); k++) {
						if (outMask.getFloat(i, j, k) < 0.5) {
							outVol.set(i, j, k, 0.0);
							outMask.set(i, j, k, 0.0);
						} else {
							outMask.set(i, j, k, 1.0);
						}
					}
				}
			}
		}

		//set output volumes
		outVol.setHeader(newHdr);
		outVol.setName(outputName);
		OutputVol.setValue(outVol);


		//set output for mask
		if (maskVol != null) {
			outMask.setHeader(newHdr);
			outMask.setName(outputName + "_mask");
			OutputSkullMaskVol.setValue(outMask);
		}


		OutputMatrix.setValue(transToAtlas);
	}
	
	
	//set new resolution and dimensions using input information
	//new resolutions and dimensions are written to "newRes" and "newDim"
	//returns updated output name
	String setOutputResolutionAndDimensions(float[] newRes, int[] newDim, String outputName){
		
		
		float rNewIso = 0;
		float[] oldRes;
		/*
		 * set output resolution
		 */
		switch(IsotropicResampleOptions.getIndex()){
			case 0://.8mm isotropic
				rNewIso= 0.8000f;
				outputName += "_0_8Res";
				break;

			case 1://1mm isotropic
				rNewIso = 1.0000f;
				outputName += "_1_0Res";
				break;
				
			case 2://0.83mm isotropic
				rNewIso = 0.8300f;
				outputName += "_0_83Res";
				break;
				
			case 3://1.5mm isotropic
				rNewIso = 1.5000f;
				outputName += "_1_5Res";
				break;

			case 4://Take best resolution from input as isotropic
				oldRes = InputVol.getImageData().getHeader().getDimResolutions();
				rNewIso = Float.MAX_VALUE;
				for(int i = 0; i < 3; i++){
					if(rNewIso > oldRes[i]) rNewIso = oldRes[i];					
				}
				outputName += "_inputBestRes";
				break;

			case 5://Take best resolution from atlas as isotropic
				if (MNIAtlasVol.getImageData() != null){
					rNewIso = Float.MAX_VALUE;
					oldRes = MNIAtlasVol.getImageData().getHeader().getDimResolutions();
					for(int i = 0; i < 3; i++){
						if(rNewIso > oldRes[i]) rNewIso = oldRes[i];					
					}
				}else{
					rNewIso = -1;
				}
				outputName += "_atlasBestRes";
				break;
				
			default://default to .8mm isotropic
				rNewIso= 0.8000f;
				outputName += "_0_8Res";
				break;
		}


		if (rNewIso < 0){
			System.out.format("\nSomething weird happened while setting Resolution.\n\n");
			System.exit(-1);
		}


		/*
		 * The output resolution is isotropic with rNew x rNew x rNew.
		 */
		newRes[0] = newRes[1] = newRes[2] = rNewIso;
		
		OutputRes.setValue(new Point3f(newRes));
		
		/*
		 * Now set output dimensions
		 */
		
		switch(DimensionOptions.getIndex()){
		case 0://320x320x320
			newDim[0] = 320;
			newDim[1] = 320;
			newDim[2] = 320;
			outputName += "_320Dim";
			break;
			
		case 1://dimensions of input
			newDim[0] = InputVol.getImageData().getRows();
			newDim[1] = InputVol.getImageData().getCols();
			newDim[2] = InputVol.getImageData().getSlices();
			outputName += "_inputDim";
			break;
			
		case 2://dimensions of atlas
			newDim[0] = MNIAtlasVol.getImageData().getRows();
			newDim[1] = MNIAtlasVol.getImageData().getCols();
			newDim[2] = MNIAtlasVol.getImageData().getSlices();
			outputName += "_atlasDim";
			break;
			
		case 3://dilate the dimensions of atlas by 
			float newD1 = (float)InputVol.getImageData().getRows()/rNewIso;
			float newD2 = (float)InputVol.getImageData().getCols()/rNewIso;
			float newD3 = (float)InputVol.getImageData().getSlices()/rNewIso;
			newDim[0] = (int) Math.ceil(newD1);
			newDim[1] = (int) Math.ceil(newD2);
			newDim[2] = (int) Math.ceil(newD3);
			outputName += "_dilateDim";
			break;
			
		default://default to 320x320x320
			newDim[0] = 320;
			newDim[1] = 320;
			newDim[2] = 320;
			outputName += "_320Dim";
			break;

		}
		
		OutputDims.setValue(new Point3i(newDim));
		return outputName;
	}
	
	
	/*Transform a volume "volIn" using
	 * matIn - transformation matrix
	 * newDim - output dimensions after transform
	 * newRes - output resolution after transform
	 * splineDegree - spline degree during interpolation
	 * 
	 * returns - transformed image
	 */	
	ImageData transform(ImageData volIn, Matrix matIn, int[] newDim, float[] newRes,int splineDegree){
		// Black box for doing the interpolation.
		MedicAlgorithmTransformVolume2 transformMagic = new MedicAlgorithmTransformVolume2();
		transformMagic.dims.setValue(new Point3i(newDim));
		transformMagic.res.setValue(new Point3f(newRes));
		transformMagic.inputVolumeCollection.add(volIn);
		transformMagic.paramSplineDegree.setValue(splineDegree);
		transformMagic.matrix.setValue(matIn);
		transformMagic.setOutputDirectory(this.getOutputDirectory());
		transformMagic.runAlgorithm();

		return transformMagic.outputVolumeCollection.getImageDataList().get(0);
	}

	
	//use FLIRT to find a rigid transformation matrix between volumes "subject" and "target"
	//returns transformation matrix
	Matrix findTransformationToAtlas(ImageData subject, ImageData target){

		MedicAlgorithmFLIRT flirtMagic = new MedicAlgorithmFLIRT();
		flirtMagic.source.setValue(subject);
		flirtMagic.target.setValue(target);
		flirtMagic.dof.setValue(0);
		flirtMagic.costFunction.setValue(3);
		flirtMagic.setOutputDirectory(this.getOutputDirectory());
		flirtMagic.runAlgorithm();
		return flirtMagic.trans.getValue().copy();
	}
	
	
	//update Origin, given old origin, and old and new resolutions
	float[] findNewOrigin(float[] oldO, float[] newRes, float[] oldRes){
		
		float[] newO = new float[3];
		//scaling matrix
		Matrix m = new Matrix(4,4);
		m.set(0, 0, (newRes[0]/oldRes[0]));
		m.set(1, 1, (newRes[1]/oldRes[1]));
		m.set(2, 2, (newRes[2]/oldRes[2]));
		m.set(3, 3, 1.0);


		for (int i = 0; i < 3 ; i++) {
			newO[i] = oldO[0] * ((float) m.get(0, i)) + oldO[1] * ((float) m.get(1, i)) + oldO[2] * ((float) m.get(2, i));
		}

		return newO;
		
	}
	
}
