package edu.jhu.ece.iacl.plugins.registration;

import java.awt.Dimension;
import java.io.File;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;

import javax.vecmath.Point3f;
import javax.vecmath.Point3i;

import Jama.Matrix;

import edu.jhmi.rad.medic.libraries.ImageFunctionsPublic;
import edu.jhu.ece.iacl.algorithms.CommonAuthors;
import edu.jhu.ece.iacl.algorithms.vabra.VabraAlgorithm;
import edu.jhu.ece.iacl.algorithms.volume.CompareVolumes;
import edu.jhu.ece.iacl.algorithms.volume.TransformVolume;
import edu.jhu.ece.iacl.jist.io.FileExtensionFilter;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.parameter.*;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.VoxelType;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities;

import gov.nih.mipav.model.structures.ModelImage;

public class MedicAlgorithmVABRA extends ProcessingAlgorithm {

	// Module Input Parameters
	public ParamVolumeCollection inParamSubjects;
	public ParamVolumeCollection inParamTargets;
	// ParamNumberCollection inParamChannelWeights;
	// ParamNumberCollection targetWeights;
	public ParamOption[] inParamInterpChan1to3;
	public ParamNumberCollection inParamInterpChanN;
	public ParamFile inParamConfigFile;
	public ParamDouble inParamRobustMaxT;
	public ParamDouble inParamRobustMinT;
	public ParamInteger inParamNumBins;
	public ParamInteger inParamPadSize;
	public ParamOption inParamPadMethod;
	public ParamBoolean inParamUseFlirt;
	// public ParamBoolean inParamFineOptimize;
	public ParamDouble inParamOptimizeXDir;
	public ParamDouble inParamOptimizeYDir;
	public ParamDouble inParamOptimizeZDir;
	public ParamBoolean inParamSaveIntermResults;
	public ParamOption inParamDefFieldUpdateMode;
	public ParamOption inParamCostFunction;
	public ParamBoolean inParamForceDiffeomorphism;
	public ParamVolume inParamLabelsToMaintainHomeo;

	// Module Output Parameters
	public ParamVolumeCollection outParamDeformedSubject;
	public ParamVolume outParamDeformationField;
	public ParamMatrix outParamTransformMatrices;

	// algorithm Variables
	public MedicAlgorithmFLIRT flirt;
	private VabraAlgorithm vabra;
	private List<ImageData> rawTargets, rawSubjects;
	private List<ImageData> paddedSubjects, paddedTargets, transformedSubjects;
	private int[] interpType;
	private List<Number> interpTypeN;
	private TransformVolume.Interpolation[] transformInterpType;
	private int padSize;
	private double[] directionsOptmizationWeight;
	private ImageData rawLabelToMaintainHomeo, transformedLabelToMaintainHomeo,
			paddedLabelToMaintainHomeo;
	private boolean is2DImage;

	private static final String revnum = VabraAlgorithm.getVersion();

	public String[] getDefaultJVMArgs() {
		return new String[] { "-XX:MinHeapFreeRatio=60",
				"-XX:MaxHeapFreeRatio=90",
				"-XX:YoungGenerationSizeIncrement=100",
				"-XX:TenuredGenerationSizeIncrement=100" };
	}

