package edu.vanderbilt.masi.plugins.labelfusion;

import java.io.File;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import java.io.FileNotFoundException;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.vanderbilt.masi.algorithms.labelfusion.*;
import edu.jhu.ece.iacl.jist.utility.FileUtil;

public abstract class AbstractVotingFusionPlugin extends ProcessingAlgorithm {
	
	// Main Parameters parameters
	public transient ParamVolume targetim;
	public transient ParamVolumeCollection obsvols;
	public transient ParamVolumeCollection imsvols;
	public transient ParamOption output_W;
	public transient ParamOption probtype;
	
	// weighting parameters
	public transient ParamOption weight_type;
	public transient ParamInteger sv_dim1;
	public transient ParamInteger sv_dim2;
	public transient ParamInteger sv_dim3;
	public transient ParamInteger sv_dim4;
	public transient ParamInteger pv_dim1;
	public transient ParamInteger pv_dim2;
	public transient ParamInteger pv_dim3;
	public transient ParamInteger pv_dim4;
	public transient ParamFloat spatial_sd_dim1;
	public transient ParamFloat spatial_sd_dim2;
	public transient ParamFloat spatial_sd_dim3;
	public transient ParamFloat spatial_sd_dim4;
	public transient ParamFloat intensity_sd;
	public transient ParamFloat global_sel_thresh;
	public transient ParamFloat local_sel_thresh;
	public transient ParamBoolean use_intensity_normalization;
	
	// patch selection parameters
	public transient ParamOption selection_type;
	public transient ParamFloat selection_thresh;
	public transient ParamInteger num_keep;
	
	// output parameters
	public transient ParamVolume labelvol;
	public transient ParamVolumeCollection labelprob_multiple;
	public transient ParamVolume labelprob_single;
	
	/*
	 * (non-Javadoc)
	 *
	 * @see edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm#createInputParameters(edu.jhu.ece.iacl.pipeline.parameter.ParamCollection)
	 */
	protected abstract void createInputParameters(ParamCollection inputParams);	
	/*
	 * (non-Javadoc)
	 *
	 * @see edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm#createOutputParameters(edu.jhu.ece.iacl.pipeline.parameter.ParamCollection)
	 */
	protected abstract void createOutputParameters(ParamCollection outputParams);
	
	protected ParamCollection get_main_parameters() {

		ParamCollection mainParams = new ParamCollection("Main");
		
		// set the input rater observations
		mainParams.add(obsvols=new ParamVolumeCollection("Atlas Label Volumes"));
		obsvols.setLoadAndSaveOnValidate(false);
		obsvols.setMandatory(true);
		
		// set the input rater observations
		mainParams.add(targetim=new ParamVolume("Target Image Volume"));
		targetim.setLoadAndSaveOnValidate(false);
		targetim.setMandatory(true);
				
		// set the input rater observations
		mainParams.add(imsvols=new ParamVolumeCollection("Atlas Image Volumes"));
		imsvols.setLoadAndSaveOnValidate(false);
		imsvols.setMandatory(true);
		
		// set the probability fusion values
		mainParams.add(probtype = new ParamOption("Probability Dimension ('None' if not probability volumes)", new String[] { "None", "3rd", "4th"}));
		probtype.setValue(0);
		probtype.setMandatory(false);
				
		// set the options for writing the label probabilities
		mainParams.add(output_W = new ParamOption("Output Label Probabilities?", new String[] { "None", "Multiple", "Single"}));
		output_W.setValue(0);
		output_W.setMandatory(false);
		
		return(mainParams);
	}
	
