package edu.vanderbilt.masi.plugins.segadapter;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.JistPreferences;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFileCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.FileUtil;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.algorithms.adaboost.SegAdapterTraining;

import java.io.*;
import java.util.*;

import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

public class PluginSegAdapterTrainAll extends ProcessingAlgorithm {
	
	// input parameters
	public ParamVolumeCollection man_vols;
	public ParamVolumeCollection est_vols;
	public ParamVolumeCollection feature_vols;
	public ParamVolumeCollection initial_mask_vols;
	public ParamInteger dilation;
	public ParamInteger radx;
	public ParamInteger rady;
	public ParamInteger radz;
	public ParamInteger numiters;
	
	// sampling parameters
	public ParamInteger max_samples;
	public ParamBoolean use_equal_class;
	public ParamBoolean use_equal_examples;
	
	// decision tree parameters
	public ParamInteger tree_max_depth;
	public ParamInteger tree_min_samples;
	public ParamOption tree_split_type;
	
	// output parameters
	public ParamFileCollection summaryfiles;
	public ParamFile full_json_file;
	
	/****************************************************
	 * CVS Version Control
	 ****************************************************/
	private static final String cvsversion = "$Revision: 1.1 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");
	private static final String shortDescription = "Runs SegAdapter training algorithm (All Labels).";
	private static final String longDescription = "";
	
	/*
	 * (non-Javadoc)
	 *
	 * @see edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm#createInputParameters(edu.jhu.ece.iacl.pipeline.parameter.ParamCollection)
	 */
	protected void createInputParameters(ParamCollection inputParams) {
		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("https://masi.vuse.vanderbilt.edu/");
		info.setAffiliation("MASI - Vanderbilt");
		info.add(new AlgorithmAuthor("Andrew Asman","andrew.j.asman@vanderbilt.edu","https://masi.vuse.vanderbilt.edu/index.php/MASI:Andrew_Asman"));
		info.setDescription(shortDescription);
		info.setLongDescription(shortDescription + longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.BETA);
		
		inputParams.setPackage("MASI");
		inputParams.setCategory("SegAdapter");
		inputParams.setLabel("SegAdapter Training (All Labels)");
		inputParams.setName("SegAdapter_Training_All");
		
		ParamCollection mainParams = new ParamCollection("Main");
		
		// set the manual observations
		mainParams.add(man_vols=new ParamVolumeCollection("Manual Segmentation Volumes"));
		man_vols.setLoadAndSaveOnValidate(false);
		man_vols.setMandatory(true);
		
		// set the estimated observations
		mainParams.add(est_vols=new ParamVolumeCollection("Estimated Segmentation Volumes"));
		est_vols.setLoadAndSaveOnValidate(false);
		est_vols.setMandatory(true);
		
		// set the feature observations
		mainParams.add(feature_vols=new ParamVolumeCollection("Feature Volumes"));
		feature_vols.setLoadAndSaveOnValidate(false);
		feature_vols.setMandatory(false);
		
		// set the manual observations
		mainParams.add(initial_mask_vols = new ParamVolumeCollection("Initial Masks (4-D)"));
		initial_mask_vols.setLoadAndSaveOnValidate(false);
		initial_mask_vols.setMandatory(false);
		
		// set the dilation radius
		mainParams.add(dilation = new ParamInteger("Dilation Radius"));
		dilation.setValue(1);
		dilation.setMandatory(false);
		
		// set the feature radii
		mainParams.add(radx = new ParamInteger("Feature Radius Dimension 1 (Image Units)"));
		radx.setValue(2);
		radx.setMandatory(false);
		mainParams.add(rady = new ParamInteger("Feature Radius Dimension 2 (Image Units)"));
		rady.setValue(2);
		rady.setMandatory(false);
		mainParams.add(radz = new ParamInteger("Feature Radius Dimension 3 (Image Units)"));
		radz.setValue(2);
		radz.setMandatory(false);
		
		// set the dilation radius
		mainParams.add(numiters = new ParamInteger("Number of Iterations"));
		numiters.setValue(50);
		numiters.setMandatory(false);
		
		/*
		 * Sampling Parameters
		 */
		ParamCollection sampParams = new ParamCollection("Sampler");
		
		sampParams.add(max_samples = new ParamInteger("Maximum Number of Samples"));
		max_samples.setValue(500000);
		max_samples.setMandatory(false);
		
		sampParams.add(use_equal_class = new ParamBoolean("Use Equal Class Sampling"));
		use_equal_class.setValue(false);
		use_equal_class.setMandatory(false);

		sampParams.add(use_equal_examples = new ParamBoolean("Use Equal Example Sampling"));
		use_equal_examples.setValue(false);
		use_equal_examples.setMandatory(false);
		
		/*
		 * Decision Tree Parameters
		 */
		ParamCollection treeParams = new ParamCollection("Decision Tree");
		
		treeParams.add(tree_max_depth = new ParamInteger("Maximum Tree Depth", 1, 5));
		tree_max_depth.setValue(1);
		tree_max_depth.setMandatory(false);
		
		treeParams.add(tree_min_samples = new ParamInteger("Minimum Number of Samples in Leaf Node", 1, 1000));
		tree_min_samples.setValue(50);
		tree_min_samples.setMandatory(false);
		
		// set the options for writing the label probabilities
		treeParams.add(tree_split_type = new ParamOption("Split Criterion", new String[] { "Classification Rate", "Gini", "Information Gain", "Gain Ratio"}));
		tree_split_type.setValue(0);
		tree_split_type.setMandatory(false);
		
		inputParams.add(mainParams);
		inputParams.add(sampParams);
		inputParams.add(treeParams);
	}
	
	/*
	 * (non-Javadoc)
	 *
	 * @see edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm#createOutputParameters(edu.jhu.ece.iacl.pipeline.parameter.ParamCollection)
	 */
	protected void createOutputParameters(ParamCollection outputParams) {
		
		// handle the accuracy output
		outputParams.add(summaryfiles = new ParamFileCollection("SegAdapter JSON File Collection"));
		
		// handle the accuracy output
		outputParams.add(full_json_file = new ParamFile("Full SegAdapter JSON File"));
	}
	
	protected void execute(CalculationMonitor monitor) throws AlgorithmRuntimeException {
		try {
			ExecuteWrapper wrapper=new ExecuteWrapper();
			monitor.observe(wrapper);
			wrapper.execute(this);

		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	
	protected class ExecuteWrapper extends AbstractCalculation {
		public void execute(ProcessingAlgorithm alg) throws FileNotFoundException {
			
			// get the unique values
			int num_vols = est_vols.size();
			
			// make sure we have the same number of manual segmentations
			if (num_vols != man_vols.size())
				throw new RuntimeException("Number of Estimated Segmenations must equal number of Manual Segmentations");
			
			// get the unique labels
			ArrayList<Integer> ul = get_labels(num_vols);
			
			// get the number of labels
			int num_labels = ul.size();
			JistLogger.logOutput(JistLogger.INFO, String.format("Found %d labels", num_labels));
			
			// get the initial masks
			boolean [][][][][] initial_masks = get_initial_masks(num_labels, num_vols);

			// store the per label json files
			JSONArray per_label_json = new JSONArray();
			
			// iterate over all of the labels
			for (int l = 0; l < num_labels; l++) {
				
				// set the current label of interest
				short tlabel = ul.get(l).shortValue();
				JistLogger.logOutput(JistLogger.INFO, "");
				JistLogger.logOutput(JistLogger.INFO, String.format("***************************", tlabel));
				JistLogger.logOutput(JistLogger.INFO, String.format("*** Processing Label %d ***", tlabel));
				JistLogger.logOutput(JistLogger.INFO, String.format("***************************", tlabel));
				JistLogger.logOutput(JistLogger.INFO, "");
			
				// construct the training classifier
				SegAdapterTraining ada = new SegAdapterTraining(man_vols,
						                                        est_vols,
						                                        feature_vols,
						                                        tlabel,
						                                        dilation.getInt(),
						                                        radx.getInt(),
						                                        rady.getInt(),
						                                        radz.getInt(),
						                                        numiters.getInt(),
						                                        max_samples.getInt(),
						                                        use_equal_examples.getValue(),
						                                        use_equal_class.getValue(),
						                                        tree_max_depth.getInt(),
						                                        tree_min_samples.getInt(),
						                                        tree_split_type.getIndex(),
						                                        initial_masks[l]);
				
				// train the classifiers
				ada.train();
				
				// add this json object to the json array
				per_label_json.put(ada.toJSONObject());
				
				// Create the file
				try {
					
					// Set the output directory
					File outdir = new File(
							alg.getOutputDirectory() +
							File.separator +
							FileUtil.forceSafeFilename(alg.getAlgorithmName()) +
							File.separator);
					outdir.mkdirs();
					
					// write the output to the file
					File f = new File(outdir, String.format("AdaBoost-Training-%04d.json", tlabel)); 
					FileWriter fstream = new FileWriter(f);
					BufferedWriter b = new BufferedWriter(fstream);
					b.write(ada.toJSONString());
					b.close();
					fstream.close();
					
					// write the summary file to disk
					summaryfiles.add(f);
	
				} catch (IOException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
				}
			}
			
			try {
			
				// Set the output directory
				File outdir = new File(
						alg.getOutputDirectory() +
						File.separator +
						FileUtil.forceSafeFilename(alg.getAlgorithmName()) +
						File.separator);
				outdir.mkdirs();
				
				// construct the final JSON object
				JSONObject obj = new JSONObject();
				obj.put("Classifiers", per_label_json);
				
				// write the output to the file
				File f = new File(outdir, String.format("AdaBoost-Training-Full.json")); 
				FileWriter fstream = new FileWriter(f);
				BufferedWriter b = new BufferedWriter(fstream);
				b.write(obj.toString());
				b.close();
				fstream.close();
				full_json_file.setValue(f);

			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			} catch (JSONException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			
			// save the output files
			summaryfiles.writeAndFreeNow(alg);
			full_json_file.writeAndFreeNow(alg);
			
		}
		
		private boolean [][][][][] get_initial_masks(int num_labels, int num_vols) {
			
			// initialize the masks
			boolean [][][][][] initial_masks = null;
			initial_masks = new boolean[num_labels][num_vols][][][];
			for (int l = 0; l < num_labels; l++)
				for (int i = 0; i < num_vols; i++)
					initial_masks[l][i] = null;
			
			// add them, if specified
			if (initial_mask_vols.size() != 0) {
				
				JistLogger.logOutput(JistLogger.INFO, "-> Loading initial masks");
				
				initial_masks = new boolean [num_labels][num_vols][][][];
				int [] dims = new int [4];
				
				if (initial_mask_vols.size() != num_vols)
					throw new RuntimeException("Number of mask volumes does not match number of input segmentations");
				
				for (int i = 0; i < num_vols; i++) {

					// load the image information (quietly)
					JistPreferences prefs = JistPreferences.getPreferences();
					int orig_level = prefs.getDebugLevel();
					prefs.setDebugLevel(JistLogger.SEVERE);
					
					ImageData img = initial_mask_vols.getParamVolume(i).getImageData(true);
					dims[0] = Math.max(img.getRows(), 1);
					dims[1] = Math.max(img.getCols(), 1);
					dims[2] = Math.max(img.getSlices(), 1);
					dims[3] = Math.max(img.getComponents(), 1);
					
					if (dims[3] != num_labels)
						throw new RuntimeException("Number of mask labels does not match the number of discovered labels");
					
					for (int l = 0; l < num_labels; l++) {
						initial_masks[l][i] = new boolean [dims[0]][dims[1]][dims[2]];
						for (int x = 0; x < dims[0]; x++)
							for (int y = 0; y < dims[1]; y++)
								for (int z = 0; z < dims[2]; z++)
									initial_masks[l][i][x][y][z] = img.getInt(x, y, z, l) > 0;
					}
									
					img.dispose();
					
					prefs.setDebugLevel(orig_level);
				}
			}
			
			return(initial_masks);	
		}
		
		private ArrayList<Integer> get_labels(int num_vols) {
			ArrayList<Integer> ul = new ArrayList<Integer>();
			int [] dims = new int [4];
			Integer val = new Integer(0);
			
			JistLogger.logOutput(JistLogger.INFO, "-> Determining the number of labels to classify");
			for (int i = 0; i < num_vols; i++) {
				
				// load the image information (quietly)
				JistPreferences prefs = JistPreferences.getPreferences();
				int orig_level = prefs.getDebugLevel();
				prefs.setDebugLevel(JistLogger.SEVERE);
				
				ImageData img = est_vols.getParamVolume(i).getImageData(true);
				dims[0] = Math.max(img.getRows(), 1);
				dims[1] = Math.max(img.getCols(), 1);
				dims[2] = Math.max(img.getSlices(), 1);
				for (int x = 0; x < dims[0]; x++)
					for (int y = 0; y < dims[1]; y++)
						for (int z = 0; z < dims[2]; z++) {
							val = img.getInt(x, y, z);
							if (!ul.contains(val))
								ul.add(val);
						}
				img.dispose();
				
				img = man_vols.getParamVolume(i).getImageData(true);
				dims[0] = Math.max(img.getRows(), 1);
				dims[1] = Math.max(img.getCols(), 1);
				dims[2] = Math.max(img.getSlices(), 1);
				for (int x = 0; x < dims[0]; x++)
					for (int y = 0; y < dims[1]; y++)
						for (int z = 0; z < dims[2]; z++) {
							val = img.getInt(x, y, z);
							if (!ul.contains(val))
								ul.add(val);
						}
				img.dispose();
				
				prefs.setDebugLevel(orig_level);
			}
			
			// sort the labels
			Collections.sort(ul);
			
			return(ul);
		}
	}
	
}
