package edu.jhu.ece.iacl.plugins.classification;

import java.awt.Dimension;



import edu.jhmi.rad.medic.algorithms.AlgorithmAdaptiveEMSegmentation;
import edu.jhmi.rad.medic.algorithms.AlgorithmAtlasEMSegmentation;
import edu.jhmi.rad.medic.methods.DemonToadDeformableAtlas;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.algorithms.PrinceGroupAuthors;

import edu.jhu.ece.iacl.jist.pipeline.parameter.*;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.jist.structures.image.VoxelType;

import gov.nih.mipav.model.structures.ModelImage;

/**
 * EM segmentation of brain images using adaptive statistical atlases
 * 
 * @author Navid Shiee
 * 
 */

public class MedicAlgorithmAtlasEMSegmentation extends ProcessingAlgorithm {




	ParamVolume MPRAGE;
	ParamVolume SPGR;
	/*ParamVolume T2;
	ParamVolume PD;
	ParamVolume FLAIR;
	ParamVolume ATLAS;*/
		
//	private ParamVolume VentHard;

	ParamFile atlasFile;

	ParamBoolean correctInhomogeneity;
	
	ParamDouble smoothParam;

	 ParamInteger maxIters;
	
	 ParamDouble maxDiff;
			
	 ParamOption registerationMode;
			
	 ParamOption outputType;
			
	 ParamVolume classification;
	
	 ParamVolume field;
	
	 ParamBoolean outputField;
	
	 /*ParamInteger polynomialDegree;
	
	 ParamOption correctionMethod;*/
	
	 ParamOption priorMode;
	
	 ParamFloat demonsSmoothing;
	
	 ParamFloat demonsScale;
	
	 //ParamFloat kernelSize;

	 ParamVolume memberships;
	
	 ParamVolume old_prior;
	
	 ParamVolume wmfill;
	
	 ParamVolume gm;
	
	 ParamVolume wmMask;
	
	 ParamBoolean outputAtlases;

	
	