	protected ParamCollection get_main_parameters_no_intensity() {
		ParamCollection mainParams = new ParamCollection("Main");
		
		// set the input rater observations
		mainParams.add(obsvols=new ParamVolumeCollection("Rater Volumes"));
		obsvols.setLoadAndSaveOnValidate(false);
		obsvols.setMandatory(true);
		
		// set the probability fusion values
		mainParams.add(probtype = new ParamOption("Probability Dimension ('None' if not probability volumes)", new String[] { "None", "3rd", "4th"}));
		probtype.setValue(0);
		probtype.setMandatory(false);
				
		// set the options for writing the label probabilities
		mainParams.add(output_W = new ParamOption("Output Label Probabilities?", new String[] { "None", "Multiple", "Single"}));
		output_W.setValue(0);
		output_W.setMandatory(false);
		
		return(mainParams);
	}
	
	protected ParamCollection get_weighting_parameters() {
		ParamCollection weightParams = new ParamCollection("Correspondence Parameters");
		
		// set the prior type
		weightParams.add(weight_type = new ParamOption("Weighting Type", new String[] { "LNCC", "MSD", "Mixed"}));
		weight_type.setValue(1);
		weight_type.setMandatory(false);
		
		// Set the search volume for each dimension
		weightParams.add(sv_dim1 = new ParamInteger("Search Volume Dimension 1 (>0: Image Units, <0: Voxels)"));
		sv_dim1.setValue(0);
		sv_dim1.setMandatory(false);
		weightParams.add(sv_dim2 = new ParamInteger("Search Volume Dimension 2 (>0: Image Units, <0: Voxels)"));
		sv_dim2.setValue(0);
		sv_dim2.setMandatory(false);
		weightParams.add(sv_dim3 = new ParamInteger("Search Volume Dimension 3 (>0: Image Units, <0: Voxels)"));
		sv_dim3.setValue(0);
		sv_dim3.setMandatory(false);
		weightParams.add(sv_dim4 = new ParamInteger("Search Volume Dimension 4 (>0: Image Units, <0: Voxels)"));
		sv_dim4.setValue(0);
		sv_dim4.setMandatory(false);
		
		// Set the patch volume for each dimension
		weightParams.add(pv_dim1 = new ParamInteger("Patch Volume Dimension 1 (>0: Image Units, <0: Voxels)"));
		pv_dim1.setValue(0);
		pv_dim1.setMandatory(false);
		weightParams.add(pv_dim2 = new ParamInteger("Patch Volume Dimension 2 (>0: Image Units, <0: Voxels)"));
		pv_dim2.setValue(0);
		pv_dim2.setMandatory(false);
		weightParams.add(pv_dim3 = new ParamInteger("Patch Volume Dimension 3 (>0: Image Units, <0: Voxels)"));
		pv_dim3.setValue(0);
		pv_dim3.setMandatory(false);
		weightParams.add(pv_dim4 = new ParamInteger("Patch Volume Dimension 4 (>0: Image Units, <0: Voxels)"));
		pv_dim4.setValue(0);
		pv_dim4.setMandatory(false);
		
		// Set the search volume for each dimension
		weightParams.add(spatial_sd_dim1 = new ParamFloat("Search Volume Standard Deviation Dimension 1 (>0: Image Units, <0: Voxels)"));
		spatial_sd_dim1.setValue(1.5);
		spatial_sd_dim1.setMandatory(false);
		weightParams.add(spatial_sd_dim2 = new ParamFloat("Search Volume Standard Deviation Dimension 2 (>0: Image Units, <0: Voxels)"));
		spatial_sd_dim2.setValue(1.5);
		spatial_sd_dim2.setMandatory(false);
		weightParams.add(spatial_sd_dim3 = new ParamFloat("Search Volume Standard Deviation Dimension 3 (>0: Image Units, <0: Voxels)"));
		spatial_sd_dim3.setValue(1.5);
		spatial_sd_dim3.setMandatory(false);
		weightParams.add(spatial_sd_dim4 = new ParamFloat("Search Volume Standard Deviation Dimension 4 (>0: Image Units, <0: Voxels)"));
		spatial_sd_dim4.setValue(1.5);
		spatial_sd_dim4.setMandatory(false);
		
		// set the difference metric standard deviation
		weightParams.add(intensity_sd = new ParamFloat("Difference Metric Standard Deviation"));
		intensity_sd.setValue(0.5);
		intensity_sd.setMandatory(false);
		
		weightParams.add(global_sel_thresh = new ParamFloat("Global Selection Threshold (Range: [0 1], 0 = No Selection)", 0, 1));
		global_sel_thresh.setValue(0);
		global_sel_thresh.setMandatory(false);
		
		weightParams.add(local_sel_thresh = new ParamFloat("Local Selection Threshold (Range: [0 1], 0 = No Selection)", 0, 1));
		local_sel_thresh.setValue(0);
		local_sel_thresh.setMandatory(false);
		
		// set the boolean ignore consensus voxels
		weightParams.add(use_intensity_normalization = new ParamBoolean("Use Intensity Normalization"));
		use_intensity_normalization.setValue(true);
		use_intensity_normalization.setMandatory(false);
		
		// set the Patch selection options
		ParamCollection selParams = new ParamCollection("Patch Selection");
		
		// set the selection type
		selParams.add(selection_type = new ParamOption("Selection Type", new String[] { "SSIM", "Jaccard", "None"}));
		selection_type.setValue(2);
		selection_type.setMandatory(false);
		
		selParams.add(selection_thresh = new ParamFloat("Selection Type Threshold", 0, 1));
		selection_thresh.setValue(0.05);
		selection_thresh.setMandatory(false);
		
		selParams.add(num_keep = new ParamInteger("Number of Patches to Keep (-1 = all patches)"));
		num_keep.setValue(-1);
		num_keep.setMandatory(false);
		
		ParamCollection fullWeightParams = new ParamCollection("Weighting Parameters");
		fullWeightParams.add(weightParams);
		fullWeightParams.add(selParams);
		
		return(fullWeightParams);
	}
	