	protected void createInputParameters(ParamCollection inputParams) {

		inParamConfigFile = new ParamFile("Configuration",
				ParamFile.DialogType.FILE);
		inParamConfigFile.setExtensionFilter(new FileExtensionFilter(
				new String[] { "xml" }));
		URL url = VabraAlgorithm.class.getResource("config.xml");
		try {
			if (url != null) {
				inParamConfigFile.setValue(new File(url.toURI()));
			} else {
				inParamConfigFile.setValue(new File("")); // prevent crash when
															// module is not
															// found
			}
		} catch (URISyntaxException e) {
			inParamConfigFile.setValue(new File("")); // prevent crash when
														// module is not found
			e.printStackTrace();
			JistLogger.logError(JistLogger.INFO,
					"Continuing with library build.");
		}

		ParamCollection mainParams = new ParamCollection("Main");
		mainParams.add(inParamSubjects = new ParamVolumeCollection("Subjects"));
		mainParams.add(inParamTargets = new ParamVolumeCollection("Targets"));
		mainParams.add(inParamConfigFile);
		String[] costTypes = { "NMI", "SSD", "M-NMI"};
		mainParams.add(inParamCostFunction = new ParamOption("Cost Function", costTypes));


		ParamCollection interpParams = new ParamCollection("Interpolation Type");
		inParamInterpChan1to3 = new ParamOption[3];
		String[] interpTypes = { "0:TriLinear", "1:NNI" };
		interpParams.add(inParamInterpChan1to3[0] = new ParamOption(
				"Type for Channel 1", interpTypes));
		interpParams.add(inParamInterpChan1to3[1] = new ParamOption(
				"Type for Channel 2", interpTypes));
		interpParams.add(inParamInterpChan1to3[2] = new ParamOption(
				"Type for Channel 3", interpTypes));
		for (int i = 0; i < 3; i++)
			inParamInterpChan1to3[i].setValue(0);
		interpParams.add(inParamInterpChanN = new ParamNumberCollection(
				"Type for Remaining Channels (Use Index from Options Above)"));
		inParamInterpChanN.setMandatory(false);

		// Advanced Parameters
		ParamCollection advParams = new ParamCollection("Advanced");

		ParamCollection advGenParams = new ParamCollection("General");
		advGenParams.add(inParamUseFlirt = new ParamBoolean(
				"Run Affine Registration (FLIRT) First", true));
		advGenParams.add(inParamForceDiffeomorphism = new ParamBoolean(
				"Force Diffeomorphism", false));
		advGenParams.add(inParamLabelsToMaintainHomeo = new ParamVolume(
				"Labels to Maintain Digital Homeomorphism on:(Optional)"));
		inParamLabelsToMaintainHomeo.setMandatory(false);
		advGenParams.add(inParamSaveIntermResults = new ParamBoolean(
				"Save Results Between Levels", false));
		String[] defUpdateMode = { "0:Summation", "1:Composite" };
		advGenParams.add(inParamDefFieldUpdateMode = new ParamOption("Deformation Field Update Mode", defUpdateMode));
		advGenParams.add(inParamRobustMaxT = new ParamDouble("Percentile of Image Above Robust Maximum (0.0-1.0)", 0.0, 1.0,0.0));
		advGenParams.add(inParamRobustMinT = new ParamDouble("Percentile of Image Below Robust Minimum (0.0-1.0)", 0.0, 1.0,0.0));
		advGenParams.add(inParamNumBins = new ParamInteger("Number of Bins used in Histogram", 0, 128, 64));
		advGenParams.add(inParamPadSize = new ParamInteger("Number of Voxels to Pad Edges", 0, 100, 5));
		String[] padMethods = { "Extend Zeros", "Extend Ends", "Mirror Edges" };
		advGenParams.add(inParamPadMethod = new ParamOption("Method for Padding Edges", padMethods));
		inParamPadMethod.setValue(0);
		advParams.add(advGenParams);

		// Optimization Parameters
		ParamCollection optParams = new ParamCollection("Optimization Settings");
		optParams.add(inParamOptimizeXDir = new ParamDouble("X Optimization Weight", 1));
		optParams.add(inParamOptimizeYDir = new ParamDouble("Y Optimization Weight", 1));
		optParams.add(inParamOptimizeZDir = new ParamDouble("Z Optimization Weight", 1));
		// optParams.add(inParamFineOptimize = new
		// ParamBoolean("Use Fine Optimizer(~4x Runtime Increase)", false));
		advParams.add(optParams);

		// flirt Params
		flirt = new MedicAlgorithmFLIRT();
		flirt.target.setHidden(true);
		flirt.source.setHidden(true);
		flirt.refWeight.setHidden(true);
		flirt.inputWeight.setHidden(true);
		flirt.inputInterpolation.setHidden(true);
		;
		flirt.outputInterpolation.setHidden(true);
		;
		flirt.costFunction.setValue(3);
		flirt.dof.setValue(2);
		ParamCollection flirtParams = flirt.getInput();
		flirtParams.setName("FLIRT Options");
		flirtParams.setLabel("FLIRT Options");
		advParams.add(flirtParams);

		inputParams.add(mainParams);
		inputParams.add(interpParams);
		inputParams.add(advParams);

		inputParams.setName("VABRA");
		inputParams.setLabel("VABRA");

		setPreferredSize(new Dimension(300, 600));

		inputParams.setPackage("IACL");
		inputParams.setCategory("Registration.Volume");

		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("http://www.iacl.ece.jhu.edu/");
		info.add(CommonAuthors.minChen);
		info.add(CommonAuthors.blakeLucas);
		info.add(CommonAuthors.bryanWheeler);
		info.add(new AlgorithmInformation.Citation(
				"G. K. Rohde, A. Aldroubi, and B. M. Dawant, \"The adaptive bases algorithm for intensity-based nonrigid image registration,\" Medical Imaging, IEEE Trans. on, vol. 22, pp. 1470-1479, 2003."));
		info.setDescription("Vectorized Adaptive Bases Registration Algorithm - A deformable registration algorithm.  Returns registered image and the applied deformation field.");
		info.setLongDescription("This is a port of Bryan Wheeler's VABRA. The Java implementation never achieved exact numerical equivalence with the C++ version, but the final results are close enough to conclude the algorithms produce the same results.");
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.RC);
	}

	protected void createOutputParameters(ParamCollection outputParams) {
		outParamDeformedSubject = new ParamVolumeCollection("Registered Image");
		outputParams.add(outParamDeformedSubject);
		outputParams.add(outParamDeformationField = new ParamVolume("Deformation Field", VoxelType.FLOAT, -1, -1, -1, -1));
		outputParams.add(outParamTransformMatrices = new ParamMatrix("Transformation Matrix", 4, 4));
		outParamTransformMatrices.setValue(flirt.trans.getValue());
		outputParams.setName("VABRA");
	}

	protected void execute(CalculationMonitor monitor) {
		// 0.)Initialize variables;
		initialize();
		monitor.observe(vabra);

		// 1.)Check if Subject and Targets are the same, if so then exit
		if (checkSameInputs())
			return;

		// 2.)Preprocess with affine and padding if needed
		preProcessing();

		System.out.println(getClass().getCanonicalName() + "\t"
				+ "VABRA - BEFORE VABRA");
		edu.jhu.ece.iacl.plugins.labeling.MedicAlgorithmMultiAtlasSurfaceLabeling
				.generateMemoryReport();

		// 3.)Run VABRA
		vabra.solve(paddedSubjects, paddedTargets,
				inParamConfigFile.getValue(), inParamNumBins.getInt(),
				interpType, this.getOutputDirectory(),
				inParamSaveIntermResults.getValue(),
				directionsOptmizationWeight,
				inParamDefFieldUpdateMode.getIndex(),
				inParamCostFunction.getIndex(),
				inParamForceDiffeomorphism.getValue(),
				paddedLabelToMaintainHomeo);

		System.out.println(getClass().getCanonicalName() + "\t"
				+ "VABRA - AFTER VABRA");
		edu.jhu.ece.iacl.plugins.labeling.MedicAlgorithmMultiAtlasSurfaceLabeling
				.generateMemoryReport();

		// 4.)Postprocess by cropping the original pads
		postProcessing();

		// These are referencing input parameters, so we can't destroy them in
		// case of other algorithms linking to it.
		rawTargets = null;
		rawSubjects = null;

		System.out.println(getClass().getCanonicalName() + "\t"
				+ "VABRA - FINISHED");
		edu.jhu.ece.iacl.plugins.labeling.MedicAlgorithmMultiAtlasSurfaceLabeling
				.generateMemoryReport();

	}

	private void initialize() {
		// initialize variables
		rawTargets = inParamTargets.getImageDataList();// no need to clone,
														// since we don't need
														// the original anymore
		// Clone Input subject since we'll need the original later
		rawSubjects = new ArrayList<ImageData>();
		for (int i = 0; i < inParamSubjects.getImageDataList().size(); i++) {
			rawSubjects.add(inParamSubjects.getImageDataList().get(i).clone());
		}
		rawLabelToMaintainHomeo = inParamLabelsToMaintainHomeo.getImageData();
		transformedSubjects = new ArrayList<ImageData>();
		paddedSubjects = new ArrayList<ImageData>();
		paddedTargets = new ArrayList<ImageData>();
		vabra = new VabraAlgorithm();
		is2DImage = false;

		padSize = inParamPadSize.getInt();

		directionsOptmizationWeight = new double[3];
		directionsOptmizationWeight[0] = inParamOptimizeXDir.getDouble();
		directionsOptmizationWeight[1] = inParamOptimizeYDir.getDouble();
		directionsOptmizationWeight[2] = inParamOptimizeZDir.getDouble();

		// initialize Interpolation Types
		transformInterpType = new TransformVolume.Interpolation[rawSubjects
				.size()];
		interpType = new int[rawSubjects.size()];
		interpTypeN = inParamInterpChanN.getValue();
		for (int i = 0; i < rawSubjects.size(); i++) {
			if (i < 3)
				interpType[i] = inParamInterpChan1to3[i].getIndex();
			else
				interpType[i] = interpTypeN.get(i - 3).intValue();

			if (interpType[i] == RegistrationUtilities.InterpolationType.TRILINEAR)
				transformInterpType[i] = TransformVolume.Interpolation.Trilinear;
			else
				transformInterpType[i] = TransformVolume.Interpolation.Nearest_Neighbor;
		}

	}

	private void preProcessing() {

		// robust histogram inputs
		if (inParamRobustMaxT.getFloat() != 0
				|| inParamRobustMinT.getFloat() != 0) {
			RegistrationUtilities.threshAtRobustMaxAndMin(rawTargets,
					inParamRobustMinT.getFloat(), inParamRobustMaxT.getFloat());
			RegistrationUtilities.threshAtRobustMaxAndMin(rawSubjects,
					inParamRobustMinT.getFloat(), inParamRobustMaxT.getFloat());
		}

		if (is2DImage) { // if is 2D images, then we skip FLIRT (since MIPAV
							// cannot handle it)
			transformedSubjects = rawSubjects;
			inParamUseFlirt.setValue(false);
		} else
			// run FLIRT and Transform first if needed
			affineAlignment();

		// Padding Subject and Targets
		paddedSubjects = RegistrationUtilities.padImageList(
				transformedSubjects, padSize, inParamPadMethod.getIndex());
		paddedTargets = RegistrationUtilities.padImageList(rawTargets, padSize,
				inParamPadMethod.getIndex());

		if (rawLabelToMaintainHomeo != null) {
			List<ImageData> temp = new ArrayList<ImageData>();
			temp.add(transformedLabelToMaintainHomeo);
			paddedLabelToMaintainHomeo = RegistrationUtilities.padImageList(
					temp, padSize, inParamPadMethod.getIndex()).get(0);
		}

		for (ImageData vol : transformedSubjects)
			vol.dispose();// destroy, no longer needed
		transformedSubjects.clear();

	}

	private void postProcessing() {

		// Get rid of padded Images
		for (ImageData vol : paddedTargets)
			vol.dispose();
		for (ImageData vol : paddedSubjects)
			vol.dispose();

		// Cropping Deformation from Previous Padding
		ImageDataFloat croppedDefField;
		ImageDataFloat defField = vabra.getDeformationField();
		int rows = defField.getRows() - 2 * padSize;
		int cols = defField.getCols() - 2 * padSize;
		int slcs = defField.getSlices() - 2 * padSize;
		croppedDefField = defField.mimic(rows, cols, slcs, 3);
		for (int c = 0; c < 3; c++) {
			for (int i = 0; i < rows; i++)
				for (int j = 0; j < cols; j++)
					for (int k = 0; k < slcs; k++) {
						croppedDefField.set(
								i,
								j,
								k,
								c,
								defField.getDouble(i + padSize, j + padSize, k
										+ padSize, c));
					}
		}
		croppedDefField.setHeader(defField.getHeader());
		croppedDefField.setName(inParamSubjects.getImageDataList().get(0)
				.getName()
				+ "_def_field");
		defField.dispose();
		outParamDeformationField.setValue(croppedDefField);// set output for
															// deformation field

		// Apply affine transformation and deformation field to original images.
		ImageData transCurrentImg, deformedCurrentImg;
		// Combine affine and deformation field into one field if necessary

		Matrix xfmIdentity = new Matrix(4, 4);
		for (int ii = 0; ii < 4; ii++)
			xfmIdentity.set(ii, ii, 1);
		float[] res = rawTargets.get(0).getHeader().getDimResolutions();
		Point3f resolutions = new Point3f(res[0], res[1], res[2]);
		Point3i dimensions = new Point3i(rawTargets.get(0).getRows(),
				rawTargets.get(0).getCols(), rawTargets.get(0).getSlices());
		ModelImage currentModelImage;

		if (inParamUseFlirt.getValue() == true)
			croppedDefField = (ImageDataFloat) RegistrationUtilities
					.combineTransAndDef(outParamTransformMatrices.getValue(),
							croppedDefField);

		for (int i = 0; i < inParamSubjects.getImageDataList().size(); i++) {
			currentModelImage = inParamSubjects.getImageDataList().get(i)
					.getModelImageCopy();
			transCurrentImg = TransformVolume.transform(currentModelImage,
					transformInterpType[i], xfmIdentity, resolutions,
					dimensions);

			currentModelImage.disposeLocal();

			deformedCurrentImg = transCurrentImg.clone();
			deformedCurrentImg.setHeader(transCurrentImg.getHeader());
			deformedCurrentImg.setName(transCurrentImg.getName() + "_reg");
			RegistrationUtilities.DeformImage3D(transCurrentImg,
					deformedCurrentImg, croppedDefField, rows, cols, slcs,
					RegistrationUtilities.InterpolationType.TRILINEAR);
			outParamDeformedSubject.add(deformedCurrentImg);
		}
	}

	private void affineAlignment() {
		// affine first if needed
		if (inParamUseFlirt.getValue() == true) {
			System.out.println("*********Running Affine First**************");
			// set inputs
			flirt.target.setValue(rawTargets.get(0));
			flirt.source.setValue(rawSubjects.get(0));
			//flirt.dof.setValue(2);
			flirt.run();
			outParamTransformMatrices.setValue(flirt.trans.getValue());
			flirt.registered.getImageData().dispose(); // don't need this output
			System.out.println(getClass().getCanonicalName() + "\t"
					+ "*********Affine Finished**************");
		} else {
			System.out
					.println("*********Using Identity Affine Transform**************");
			Matrix xfmIdentity = new Matrix(4, 4);
			for (int ii = 0; ii < 4; ii++)
				xfmIdentity.set(ii, ii, 1);
			outParamTransformMatrices.setValue(xfmIdentity);
		}

		// transform if needed
		Matrix transMatrix = outParamTransformMatrices.getValue();
		Point3i dimensions;
		Point3f resolutions;
		float[] res;
		res = rawTargets.get(0).getHeader().getDimResolutions();
		resolutions = new Point3f(res[0], res[1], res[2]);

		dimensions = new Point3i(rawTargets.get(0).getRows(), rawTargets.get(0)
				.getCols(), rawTargets.get(0).getSlices());

		for (int q = 0; q < rawSubjects.size(); q++) {
			ModelImage tempModelImage = rawSubjects.get(q).getModelImageCopy();

			transformedSubjects.add(q, TransformVolume.transform(
					tempModelImage, transformInterpType[q], transMatrix,
					resolutions, dimensions));

			System.out.format(transformedSubjects.get(q).getType().name());
			tempModelImage.disposeLocal();
		}

		if (rawLabelToMaintainHomeo != null)
			transformedLabelToMaintainHomeo = TransformVolume.transform(
					rawLabelToMaintainHomeo.getModelImageCopy(),
					transformInterpType[0], transMatrix, resolutions,
					dimensions);

	}

	private boolean checkSameInputs() {
		CompareVolumes vc;
		ImageData src, dest;
		for (int i = 0; i < rawSubjects.size(); i++) {
			src = rawSubjects.get(i);
			dest = rawTargets.get(i);
			if (src.getRows() == 1 || src.getCols() == 1
					|| src.getSlices() == 1 || dest.getRows() == 1
					|| dest.getCols() == 1 || dest.getSlices() == 1) {
				is2DImage = true;
				return false;
			}

			vc = new CompareVolumes(src, dest, -1E30, 1);
			vc.compare();
			if (!vc.isComparable() || vc.getMinError() != 0
					|| vc.getMaxError() != 0) {
				return false;
			}
		}

		// set outputs as original subject with zero deformation, and identity
		// transform matrix
		System.out.println(getClass().getCanonicalName() + "\t"
				+ "Same Image, no registration needed");
		outParamDeformedSubject.setValue(rawSubjects);

		ImageDataFloat tempF = new ImageDataFloat(rawTargets.get(0).getRows(),
				rawTargets.get(0).getCols(), rawTargets.get(0).getSlices(), 3);
		tempF.setName("def_field");
		outParamDeformationField.setValue(tempF);

		Matrix xfmIdentity = new Matrix(4, 4);
		for (int ii = 0; ii < 4; ii++)
			xfmIdentity.set(ii, ii, 1);
		outParamTransformMatrices.setValue(xfmIdentity);
		return true;
	}

	public void hideVolumeInputs() {
		inParamSubjects.setHidden(true);
		inParamTargets.setHidden(true);
	}
}