	private static final String revnum = new AlgorithmAtlasEMSegmentation().get_version();
	/**
	 * Create Input parameters for LesionTOADS as specified in AlgorithmLesionToads The
	 * boundaries for these variables do not necessarily match those specified
	 * in the original dialog
	 */
	protected void createInputParameters(ParamCollection inputParams) {
		
		
		MPRAGE =  new ParamVolume("T1_MPRAGE Image", VoxelType.UBYTE);
		MPRAGE.setMandatory(false);
		SPGR =  new ParamVolume("T1_SPGR Image", VoxelType.UBYTE);
		SPGR.setMandatory(false);
		atlasFile = new ParamFile("Atlas file");
		String filename = "Atlas/AtlasEM/atlas_4obj_EM_2012.txt";
        try {

        ClassLoader cl = Thread.currentThread().getContextClassLoader();
        atlasFile.setValue(cl.getResource(filename).getFile());
        } catch (Exception e) {
                System.out.print("Error: Unable to set default atlas\n");
                //System.out.print("Error: "+e.getMessage()+ "\n");
        }
				
		correctInhomogeneity = new ParamBoolean("Correct inhomogeneity");
		correctInhomogeneity.setValue(false);
		correctInhomogeneity.setDescription("Correct MR field inhomogeneity.");
		
		outputAtlases = new ParamBoolean("Output Atlases", false);
		outputAtlases.setDescription("Output the statisitcal and adaptive atlases.");
		
		outputField= new ParamBoolean("Output inhomogeniety filed", false);
		outputField.setDescription("Output the estimated inhomogeneity field");
		
		/*polynomialDegree = new ParamInteger("Inhomogeneity field degree", 1, 4);
		polynomialDegree.setValue(3);
		polynomialDegree.setDescription("Polynomial degree for MR field estimation.");
		
		String[] fieldOption = {"Chebyshev","Splines"};
		correctionMethod = new ParamOption("Correction Method",fieldOption);
		correctionMethod.setValue("Chebyshev");
		correctionMethod.setDescription("The type of polynomial used for inhomogeneity correction");
		
	    kernelSize= new ParamFloat("Kernel size",0,1000.0f);
	    kernelSize.setValue(30.0f);
	    kernelSize.setDescription("The krenel size for spline polynimals");*/
		
		smoothParam = new ParamDouble("Smooting parameter", 0, 1E10);
		smoothParam.setValue(5.0);
		smoothParam.setDescription("Controls the effect of neighberhood voxels on the membership");
		
		maxIters = new ParamInteger("Maximum iterations", 0, 100000);
		maxIters.setValue(99);
				
		maxDiff = new ParamDouble("Maximum difference", 0.0, 1E10);
		maxDiff.setValue(0.01);
		maxDiff.setDescription("Maximum amount of relative change in the energy function considered as the convergence criteria");
				
		
			
		String[]	registerationType = {"rigid","multi_fully_affine"};
		registerationMode = new ParamOption("registeration Type", registerationType);
		registerationMode.setValue("rigid");
		registerationMode.setDescription("The method which is used for the registration of the atlases to the image");
		
		demonsScale = new ParamFloat("Demons Scale",0,10,2.0f);
		demonsSmoothing = new ParamFloat("Demons Smoothing",0,10,1.0f);
		
		String[] priorType={"Equal","Class dependent","Atlas"};
		priorMode = new ParamOption("Prior Type",priorType);
		priorMode.setValue("Atlas");
		
		String[] out = {"classification","class+member","dura removal inputs"};
		outputType=new ParamOption("Output images", out);
		outputType.setValue("class+member");
		
		// Add input to main pane
		//inputParams.add(inputImages);
		
		ParamCollection mainParams = new ParamCollection("Main");
		mainParams.add(MPRAGE);
		mainParams.add(SPGR);
		/*mainParams.add(T2);
		mainParams.add(PD);
		mainParams.add(FLAIR);*/
		mainParams.add(smoothParam);
		mainParams.add(maxDiff);
		mainParams.add(maxIters);
		mainParams.add(outputType);
		mainParams.add(correctInhomogeneity);
		mainParams.add(outputField);
		mainParams.add(outputAtlases);
		
		ParamCollection atlasParams = new ParamCollection("Atlas option");
		atlasParams.add(priorMode);
		atlasParams.add(atlasFile);
		atlasParams.add(registerationMode);
		atlasParams.add(demonsSmoothing);
		atlasParams.add(demonsScale);
	
			
			
		/*advancedParams.add(correctionMethod);
		advancedParams.add(polynomialDegree);
		advancedParams.add(kernelSize);*/
		
		inputParams.add(mainParams);
		inputParams.add(atlasParams);
		inputParams.setPackage("IACL");
		inputParams.setCategory("Classification");
		inputParams.setName("atlas_EM_segmentation");
		inputParams.setLabel("Atlas EM Segmentation");

		AlgorithmInformation info = getAlgorithmInformation();
		info.add(PrinceGroupAuthors.navidShiee);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.Dep);
		info.setDescription("Algorithm for brain structure segmentation using EM algorithm" );
	}

	/**
	 * Create output Parameters for TOADS. Note: Not all output fields are
	 * populated and the non-populated fields will create an error that is
	 * caught by the dialog.
	 */
	protected void createOutputParameters(ParamCollection outputParams) {
		outputParams.add(classification = new ParamVolume("Hard segmentation",VoxelType.UBYTE));
		classification.setMandatory(false);
		outputParams.add(field = new ParamVolume("Inhomogeneity Field",VoxelType.FLOAT));
		field.setMandatory(false);
		outputParams.add(memberships = new ParamVolume("Membership Functions",VoxelType.FLOAT));
		memberships.setMandatory(false);
		gm = new ParamVolume("Cortical GM Membership",VoxelType.FLOAT);
		outputParams.add(gm);
		gm.setMandatory(false);
		wmfill = new ParamVolume("Filled WM Membership",VoxelType.UBYTE);
		outputParams.add(wmfill);
		wmfill.setMandatory(false);
		wmMask = new ParamVolume("WM Mask",VoxelType.UBYTE);
		outputParams.add(wmMask);
		wmMask.setMandatory(false);
		old_prior = new ParamVolume("Atlas_Priors",VoxelType.FLOAT);
		outputParams.add(old_prior);
		old_prior.setMandatory(false);
		outputParams.setLabel("Atlas EM Segmentation");
		outputParams.setName("atlas_EM_segmentation" );
	}

	/**
	 * Execute the Adaptive EM algorithm given the input parameters
	 */
	protected class AtlasEMWrapper extends AbstractCalculation{
		public AtlasEMWrapper(){
			setLabel("Atlas EM segmentation");
		}
		public void execute(){
						
            System.out.println("Loading Images");
			int i = 0;
			/*boolean mprage = (MPRAGE.getImageData() != null);
			boolean spgr = (SPGR.getImageData() != null);
			boolean flair = (FLAIR.getImageData() != null);
			boolean pd = (PD.getImageData() != null);
			boolean t2 = (T2.getImageData() != null);*/
			ModelImage[] images = null;
			String[] modals = null;
			if (MPRAGE.getImageData() != null){
				images = new ModelImage[1];
				images[0] = MPRAGE.getImageData().getModelImageCopy();
				modals = new String[1];
				modals[0] = "T1_MPRAGE";
			}else{
				images = new ModelImage[1];
				images[0] = SPGR.getImageData().getModelImageCopy();
				modals = new String[1];
				modals[0] = "T1_SPGR";
			}
			
			DemonToadDeformableAtlas atlas = new DemonToadDeformableAtlas(atlasFile.getValue().getAbsolutePath());
			/*if (ATLAS.getImageData()!=null){
				float[] res = ATLAS.getImageData().getHeader().getDimResolutions();
				System.out.println("New resolution: "+res[0]+" "+res[1]+" "+res[2] );
				int[]  temp_dim= new int[]{ATLAS.getImageData().getRows(),ATLAS.getImageData().getCols(),ATLAS.getImageData().getSlices()};
				float[] temp_res = new float[]{res[0],res[1],res[2]}; 
				System.out.println("New dimension: " + ATLAS.getImageData().getRows()+" "+ATLAS.getImageData().getCols()+" "+ATLAS.getImageData().getSlices());
				float[][][][] external_priors = new ImageDataFloat(ATLAS.getImageData()).toArray4d();
				float[][][][] external_shapes = new float[external_priors[0][0][0].length][external_priors.length][external_priors[0].length][external_priors[0][0].length];
				for (int x=0; x< external_priors.length; x++) for (int y=0; y < external_priors[0].length; y++) for (int z =0; z<external_priors[0][0].length; z++) for (int k= 0; k <external_priors[0][0][0].length; k++)
					external_shapes[k][x][y][z] = external_priors[x][y][z][k];
				//System.out.println("Atlas numbers are "+external_priors.length+ " * "+external_priors[0].length + " * " + external_priors[0][0].length + external_priors[0][0][0].length);
				atlas.setShapes(external_shapes, temp_dim,temp_res);
			}*/
			setTotalUnits(1);
			int polyDegree=3;
			float kernelSize =30.0f;
			String correctionMethod="Chebyshev";
			AlgorithmAtlasEMWrapper algo = new AlgorithmAtlasEMWrapper(images, images.length,
					modals, 
					atlasFile.getValue().getAbsolutePath(), atlas, outputType.getValue(), 
					smoothParam.getFloat(), maxIters.getInt(), maxDiff.getFloat(), 0,
					registerationMode.getValue(),
					true, 
					4, 50, 1, demonsSmoothing.getFloat(),demonsScale.getFloat(),
					correctInhomogeneity.getValue(),outputField.getValue(),
					//correctionMethod.getValue(), polynomialDegree.getInt(),kernelSize.getFloat(),
					correctionMethod, polyDegree,kernelSize,
					priorMode.getValue()
					) ;
			algo.setObserver(this); 
			algo.runAlgorithm();
			 
//			export the images needed
			ModelImage[] resultImage=algo.getResultImages();
			i=0;
			ImageHeader imagesHeader=new ImageDataMipav(images[0]).getHeader();
			if (outputAtlases.getValue()){
				old_prior.setValue(new ImageDataMipav(resultImage[0]));
				old_prior.getImageData().setHeader(imagesHeader);
			}
			i++;
			
			if ( outputType.getValue().equals("class+member")){
				memberships.setValue(new ImageDataMipav(resultImage[i]));
				i++;
			}
			
		
			
			classification.setValue(new ImageDataMipav(resultImage[i]));
			classification.getImageData().setHeader(imagesHeader);
			i++;
					
			if ( outputType.getValue().equals("dura removal inputs")){
				gm.setValue(new ImageDataMipav(resultImage[i]));
				gm.getImageData().setHeader(imagesHeader);
				wmfill.setValue(new ImageDataMipav(resultImage[i+1]));
				wmfill.getImageData().setHeader(imagesHeader);
				wmMask.setValue(new ImageDataMipav(resultImage[i+2]));
				wmMask.getImageData().setHeader(imagesHeader);
				i+=3;
			}
			
				
			if (outputField.getValue() && correctInhomogeneity.getValue())
				field.setValue(new ImageDataMipav(resultImage[i]));
					
						
			for(int k=0;k < images.length; k++) images[k].disposeLocal();
			for(int k=0;k < resultImage.length; k++) resultImage[k].disposeLocal();
			images = null;
			resultImage = null;
			System.gc();
			markCompleted();
	
		}
	}
	
	protected class AlgorithmAtlasEMWrapper extends AlgorithmAtlasEMSegmentation{
		protected AbstractCalculation observer;
		public AlgorithmAtlasEMWrapper(ModelImage[] srcImg_, int nInput_,String[] imgModal_, 
				String aName_, DemonToadDeformableAtlas atlas_, String segOutput_, float smooth_, 
				int nIterMax_, float distMax_, float bgth_, 
				String spcMode_, boolean register_, 
				int lev_, int init_, int main_, float dSmooth_, float dScale_, boolean correct_, boolean outputField_, String correctType_, int poly_, float kernel_,
				String prior_) {
			super( srcImg_,  nInput_, imgModal_, aName_,  atlas_,  segOutput_, smooth_,  nIterMax_,  distMax_,  bgth_,
					spcMode_,register_,lev_, init_, main_, dSmooth_, dScale_,  correct_, outputField_, correctType_, poly_, kernel_,
					prior_);
		}
		
		
		
		public void setObserver(AbstractCalculation observer){
			this.observer=observer;
		}
		public void runAlgorithm(){
			observer.setTotalUnits(100);
			super.runAlgorithm();
			observer.markCompleted();
		}
	    /**
	     * Notifies all listeners that have registered interest for notification on this event type.
	     *
	     * @param  value  the value of the progress bar.
	     */
	    protected void fireProgressStateChanged(int value) {
	        super.fireProgressStateChanged(value);
	        observer.setCompletedUnits(value);
	    }

	    /**
	     * Updates listeners of progress status. Without actually changing the numerical value
	     *
	     * @param  imageName  the name of the image
	     * @param  message    the new message to display
	     */
	    protected void fireProgressStateChanged(String imageName, String message) {
	    	super.fireProgressStateChanged(imageName, message);
	    	observer.setLabel(message);
	    }
	}
	protected void execute(CalculationMonitor monitor) {
		AtlasEMWrapper atlas_em=new AtlasEMWrapper();
		monitor.observe(atlas_em);
		atlas_em.execute();
	}
	public Dimension getPreferredSize() {
		return new Dimension(600,400);
	}
}