	protected void set_output_parameters() {
		// handle the label estimate output
		outputParams.add(labelvol = new ParamVolume("Label Volume"));
		
		// handle the label probabilities output
		outputParams.add(labelprob_multiple = new ParamVolumeCollection("Label Probabilities Collection"));
		labelprob_multiple.setLoadAndSaveOnValidate(false);
		labelprob_multiple.setMandatory(false);
		
		outputParams.add(labelprob_single = new ParamVolume("Label Probabilities Volume"));
		labelprob_single.setLoadAndSaveOnValidate(false);
		labelprob_single.setMandatory(false);
	}
	
	protected class AbstractExecuteWrapper extends AbstractCalculation {
		public void execute(ProcessingAlgorithm alg) throws FileNotFoundException{

			// get the observation object
			ObservationBase obs = get_observation_type();
			
			// run majority vote
			String algname = alg.getAlgorithmName();
			MajorityVote mv = new MajorityVote(obs, String.format("%s_Estimate", algname));
			labelvol.setValue(mv.run());
							
			// write the label probabilities if desired
			write_label_probabilities(obs, mv, alg);
		}
		
		protected ObservationBase get_observation_type() {
			
			// determine the observation type
			ObservationBase obs;
			
			boolean usePartial = true;
			
			JistLogger.logOutput(JistLogger.INFO, "*** Determining appropriate type for the observation object. ***");
			
			if (targetim == null || imsvols == null)
				usePartial = false;
			
			// check to see if the target image was specified
			if (usePartial && targetim.getValue() == null) {
				usePartial = false;
				JistLogger.logOutput(JistLogger.INFO, "No target image specified");
			}

			// check to see if they specified the right number of atlas images
			if (usePartial && obsvols.getParamVolumeList().size() != imsvols.getParamVolumeList().size()) {
				usePartial = false;
				JistLogger.logOutput(JistLogger.INFO, "Number of atlas images does not match number of atlas labels");
			}
						
			if (usePartial) {
				
				JistLogger.logOutput(JistLogger.INFO, "-> Using ObservationVolumePartial.");
			
				int [] sv = new int [4];
				sv[0] = sv_dim1.getInt();
				sv[1] = sv_dim2.getInt();
				sv[2] = sv_dim3.getInt();
				sv[3] = sv_dim4.getInt();
				
				int [] pv = new int [4];
				pv[0] = pv_dim1.getInt();
				pv[1] = pv_dim2.getInt();
				pv[2] = pv_dim3.getInt();
				pv[3] = pv_dim4.getInt();
				
				float [] sp_stdevs = new float [4];
				sp_stdevs[0] = spatial_sd_dim1.getFloat();
				sp_stdevs[1] = spatial_sd_dim2.getFloat();
				sp_stdevs[2] = spatial_sd_dim3.getFloat();
				sp_stdevs[3] = spatial_sd_dim4.getFloat();
				
				// Create the observation structures
				obs = new ObservationVolumePartial(targetim, obsvols, imsvols,
																   weight_type.getIndex(),
																   sv, pv, sp_stdevs,
																   intensity_sd.getFloat(),
																   selection_type.getIndex(),
																   selection_thresh.getFloat(),
																   num_keep.getInt(), 1f,
														           probtype.getIndex(),
																   use_intensity_normalization.getValue());
				
				
				// if we're using atlas selection, construct that matrix here
				float global_selection_threshold = global_sel_thresh.getFloat();
				float local_selection_threshold = local_sel_thresh.getFloat();
				
				if (global_selection_threshold > 1)
					global_selection_threshold = 0;
				if (local_selection_threshold > 1)
					local_selection_threshold = 0;
				
				if (global_selection_threshold > 0 || local_selection_threshold > 0)
					obs.create_atlas_selection_matrix(global_selection_threshold,
													  local_selection_threshold);
			} else if (probtype.getIndex() == 0) {
				JistLogger.logOutput(JistLogger.INFO, "-> Using ObservationVolume.");
				obs = new ObservationVolume(obsvols, 1f);
			} else {
				JistLogger.logOutput(JistLogger.INFO, "-> Using ObservationVolumeProbability.");
				
				// set the ObservationVolume
				obs = new ObservationVolumeProbability(obsvols, 1f, probtype.getIndex());
			}
			
			JistLogger.logOutput(JistLogger.INFO, "");
			
			return(obs);
		}
	
