package edu.vanderbilt.masi.algorithms.adaboost;

import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SegAdapterTraining extends AbstractCalculation {
	
	private int tlabel;
	private int [] feature_rad;
	private int dilation;
	private int numiters;
	private int num_obs;
	private int num_channels;
	private int win_size;
	private int num_features;
	private boolean [] Y;
	private float [][] X;
	private int num_samples_total;
	private WeakLearner weak_learner_template;
	
	// the output arrays
	JSONArray bwlList;
	
	public SegAdapterTraining(ParamVolumeCollection man_vols,
					          ParamVolumeCollection est_vols,
					          ParamVolumeCollection feature_vols,
					          int tlabel_in,
					          int dilation_in,
					          int radx_in,
					          int rady_in,
					          int radz_in,
					          int numiters_in,
					          int max_num_samples,
					          boolean equal_examples,
					          boolean equal_class,
					       	  int tree_max_depth,
					       	  int tree_min_samples,
					       	  int tree_split_type,
					          boolean [][][][] initial_masks) {
		
		// get all of the input stuff into the class variables
		tlabel = tlabel_in;
		dilation = dilation_in;
		feature_rad = new int [3];
		feature_rad[0] = radx_in;
		feature_rad[1] = rady_in;
		feature_rad[2] = radz_in;
		numiters = numiters_in;
		
		if (tree_max_depth < 1)
			throw new RuntimeException("Invalid Max Tree Depth");
		else if (tree_max_depth == 1)
			weak_learner_template = new WeakLearnerDecisionStump(tree_split_type);
		else
			weak_learner_template = new WeakLearnerDecisionTree(tree_max_depth, tree_min_samples, tree_split_type);
		
		// set the number of observations
		num_obs = man_vols.getParamVolumeList().size();
		if (num_obs != est_vols.getParamVolumeList().size())
			throw new RuntimeException("Number of manual volumes does not equal number of estimated volumes");
		
		// set the number of feature channels
		num_channels = feature_vols.getParamVolumeList().size() / num_obs;
		if (num_channels * num_obs != feature_vols.getParamVolumeList().size())
			throw new RuntimeException("Invalid number of feature channels");
		
		// set the number of features per sample
		win_size = (2*feature_rad[0]+1) * (2*feature_rad[1]+1) * (2*feature_rad[2]+1);
		num_features = ((1 + num_channels) * win_size + 3) * 4 - 3; 
			
		// print all of the current information to the screen
		JistLogger.logOutput(JistLogger.INFO, "-> Found the following information:");
		JistLogger.logOutput(JistLogger.INFO, String.format("Target Label: %d", tlabel));
		JistLogger.logOutput(JistLogger.INFO, String.format("Dilation Radius: %d", dilation));
		JistLogger.logOutput(JistLogger.INFO, String.format("Feature Radii: %dx%dx%d", feature_rad[0], feature_rad[1], feature_rad[2]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of Iterations: %d", numiters));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of Input Segmentations: %d", num_obs));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of feature channels: %d", num_channels));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of feature per sample: %d", num_features));
		JistLogger.logOutput(JistLogger.INFO, String.format("Maximum Number of Samples: %d", max_num_samples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Use Equal Class Sampling: %s", equal_class));
		JistLogger.logOutput(JistLogger.INFO, String.format("Use Equal Sampling from each Image: %s", equal_examples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Maximum Decision Tree Depth: %d", tree_max_depth));
		JistLogger.logOutput(JistLogger.INFO, String.format("Minimum Number of Samples in Leaf Node: %d", tree_min_samples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Tree Split Criterion: %s", weak_learner_template.get_criterion_string()));
		JistLogger.logOutput(JistLogger.INFO, "");
		
		// allocate space for all of the images used for the AdaBoost algorithm
		SegAdapterImageTraining [] ims = new SegAdapterImageTraining[num_obs];

		// determine the cropping region and number of samples for each image
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Determining Pre-Processing Information"));
		for (int i = 0; i < num_obs; i++) {
			JistLogger.printStatusBar(i, num_obs);
			ims[i] = new SegAdapterImageTraining(est_vols.getParamVolume(i),
											     man_vols.getParamVolume(i),
									   	 	     dilation, feature_rad, tlabel,
									   	 	     initial_masks[i]);

		}
		JistLogger.logOutput(JistLogger.INFO, "");
		
		// determine sampling information
		RandomSampler random_sampler = new RandomSampler(max_num_samples, equal_class, equal_examples);
		random_sampler.set_sampling_information(ims);
		
		// allocate space and the feature matrix (X) and the indicator vector (Y), and the appropriate sorting indices
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Allocating space for classifier"));

		num_samples_total = random_sampler.get_num_samples_total();
		Y = new boolean [num_samples_total];
		X = new float [num_features][num_samples_total];
		JistLogger.logOutput(JistLogger.INFO, "");
		
		// process each images individually
		for (int i = 0; i < num_obs; i++) {
			JistLogger.logOutput(JistLogger.INFO, String.format("*** Processing Volume: %03d ***", i+1));
			ims[i].print_info();
			ims[i].process(i, Y, X, est_vols, man_vols, feature_vols, num_channels,
						   num_features, tlabel, dilation, feature_rad, random_sampler, initial_masks[i]);
		}
		
		// set the final number of samples
		num_samples_total = random_sampler.get_num_samples_final();
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Final Sampling Information"));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of samples: %d", num_samples_total));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of positive samples: %d", random_sampler.get_num_pos_samples_final()));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of negative samples: %d", random_sampler.get_num_neg_samples_final()));
		JistLogger.logOutput(JistLogger.INFO, String.format(""));
		
	}
	
	public void train () {
		bwlList = AdaBoost.train(X, Y, numiters, weak_learner_template, num_samples_total);
	}
	
	public JSONObject toJSONObject() {
		
		JSONObject obj;
		try {
			obj = new JSONObject();
			obj.put("Iters", numiters);
			obj.put("Label", tlabel);
			obj.put("Dilation", dilation);
			obj.put("RadX", feature_rad[0]);
			obj.put("RadY", feature_rad[1]);
			obj.put("RadZ", feature_rad[2]);
			obj.put("Channels", num_channels);
			obj.put("Learners", bwlList);
		} catch (JSONException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
			throw new RuntimeException("JSON Exception encountered.");
		}
		
		return(obj);
	}
	
	public String toJSONString() {
		return(this.toJSONObject().toString());
	}
	
}
