package edu.jhu.ece.iacl.plugins.classification;


import java.awt.Dimension;
import java.util.ArrayList;
import java.util.HashMap;

import edu.jhmi.rad.medic.algorithms.AlgorithmLesionToads;
import edu.jhmi.rad.medic.methods.DemonToadDeformableAtlas;

import edu.jhu.ece.iacl.algorithms.PrinceGroupAuthors;
import edu.jhu.ece.iacl.algorithms.ReferencedPapers;
import edu.jhu.ece.iacl.structures.image.ImageDataMath;

import edu.jhu.ece.iacl.jist.pipeline.parameter.*;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipavWrapper;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.jist.structures.image.VoxelType;

import gov.nih.mipav.model.structures.ModelImage;


/**
 * Lesion TOADS Algorithm (Variation of MedicAlgorithmMultioTOADS)
 * Using TOADS for segmenting MS Lesion as well as other tissues 
 * 
 * @author Navid Shiee
 * 
 */

public class MedicAlgorithmLesionToads extends ProcessingAlgorithm {



//	private ParamWeightedVolumeCollection<String> inputImages;
	private ParamVolume MPRAGE;
	private ParamVolume SPGR;
	//private ParamVolume T2;
	//private ParamVolume PD;
	private ParamVolume FLAIR;
	
		
//	private ParamVolume VentHard;

	private ParamOption atlasSelect;
	private ParamFile atlasFile_Lesions;
	private ParamFile atlasFile_noLesions_T1_FLAIR;
	private ParamFile atlasFile_noLesions_T1only;

	private ParamBoolean correctInhomogeneity;
	
	//private ParamBoolean scaleData;

	private ParamDouble smoothParam;

	private ParamInteger maxIters;
	
	private ParamDouble maxDiff;
	
	private ParamInteger maxGMDist;
	
	private ParamInteger maxVentDist;
	
	private ParamInteger maxInterVentDist;
		
	//private ParamDouble spread;
	
	private ParamDouble atlasPrior;

	//private ParamDouble atlasRange;
	
	//private ParamFloat demonsSmoothing;
	
	//private ParamFloat demonsSacle;
	
	//private ParamOption distanceMode;
	
	private ParamOption alignType;
	
	private ParamOption connectivity;
	
	private ParamOption outputType;
	
	//private ParamOption normType;
	
	//private ParamOption centroidMode;
	
    //private ParamFloat centroidSmoothness;
	
	//private ParamBoolean useLesionWeight;
	
	private ParamBoolean includeLesions;
	
	private ParamBoolean outputMembershipClassification;
	
	private ParamVolume classification;
	
	private ParamVolume classification_mem;

	private ParamVolume field;
	
	private ParamBoolean outputField;
	
	//private ParamInteger polynomialDegree;
	
	//private ParamOption correctionMethod;
	
	//private ParamFloat kernelSize;

	private ParamVolume memberships;

	private ParamVolume lesions;
	
	//private ParamVolume prior;
	
	private ParamVolume wmfill;
	
	private ParamVolume gm;
	
	//private ParamVolume gm_cl;
	
	//private ParamVolume wm_cl;
	
	private ParamVolume csf;
	
	private ParamVolume wmMask;
	
	//private ParamVolume totalwmMask;
	
	private HashMap<String,String> connMap;

