package edu.vanderbilt.masi.plugins.classification;

import java.io.*;
import java.util.ArrayList;

import org.json.JSONException;
import org.json.JSONObject;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.jist.utility.FileUtil;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class DecisionTreePlugin2 extends LearnerPlugin {


	// Input Parameters
	public ParamOption split_criterion;
	public ParamInteger max_depth;
	public ParamInteger min_split;
	public ParamInteger min_leaf;

	/****************************************************
	 * CVS Version Control
	 ****************************************************/
	private static final String cvsversion = "$Revision: 1.1 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");
	private static final String shortDescription = "Sets up the framework to build a decision tree using python scikit-learn";
	private static final String longDescription = "";

	@Override
	protected void createInputParameters(ParamCollection inputParams) {
		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("https://masi.vuse.vanderbilt.edu/");
		info.setAffiliation("MASI - Vanderbilt");
		info.add(new AlgorithmAuthor("Andrew Plassard","andrew.j.plassard@vanderbilt.edu","https://masi.vuse.vanderbilt.edu/index.php/MASI:Andrew_Plassard"));
		info.setDescription(shortDescription);
		info.setLongDescription(shortDescription + longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.BETA);

		inputParams.setPackage("MASI");
		inputParams.setCategory("Classification");
		inputParams.setLabel("Decision Tree2");
		inputParams.setName("Decision_Tree2");
		ArrayList<String> options = new ArrayList<String>();
		options.add("Entropy");
		options.add("Gini");
		inputParams.add(split_criterion = new ParamOption("Split Criterion",options));
		split_criterion.setMandatory(false);
		split_criterion.setValue("Entropy");

		inputParams.add(max_depth = new ParamInteger("Maximum Tree Depth"));
		max_depth.setMandatory(false);
		max_depth.setValue(10);

		inputParams.add(min_split = new ParamInteger("Minimum Number of Elements to Split a Node"));
		min_split.setMandatory(false);
		min_split.setValue(25);

		inputParams.add(min_leaf = new ParamInteger("Minimum Number of Elements to Have in a Leaf"));
		min_leaf.setMandatory(false);
		min_leaf.setValue(10);


	}


	@Override
	protected void execute(CalculationMonitor monitor)
			throws AlgorithmRuntimeException {
		ExecuteWrapper wrapper=new ExecuteWrapper();
		monitor.observe(wrapper);
		wrapper.execute(this);


	}

	protected class ExecuteWrapper extends AbstractCalculation {
		protected void execute(ProcessingAlgorithm alg)
				throws AlgorithmRuntimeException {


			JSONObject obj = new JSONObject();
			try {
				obj.put("max_depth",max_depth.getInt());
				obj.put("min_leaf",min_leaf.getInt());
				obj.put("min_split", min_split.getInt());
				obj.put("split_criterion", split_criterion.getValue());
				obj.put("classifier","decision_tree");
				File outdir = new File(
						alg.getOutputDirectory() +
						File.separator +
						FileUtil.forceSafeFilename(alg.getAlgorithmName()) +
						File.separator);
				outdir.mkdirs();
				File f = new File(outdir, "classifier.json");
				File f2 = new File(outdir,"classifier.pkl");
				classifiers.setValue(f2);
				FileWriter fw = new FileWriter(f);
				BufferedWriter bw = new BufferedWriter(fw);
				bw.write(obj.toString());
				bw.close();
				fw.close();
				String cmd = "python /home/local/VANDERBILT/plassaaj/git/JISTSklearn/classifier_builder.py --generate-classifier "+f.getAbsolutePath()+" --output-file "+f2.getAbsolutePath();
				JistLogger.logOutput(JistLogger.INFO, "Running Command: "+ cmd);
				Process p = Runtime.getRuntime().exec(cmd);
				BufferedReader stdOut = new BufferedReader(new 
						InputStreamReader(p.getInputStream()));
				String s = "";
				String line=stdOut.readLine();
				s += line;
				while((line=stdOut.readLine())!=null) s += "\n"+line;
				JistLogger.logOutput(JistLogger.INFO, "The standard out was "+s);
			} catch (JSONException e) {
				e.printStackTrace();
			} catch (IOException e) {
				e.printStackTrace();
			}


		}
	}

}