		protected void write_label_probabilities(ObservationBase obs,
											     LabelFusionBase mv,
											     ProcessingAlgorithm alg) {
			if (output_W.getIndex() == 1) {
				
				File outdir = new File(
						alg.getOutputDirectory() +
						File.separator +
						FileUtil.forceSafeFilename(alg.getAlgorithmName()) +
						File.separator +
						"LabelProbabilities");
				outdir.mkdirs();
				
				JistLogger.logOutput(JistLogger.INFO, "Writing Label Probabilities");
				for (short l = 0; l < obs.num_labels(); l++) {
					String name = String.format("%s_Label_Probability_%04d", alg.getAlgorithmName(), obs.get_label_remap(l));
					ImageData f = mv.get_label_prob(l, name);
					f.setHeader(obs.get_header());
					labelprob_multiple.add(f);
					labelprob_multiple.writeAndFreeNow(outdir);
					f.dispose();
					f = null;
				}
			} else if (output_W.getIndex() == 2) {
				
				File outdir = new File(
						alg.getOutputDirectory() +
						File.separator +
						FileUtil.forceSafeFilename(alg.getAlgorithmName()));
				outdir.mkdirs();
				
				JistLogger.logOutput(JistLogger.INFO, "Writing Label Probabilities");
				
				String name = String.format("%s_Label_Probabilities", alg.getAlgorithmName());
				ImageData f = mv.get_all_label_probs(name);
				f.setHeader(obs.get_header());
				labelprob_single.setValue(f);
				labelprob_single.writeAndFreeNow(outdir);
				f.dispose();
				f = null;
			}
		}
	}
}