	private static final String revnum = new AlgorithmLesionToads().get_version();
	/**
	 * Create Input parameters for LesionTOADS as specified in AlgorithmLesionToads The
	 * boundaries for these variables do not necessarily match those specified
	 * in the original dialog
	 */
	protected void createInputParameters(ParamCollection inputParams) {
		
		//algorithmInformation.add(ReferencedPapers.toads);
		//algorithmInformation.add(PrinceGroupAuthors.navidShee);

		connMap=new HashMap<String,String>();
		connMap.put("(18,6)","18/6");
		connMap.put("(6,18)","6/18");
		connMap.put("(26,6)","26/6");
		connMap.put("(6,26)","6/26");
		
		MPRAGE =  new ParamVolume("T1_MPRAGE Image", VoxelType.UBYTE);
		MPRAGE.setMandatory(false);
		SPGR =  new ParamVolume("T1_SPGR Image", VoxelType.UBYTE);
		SPGR.setMandatory(false);
		/*T2 =  new ParamVolume("T2 Image", VoxelType.UBYTE);
		T2.setMandatory(false);
		PD =  new ParamVolume("PD Image", VoxelType.UBYTE);
		PD.setMandatory(false);*/
		FLAIR =  new ParamVolume("FLAIR Image", VoxelType.UBYTE);
		FLAIR.setMandatory(false);

		String[] atlasTypes = {"With Lesion", "No Lesion"};
		atlasSelect = new ParamOption("Atlas to Use",atlasTypes);
		
		
		atlasFile_Lesions = new ParamFile("Atlas File - With Lesions");
		String filename = "Atlas/lesionToads-atlas-2012/cruise-atlas-12obj-lesiontoads2012.txt";
        try {
        ClassLoader cl = Thread.currentThread().getContextClassLoader();
        atlasFile_Lesions.setValue(cl.getResource(filename).getFile());
        } catch (Exception e) {
                System.out.print("Error: Unable to set default atlas\n");
                //System.out.print("Error: "+e.getMessage()+ "\n");
        }

        atlasFile_noLesions_T1_FLAIR = new ParamFile("Atlas File - No Lesion - T1 and FLAIR");
        filename = "Atlas/lesionToads-atlas-2012/cruise-atlas-12obj-toads2012.txt";
        try {
            ClassLoader cl = Thread.currentThread().getContextClassLoader();
            atlasFile_noLesions_T1_FLAIR.setValue(cl.getResource(filename).getFile());
        } catch (Exception e) {
                System.out.print("Error: Unable to set default atlas\n");
                //System.out.print("Error: "+e.getMessage()+ "\n");
        }
        
        atlasFile_noLesions_T1only = new ParamFile("Atlas File - No Lesion - T1 Only");
        filename = "Atlas/lesionToads-atlas-2012/cruise-atlas-10obj-toads2010.txt";
        try {
            ClassLoader cl = Thread.currentThread().getContextClassLoader();
            atlasFile_noLesions_T1only.setValue(cl.getResource(filename).getFile());
        } catch (Exception e) {
                System.out.print("Error: Unable to set default atlas\n");
                //System.out.print("Error: "+e.getMessage()+ "\n");
        }
        
        
		correctInhomogeneity = new ParamBoolean("Correct inhomogeneity");
		correctInhomogeneity.setValue(false);
		correctInhomogeneity.setDescription("Correct MR field inhomogeneity.");
		
		outputField= new ParamBoolean("Output inhomogeniety filed", false);
		outputField.setDescription("Output the estimated inhomogeneity field");
		
		/*polynomialDegree = new ParamInteger("Inhomogeneity field degree", 1, 4);
		polynomialDegree.setValue(3);
		polynomialDegree.setDescription("Polynomial degree for MR field estimation.");
		
		String[] fieldOption = {"Chebyshev","Splines"};
		correctionMethod = new ParamOption("Correction Method",fieldOption);
		correctionMethod.setValue("Chebyshev");
		correctionMethod.setDescription("The type of polynomial used for inhomogeneity correction");
		
	    kernelSize= new ParamFloat("Kernel size",0,1000.0f);
	    kernelSize.setValue(30.0f);
	    kernelSize.setDescription("The krenel size for spline polynimals");*/
		
		smoothParam = new ParamDouble("Smooting parameter", 0, 1E10);
		smoothParam.setValue(0.2);
		smoothParam.setDescription("Controls the effect of neighberhood voxels on the membership");
		
		maxIters = new ParamInteger("Maximum iterations", 0, 100000);
		maxIters.setValue(99);
				
		maxDiff = new ParamDouble("Maximum difference", 0.0, 1E10);
		maxDiff.setValue(0.0001);
		maxDiff.setDescription("Maximum amount of relative change in the energy function considered as the convergence criteria");
				
		atlasPrior = new ParamDouble("Atlas prior", 0, 1E10);
		atlasPrior.setValue(2.0);
		atlasPrior.setDescription("Controls the effect of the statistical atlas on the segmentation");
		
		
		includeLesions = new ParamBoolean("Include lesions in white matter", true);
		includeLesions.setDescription("Include lesion in WM class in hard classification");
		
		outputMembershipClassification = new ParamBoolean ("Output max membership classification", false);
		outputMembershipClassification.setDescription("Output the hard classification using maximum membership (not neceesarily topologically correct)");
	
		
		//scaleData = new ParamBoolean("Scale Cruise Output",true);
		
		
		maxGMDist = new ParamInteger("Maximum GM Distance",0,256);
		maxGMDist.setValue(3);
		maxGMDist.setDescription("Maximum distance from the GM boundary to downweight the lesion membership to avoid false postives");
		maxVentDist = new ParamInteger("Maximum Ventircle Distance",0,256);
		maxVentDist.setValue(2);
		maxGMDist.setDescription("Maximum distance from the Ventricles boundary to downweight the lesion membership to avoid false postives");
		maxInterVentDist = new ParamInteger("Maximum InterVentricular Distance",0,256);
		maxInterVentDist.setValue(25);
		maxGMDist.setDescription("Maximum distance from the interventricular WM boundary to downweight the lesion membership to avoid false postives");
		/*spread = new ParamDouble("Minimum Distance", 0, 1E10);
		spread.setValue(1.0);
		spread.setDescription("The defualt value of the distance function inside the structures");*/
		
		connectivity=new ParamOption("Connectivity (foreground,background)",new ArrayList<String>(connMap.keySet()));
		
		String[] out = {"hard segmentation","hard segmentation+memberships","cruise inputs","dura removal inputs"};
		outputType=new ParamOption("Output images", out);
		outputType.setValue("cruise inputs");
		
		String[] align = {"rigid","multi_fully_affine"};
		alignType=new ParamOption("Atlas alignment", align);
		
		ParamCollection main = new ParamCollection("Main Inputs");
		ParamCollection atlases = new ParamCollection("Atlas Files");
		ParamCollection lesions = new ParamCollection("Lesion Options");
		ParamCollection advanced = new ParamCollection("Advanced Options");

		main.add(MPRAGE);
		main.add(SPGR);
		main.add(FLAIR);
		main.add(atlasSelect);
		main.add(outputType);
		main.add(outputMembershipClassification);
		main.add(correctInhomogeneity);
		main.add(outputField);
		inputParams.add(main);
		
		
		atlases.add(atlasFile_Lesions);
		atlases.add(atlasFile_noLesions_T1_FLAIR);
		atlases.add(atlasFile_noLesions_T1only);
		inputParams.add(atlases);
		
		lesions.add(maxGMDist);
		lesions.add(maxVentDist);
		lesions.add(maxInterVentDist);
		//lesions.add(spread);
		lesions.add(includeLesions);
		inputParams.add(lesions);
		
		
		
		/*advanced.add(polynomialDegree);
		advanced.add(correctionMethod);
		advanced.add(kernelSize);*/
		advanced.add(atlasPrior);
		advanced.add(smoothParam);
		advanced.add(maxDiff);
		advanced.add(maxIters);
		//advanced.add(scaleData);
		advanced.add(alignType);
		advanced.add(connectivity);		
		inputParams.add(advanced);
		
		inputParams.setPackage("IACL");
		inputParams.setCategory("Classification");
		inputParams.setName("lesion toads");
		inputParams.setLabel("Lesion TOADS");

		AlgorithmInformation info = getAlgorithmInformation();
		info.add(PrinceGroupAuthors.navidShiee);
		info.add(ReferencedPapers.lesionToads);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.Release);
		info.setDescription("Algorithm for simulataneous brain structures and MS lesion segmentation of MS Brains. The brain segmentation is topologically consistent and the algorithm can use multiple MR sequences as input data." );
		info.setAdditionalDocURL("html/edu/jhu/ece/iacl/plugins/classification/MedicAlgorithmLesionToads/index.html");
	}

	/**
	 * Create output Parameters for TOADS. Note: Not all output fields are
	 * populated and the non-populated fields will create an error that is
	 * caught by the dialog.
	 */
	protected void createOutputParameters(ParamCollection outputParams) {
		outputParams.add(classification = new ParamVolume("Hard segmentation",VoxelType.UBYTE));
		classification.setMandatory(false);
		outputParams.add(classification_mem = new ParamVolume("Hard segmentationfrom memberships",VoxelType.UBYTE));
		classification_mem.setMandatory(false);
		outputParams.add(field = new ParamVolume("Inhomogeneity Field",VoxelType.FLOAT));
		field.setMandatory(false);
		outputParams.add(memberships = new ParamVolume("Membership Functions",VoxelType.FLOAT,-1,-1,-1,-1));
		memberships.setMandatory(false);
		lesions = new ParamVolume("Lesion Segmentation",VoxelType.FLOAT);
		outputParams.add(lesions);
		lesions.setMandatory(false);
		/*prior = new ParamVolume("Priors",VoxelType.FLOAT);
		outputParams.add(prior);
		prior.setMandatory(false);*/
		csf = new ParamVolume("Sulcal CSF Membership",VoxelType.FLOAT);
		outputParams.add(csf);
		csf.setMandatory(false);
		gm = new ParamVolume("Cortical GM Membership",VoxelType.FLOAT);
		outputParams.add(gm);
		gm.setMandatory(false);
		wmfill = new ParamVolume("Filled WM Membership",VoxelType.UBYTE);
		outputParams.add(wmfill);
		wmfill.setMandatory(false);
		wmMask = new ParamVolume("WM Mask",VoxelType.UBYTE);
		outputParams.add(wmMask);
		wmMask.setMandatory(false);
		outputParams.setLabel("Lesion TOADS");
		outputParams.setName("lesion toads" );
	}

	/**
	 * Execute the LesionTOADS algorithm given the input parameters
	 */
	protected class LesionToadsWrapper extends AbstractCalculation{
		public LesionToadsWrapper(){
			setLabel("Lesion TOADS");
		}
		public void execute(){
						
			int i;
			boolean mprage = (MPRAGE.getImageData() != null);
			boolean spgr = (SPGR.getImageData() != null);
			boolean flair = (FLAIR.getImageData() != null);
			//boolean pd = (PD.getImageData() != null);
			//boolean t2 = (T2.getImageData() != null);
			ModelImage[] images = null;
			String[] modals = null;
			if (mprage && flair) {
				images = new ModelImage[2];
				images[0] = MPRAGE.getImageData().getModelImageCopy();
				images[1] = FLAIR.getImageData().getModelImageCopy();
				modals = new String[2];
				modals[0] = "T1_MPRAGE";
				modals[1]="FLAIR";
			}else if (spgr && flair) {
				images = new ModelImage[2];
				images[0] = SPGR.getImageData().getModelImageCopy();
				images[1] = FLAIR.getImageData().getModelImageCopy();
				modals = new String[2];
				modals[0] = "T1_SPGR";
				modals[1]="FLAIR";
			}else if (mprage) {
				images = new ModelImage[1];
				images[0] = MPRAGE.getImageData().getModelImageCopy();
				modals = new String[1];
				modals[0] = "T1_MPRAGE";
			}else if (spgr) {
				images = new ModelImage[1];
				images[0] = SPGR.getImageData().getModelImageCopy();
				modals = new String[1];
				modals[0] = "T1_SPGR";
			}
			
			String selectedAtlas;
			switch (atlasSelect.getIndex()) {
			case 0:
				selectedAtlas = atlasFile_Lesions.getValue().getAbsolutePath();
				break;
			case 1:
				if(mprage && flair){
					selectedAtlas = atlasFile_noLesions_T1_FLAIR.getValue().getAbsolutePath();
					break;
				}
				else{
					selectedAtlas = atlasFile_noLesions_T1only.getValue().getAbsolutePath();
					break;
				}
			default:
				selectedAtlas = atlasFile_Lesions.getValue().getAbsolutePath();
				break;
			}
		
			DemonToadDeformableAtlas atlas = new DemonToadDeformableAtlas(selectedAtlas);
			setTotalUnits(1);
			
			float atlasRange = 0.2f;
			float demonsSacle = 2.0f;
			float demonsSmoothing =  1.0f;
			String normType = "Identity";
			String distanceMode = "normal distance";
			String centroidMode = "prior";
			float centroidSmoothness= 0.1f;
			
			boolean useLesionWeight =true;
			String correctionMethod = "Chebyshev";
			int polynomialDegree = 3;
			float kernelSize = 30.0f;
			float spread = 1.0f;
			//String connectivity = "18/6";
			AlgorithmLesionToadsWrapper algo = new AlgorithmLesionToadsWrapper(images, images.length,
					modals,
					selectedAtlas, atlas, outputType.getValue(), 
					smoothParam.getFloat(), 0.1f, maxIters.getInt(), maxDiff.getFloat(), 0,
					outputMembershipClassification.getValue(), 
				    3.0f, 0.5f, spread,
					(short)maxGMDist.getInt(),(short)maxInterVentDist.getInt(), (short)maxVentDist.getInt(),
					includeLesions.getValue(),
					atlasPrior.getFloat(), atlasRange,
					"diffeomorphism", 
					centroidMode, centroidSmoothness,
					useLesionWeight,
					alignType.getValue(),
					normType,
					true, 
					4, 50, 1, demonsSmoothing, demonsSacle,
					correctInhomogeneity.getValue(),outputField.getValue(), correctionMethod, 
					polynomialDegree,kernelSize,
					distanceMode,
					connMap.get(connectivity.getValue())) ;
			algo.setObserver(this); 
			algo.runAlgorithm();
			 
//			export the images needed
			ModelImage[] resultImage=algo.getResultImages();
			ImageHeader imagesHeader=new ImageHeader();
			imagesHeader= new ImageDataMipav(images[0]).getHeader();
			i=0;
			if ( outputType.getValue().equals("hard segmentation+memberships")){
				memberships.setValue(new ImageDataMipavWrapper(resultImage[i]));
				memberships.getImageData().setHeader(imagesHeader);
				i++;
			}
			
			
			classification.setValue(new ImageDataMipavWrapper(resultImage[i]));
			classification.getImageData().setHeader(imagesHeader);
			i++;
			if (outputMembershipClassification.getValue()) {
				classification_mem.setValue(new ImageDataMipavWrapper(resultImage[i]));
				i++;
				classification_mem.getImageData().setHeader(imagesHeader);
			}
			lesions.setValue(new ImageDataMipavWrapper(resultImage[i]));
			lesions.getImageData().setHeader(imagesHeader);
			//lesions_mem.setValue(resultImage[i+2]);
			//lesions_mem.getModelImage().copyFileTypeInfo(images[0]);
			i++;
			if ( (outputType.getValue().equals("cruise inputs")) || (outputType.getValue().equals("dura removal inputs"))){
				csf.setValue(new ImageDataMipavWrapper(resultImage[i]));
				csf.getImageData().setHeader(imagesHeader);
				gm.setValue(new ImageDataMipavWrapper(resultImage[i+1]));
				gm.getImageData().setHeader(imagesHeader);
				wmfill.setValue(new ImageDataMipavWrapper(resultImage[i+2]));
				wmfill.getImageData().setHeader(imagesHeader);
				wmMask.setValue(new ImageDataMipavWrapper(resultImage[i+3]));
				wmMask.getImageData().setHeader(imagesHeader);
				i+=4;
				/*if (outputType.getValue().equals("dura removal inputs")){
					wm_cl.setValue(resultImage[i]);
					wm_cl.getModelImage().copyFileTypeInfo(images[0]);
					gm_cl.setValue(resultImage[i+1]);
					gm_cl.getModelImage().copyFileTypeInfo(images[0]);
					totalwmMask.setValue(resultImage[i+2]);
					totalwmMask.getModelImage().copyFileTypeInfo(images[0]);
					i +=3;
				}*/
				//if(scaleData.getValue()){
					ImageDataMath.scaleFloatValue(csf.getImageData(), 255.0f);
					ImageDataMath.scaleFloatValue(gm.getImageData(), 255.0f);
					ImageDataMath.scaleFloatValue(wmfill.getImageData(), 255.0f);
					ImageDataMath.scaleFloatValue(wmMask.getImageData(), 255.0f);
					/*if (outputType.getValue().equals("dura removal inputs")){
						ImageDataMath.scaleFloatValue(gm_cl.getImageData(), 255.0f);
						ImageDataMath.scaleFloatValue(wm_cl.getImageData(), 255.0f);
						ImageDataMath.scaleFloatValue(totalwmMask.getImageData(), 255.0f);						
					}*/
				//}				
			}
				
			if (outputField.getValue() && correctInhomogeneity.getValue()){
				field.setValue(new ImageDataMipavWrapper(resultImage[i]));
				field.getImageData().setHeader(imagesHeader);
			}
			
			/*if (outputType.getValue().equals("all_images")){
				prior.setValue(new ImageDataMipavWrapper(resultImage[i]));
				prior.getImageData().setHeader(imagesHeader[0]);
			}*/
			
			
			
			
			for(int k=0;k < images.length; k++) images[k].disposeLocal();
			
			markCompleted();
			
			
			
		}
	}
	
	protected class AlgorithmLesionToadsWrapper extends AlgorithmLesionToads{
		protected AbstractCalculation observer;
		public AlgorithmLesionToadsWrapper(ModelImage[] srcImg_, int nInput_,
				String[] imgModal_,
				String aName_, DemonToadDeformableAtlas atlas_, String segOutput_, 
				float smooth_, float out_, int nIterMax_, float distMax_, float bgth_,
				boolean outputMaxMem_, 
				float fLim_, float lLim_, float spread_,
				short maxGMDist_, short maxBstemDist_, short maxVentDist_,
				boolean inludeLesions_,
				float atlasCoeff_, float atlasScale_,
				String relMode_, 
				String centMode_, float smoothCentr_,
				boolean lesionWeight_,
				String regMode_,
				String nrmType_,
				boolean register_, 
				int lev_, int initAlignIter_, int mainAlignIter_,
				float dSmooth_, float dScale_,
				boolean correct_, boolean outputField_,
				String correctType_,
				int poly_, float kernel_,
				String algo_,
				String connect_) {
			super( srcImg_,  nInput_, imgModal_, aName_,  atlas_,  segOutput_, smooth_,  out_,  nIterMax_,  distMax_,  bgth_,
					 outputMaxMem_, fLim_,  lLim_, spread_,  maxGMDist_, maxBstemDist_, maxVentDist_, inludeLesions_,
					atlasCoeff_, atlasScale_, relMode_, centMode_, smoothCentr_,lesionWeight_, 
					regMode_,nrmType_,register_,lev_, initAlignIter_, mainAlignIter_, dSmooth_, dScale_,  correct_, outputField_, correctType_, poly_, kernel_,
					algo_, connect_);
		}
		public void setObserver(AbstractCalculation observer){
			this.observer=observer;
		}
		public void runAlgorithm(){
			observer.setTotalUnits(100);
			super.runAlgorithm();
			observer.markCompleted();
		}
	    /**
	     * Notifies all listeners that have registered interest for notification on this event type.
	     *
	     * @param  value  the value of the progress bar.
	     */
	    protected void fireProgressStateChanged(int value) {
	        super.fireProgressStateChanged(value);
	        observer.setCompletedUnits(value);
	    }

	    /**
	     * Updates listeners of progress status. Without actually changing the numerical value
	     *
	     * @param  imageName  the name of the image
	     * @param  message    the new message to display
	     */
	    protected void fireProgressStateChanged(String imageName, String message) {
	    	super.fireProgressStateChanged(imageName, message);
	    	observer.setLabel(message);
	    }
	}
	protected void execute(CalculationMonitor monitor) {
		LesionToadsWrapper lesiontoads=new LesionToadsWrapper();
		monitor.observe(lesiontoads);
		lesiontoads.execute();
	}
	public Dimension getPreferredSize() {
		return new Dimension(450,550);
	}
}


