package edu.jhu.ece.iacl.plugins.classification;

import edu.jhu.ece.iacl.jist.io.MipavController;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamDouble;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;

import java.awt.Dimension;

import Jama.Matrix;

import edu.jhu.ece.iacl.algorithms.volume.CompareVolumes;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingApplication;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.Citation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.*;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipavWrapper;
import edu.jhu.ece.iacl.jist.structures.image.VoxelType;
import gov.nih.mipav.model.algorithms.AlgorithmIHN3Correction;
import gov.nih.mipav.model.structures.ModelImage;
import gov.nih.mipav.view.dialogs.JDialogBase;

public class MedicAlgorithmN3 extends ProcessingAlgorithm {
	ParamDouble sigThreshold;
	ParamDouble EndTolerance;
	ParamDouble FieldDistance;
	ParamDouble KernelFWHM;
	ParamDouble WeinerNoise;
	ParamInteger MaxIter;
	ParamDouble SubsampleFactor;
	ParamBoolean autoThresh;
	ParamVolume srcImg,destImg,fieldImg,outputVol;
	
	private static final String cvsversion = "$Revision: 1.2 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");

	
	protected void createInputParameters(ParamCollection inputParams) {
		sigThreshold = new ParamDouble("Signal Threshold", -1E10, 1E10,1);
		sigThreshold.setDescription("Default = min + 1, Values at less than threshold are treated as part of the background");
		EndTolerance = new ParamDouble("End tolerance",0.00001,0.01,0.001);
		EndTolerance.setDescription("Usually 0.01-0.00001, The measure used to terminate the iterations is the coefficient of variation of change in field estimates between successive iterations.");
		FieldDistance = new ParamDouble("Field Distance(mm)",0,1E10,62.0);
		FieldDistance.setDescription("Characteristic distance over which the field varies. The distance between adjacent knots in bspline fitting with at least 4 knots going in every dimension. The default in the dialog is one third the distance (resolution * extents) of the smallest dimension.");
		SubsampleFactor = new ParamDouble("Subsample Factor",1,32,4);
		SubsampleFactor.setDescription("Usually between 1-32, The factor by which the data is subsampled to a lower resolution in estimating the slowly varying non-uniformity field. Reduce sampling in the finest sampling direction by the shrink factor.");
		KernelFWHM = new ParamDouble("Kernel FWHM",0.05,0.50,0.15);
		KernelFWHM.setDescription("Usually between 0.05-0.50, Width of deconvolution kernel used to sharpen the histogram. Larger values give faster convergence while smaller values give greater accuracy.");
		WeinerNoise = new ParamDouble("Weiner Filter Noise",0.0,1.0,0.01);
		WeinerNoise.setDescription("Usually between 0.0-1.0");
		MaxIter = new ParamInteger("Maximum number of Iterations",1,1000,50);
		autoThresh = new ParamBoolean("Automatic Histogram Thresholding",false);
		autoThresh.setDescription("If true determines the threshold by histogram analysis. If true a VOI cannot be used and the input threshold is ignored.");
		srcImg = new ParamVolume("Input Volume");
		
		inputParams.add(srcImg);
		inputParams.add(sigThreshold);
		inputParams.add(MaxIter);
		inputParams.add(EndTolerance);
		inputParams.add(FieldDistance);
		inputParams.add(SubsampleFactor);
		inputParams.add(KernelFWHM);
		inputParams.add(WeinerNoise);
		inputParams.add(autoThresh);
		
		inputParams.setLabel("N3 Correction");
		inputParams.setName("N3Correction");
		inputParams.setPackage("IACL");
		inputParams.setCategory("Classification");

		AlgorithmInformation info = getAlgorithmInformation();
		info.setDescription("Non-parametric Intensity Non-uniformity Correction, N3, by J. G. Sled.");
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.Release);
		info.add(new Citation("Bjórn Hamre \"Three-dimensional image registration of magnetic resonance (MRI) head volumes\", Section for Medical Image Analysis and Informatics Department of Physiology & Department of Informatics University of Bergen, Norway."));
	}


	protected void createOutputParameters(ParamCollection outputParams) {
		destImg = new ParamVolume("Inhomogeneity Corrected Volume", VoxelType.FLOAT, -1, -1, -1, 1);
		fieldImg = new ParamVolume("Inhomogeneity Field", VoxelType.FLOAT, -1, -1, -1, 1);
		outputParams.add(destImg);
		outputParams.add(fieldImg); 

		outputParams.setLabel("N3 Correction");
		outputParams.setName("N3Correction");
	}

	@Override
	protected void execute(CalculationMonitor monitor)
			throws AlgorithmRuntimeException {
		
			N3Wrapper N3 = new N3Wrapper();
			monitor.observe(N3);
			N3.execute();
	}

	protected class N3Wrapper extends AbstractCalculation {
		public N3Wrapper() {
			setLabel("N3");
		}

		public void execute() {
			String name1 = (srcImg.getImageData().getName()+ "_N3Corrected");
			String name2 = (srcImg.getImageData().getName()+ "_N3Field");

			ModelImage a,b;
			ModelImage modelsrcImg = srcImg.getImageData().getModelImageCopy();
			a = new ModelImage(7, modelsrcImg.getExtents(), name1);
			b = new ModelImage(7, modelsrcImg.getExtents(), name2);
			AlgorithmIHN3CorrectionWrapper N3;
			System.out.println(
				"\nthresh = " + sigThreshold.getFloat() +
				"\nmaxiter = " + MaxIter.getFloat() +
				"\nname = " + name1 +
				"\nmodality = " + modelsrcImg.getImageModality() +
				"\n");
			N3 = new AlgorithmIHN3CorrectionWrapper(a, b,
						srcImg.getImageData().getModelImageCopy(), sigThreshold.getFloat(), MaxIter.getInt(), EndTolerance.getFloat(), 
						FieldDistance.getFloat(), SubsampleFactor.getFloat(), KernelFWHM.getFloat(), WeinerNoise.getFloat(), 
						true, autoThresh.getValue(), false);
			
			
			N3.setObserver(this);
			N3.run();

			a.getFileInfo(0).setAxisOrientation(modelsrcImg.getFileInfo(0).getAxisOrientation());
			a.getFileInfo(0).setImageOrientation(modelsrcImg.getFileInfo(0).getImageOrientation());
			a.getFileInfo(0).setResolutions(modelsrcImg.getFileInfo(0).getResolutions());	

			
			MipavController.setModelImageName(a,name1);
			destImg.setValue(new ImageDataMipavWrapper(a));
						
			b.getFileInfo(0).setAxisOrientation(modelsrcImg.getFileInfo(0).getAxisOrientation());
			b.getFileInfo(0).setImageOrientation(modelsrcImg.getFileInfo(0).getImageOrientation());
			b.getFileInfo(0).setResolutions(modelsrcImg.getFileInfo(0).getResolutions());
			
			MipavController.setModelImageName(b,name2);
			fieldImg.setValue(new ImageDataMipavWrapper(b));


			N3.finalize();		
			modelsrcImg.disposeLocal();
		}
	}


	protected static class AlgorithmIHN3CorrectionWrapper extends AlgorithmIHN3Correction {
		public AlgorithmIHN3CorrectionWrapper(ModelImage _destImg, ModelImage _fieldImg, 
				ModelImage _srcImg, float threshold, int maxIters, float endTol, 
				float fieldDistance, float shrink, float kernelfwhm, float noise, 
				boolean entireImage, boolean autoThreshold, boolean useScript)  {
				super(_destImg,_fieldImg,_srcImg, threshold,maxIters,endTol, 
						fieldDistance, shrink, kernelfwhm,noise, 
						entireImage, autoThreshold,useScript);
			
		}


		protected AbstractCalculation observer;


		public void setObserver(AbstractCalculation observer) {
			this.observer = observer;
		}


		public void runAlgorithm() {
			observer.setTotalUnits(100);
			super.runAlgorithm();
			observer.markCompleted();
		}


		protected void fireProgressStateChanged(int value) {
			super.fireProgressStateChanged(value);
			observer.setCompletedUnits(value);
		}


		protected void fireProgressStateChanged(String imageName, String message) {
			super.fireProgressStateChanged(imageName, message);
			observer.setLabel(message);
		}
	}
}
