package edu.vanderbilt.masi.plugins.segadapter;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.JistPreferences;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFileCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.algorithms.adaboost.SegAdapterTestingForest;

import java.io.*;

public class PluginSegAdapterApplyForest extends ProcessingAlgorithm {
	
	// input parameters
	public ParamVolume est_vol;
	public ParamFileCollection json_files;
	public ParamVolumeCollection feature_vols;
	public ParamVolume initial_mask_vol;
	public ParamOption fusion_type;
	
	// output parameters
	public ParamVolume out_seg;
	
	/****************************************************
	 * CVS Version Control
	 ****************************************************/
	private static final String cvsversion = "$Revision: 1.2 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");
	private static final String shortDescription = "Apply the SegAdapter algorithm using an AdaBoost Forest.";
	private static final String longDescription = "";
	
	/*
	 * (non-Javadoc)
	 *
	 * @see edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm#createInputParameters(edu.jhu.ece.iacl.pipeline.parameter.ParamCollection)
	 */
	protected void createInputParameters(ParamCollection inputParams) {
		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("https://masi.vuse.vanderbilt.edu/");
		info.setAffiliation("MASI - Vanderbilt");
		info.add(new AlgorithmAuthor("Andrew Asman","andrew.j.asman@vanderbilt.edu","https://masi.vuse.vanderbilt.edu/index.php/MASI:Andrew_Asman"));
		info.setDescription(shortDescription);
		info.setLongDescription(shortDescription + longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.BETA);
		
		inputParams.setPackage("MASI");
		inputParams.setCategory("SegAdapter");
		inputParams.setLabel("SegAdapter Apply Forest");
		inputParams.setName("SegAdapter_Apply_Forest");
		
		// set the manual observations
		inputParams.add(est_vol = new ParamVolume("Estimated Segmentation"));
		est_vol.setLoadAndSaveOnValidate(false);
		est_vol.setMandatory(true);
		
		inputParams.add(json_files = new ParamFileCollection("SegAdapter JSON File Collection"));
		json_files.setLoadAndSaveOnValidate(false);
		json_files.setMandatory(true);
		
		inputParams.add(feature_vols = new ParamVolumeCollection("Feature Volumes"));
		feature_vols.setLoadAndSaveOnValidate(false);
		feature_vols.setMandatory(false);
		
		// set the manual observations
		inputParams.add(initial_mask_vol = new ParamVolume("Initial Mask (4-D)"));
		initial_mask_vol.setLoadAndSaveOnValidate(false);
		initial_mask_vol.setMandatory(false);
		
		// set the options for writing the label probabilities
		inputParams.add(fusion_type = new ParamOption("Classifier Fusion Method", new String[] { "Mean Classifier", "Mean Probability", "Maximum Likelihood"}));
		fusion_type.setValue(0);
		fusion_type.setMandatory(false);
	}
	
	/*
	 * (non-Javadoc)
	 *
	 * @see edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm#createOutputParameters(edu.jhu.ece.iacl.pipeline.parameter.ParamCollection)
	 */
	protected void createOutputParameters(ParamCollection outputParams) {
		outputParams.add(out_seg = new ParamVolume("Corrected Segmentation"));
	}
	
	protected void execute(CalculationMonitor monitor) throws AlgorithmRuntimeException {
		try {
			ExecuteWrapper wrapper=new ExecuteWrapper();
			monitor.observe(wrapper);
			wrapper.execute(this);

		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	
	protected class ExecuteWrapper extends AbstractCalculation {
		public void execute(ProcessingAlgorithm alg) throws FileNotFoundException {
			
			boolean [][][][] initial_masks = null;
			if (initial_mask_vol.getValue() != null) {
				
				int [] dims = new int [4];
				
				JistLogger.logOutput(JistLogger.INFO, "-> Loading initial masks");
				
				// load the image information (quietly)
				JistPreferences prefs = JistPreferences.getPreferences();
				int orig_level = prefs.getDebugLevel();
				prefs.setDebugLevel(JistLogger.SEVERE);
				
				ImageData img = initial_mask_vol.getImageData(true);
				dims[0] = Math.max(img.getRows(), 1);
				dims[1] = Math.max(img.getCols(), 1);
				dims[2] = Math.max(img.getSlices(), 1);
				dims[3] = Math.max(img.getComponents(), 1);
				
				initial_masks = new boolean [dims[3]][dims[0]][dims[1]][dims[2]];
				for (int x = 0; x < dims[0]; x++)
					for (int y = 0; y < dims[1]; y++)
						for (int z = 0; z < dims[2]; z++)
							for (int l = 0; l < dims[3]; l++)
								initial_masks[l][x][y][z] = img.getInt(x, y, z, l) > 0;
				img.dispose();
				
				prefs.setDebugLevel(orig_level);
			}
			
			SegAdapterTestingForest ada = new SegAdapterTestingForest(est_vol,
																	  feature_vols,
																	  json_files,
																	  initial_masks,
																	  fusion_type.getIndex());
			out_seg.setValue(ada.get_estimate());
		}
	}
	
}
