package edu.vanderbilt.masi.algorithms.adaboost;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;

import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFileCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SegAdapterTestingForest extends AbstractCalculation {
	
	private int [] tlabels;
	private int [][] feature_rads;
	private int [] dilations;
	private int [] numiters;
	private int num_channels;
	private int num_labels;
	private SegAdapterImageTesting im;
	private WeakLearner [][][] learners;
		
	public SegAdapterTestingForest(ParamVolume est_vol,
					               ParamVolumeCollection feature_vols,
					               ParamFileCollection json_files,
					               boolean [][][][] initial_masks,
					               int fusion_type) {
		
		// read all of the information from the json files
		read_json_files(json_files);
		num_channels = feature_vols.size();
		
		// load the pertinent image information
		im = new SegAdapterImageTesting(est_vol, feature_vols, tlabels, dilations, feature_rads, initial_masks);
		im.print_info();
		
		String fusion_str = "";
		if (fusion_type == AdaBoost.FUSION_TYPE_MEAN_CLASSIFIER)
			fusion_str = "Mean Classifier";
		else if (fusion_type == AdaBoost.FUSION_TYPE_MEAN_PROBABILITY)
			fusion_str = "Mean Probability";
		else if (fusion_type == AdaBoost.FUSION_TYPE_MAXIMUM_LIKELIHOOD)
			fusion_str = "Maximum Likelihood";
		else
			throw new RuntimeException("Invalid Fusion Type");
		
		// process the image information
		for (int l = 0; l < num_labels; l++) {
			
			JistLogger.logOutput(JistLogger.INFO, String.format("*** Processing Label %d ***", tlabels[l]));
			
			// set the values for this label
			int [] feature_rad = feature_rads[l];
			int dilation = dilations[l];
			int tlabel = tlabels[l];
			int numiter = numiters[l];
			int win_size = (2*feature_rad[0]+1) * (2*feature_rad[1]+1) * (2*feature_rad[2]+1);
			int num_features = ((1 + num_channels) * win_size + 3) * 4 - 3; 
			int num_samples = im.get_num_samples(l);
			
			JistLogger.logOutput(JistLogger.INFO, String.format("-> Found the following information:"));
			JistLogger.logOutput(JistLogger.INFO, String.format("Feature Radius: %dx%dx%d", feature_rad[0], feature_rad[1], feature_rad[2]));
			JistLogger.logOutput(JistLogger.INFO, String.format("Dilation: %d", dilation));
			JistLogger.logOutput(JistLogger.INFO, String.format("Number of Iterations (classifiers): %d", numiter));
			JistLogger.logOutput(JistLogger.INFO, String.format("Number of Feature Channels: %d", num_channels));
			JistLogger.logOutput(JistLogger.INFO, String.format("Number of Samples to Classify: %d", num_samples));
			JistLogger.logOutput(JistLogger.INFO, String.format("Number of Ensemble Classifiers in Forest: %d", learners[l].length));
			JistLogger.logOutput(JistLogger.INFO, String.format("Ensemble Classifier Fusion Method: %s", fusion_str));

			// process this image
			im.process(num_features, tlabel, l, dilation, feature_rad, numiter, learners[l], (initial_masks == null) ? null : initial_masks[l], fusion_type);
			
			JistLogger.logOutput(JistLogger.INFO, "");
			
		}
	}
	
	public ImageData get_estimate() { return(im.get_estimate()); }
		
	private void read_json_files(ParamFileCollection files) {
		
		// get the number of files
		int num_files = files.size();
		
		// process everything
		try {
			for (int f = 0; f < num_files; f++) {

				BufferedReader b_reader = new BufferedReader(new FileReader(files.getValue(f)));
				String in = b_reader.readLine();
				JSONObject obj = new JSONObject(in);
				JSONArray classifiers = obj.getJSONArray("Classifiers");

				// if this is the first file, initialize
				if (f == 0) {
					num_labels = classifiers.length();
					dilations = new int [num_labels];
					tlabels = new int [num_labels];
					numiters = new int [num_labels];
					feature_rads = new int [num_labels][3];
					learners = new WeakLearner [num_labels][num_files][];

				// else, make sure that everything stays the same
				} else {
					if (classifiers.length() != num_labels)
						throw new RuntimeException("Not all adaboost match!");
				}
				
				for (int li = 0; li < num_labels; li++) {
					JSONObject classifier_obj = classifiers.getJSONObject(li);

					// get all of the learners for this object 
					JSONArray bwlList = classifier_obj.getJSONArray("Learners");
					
					// if this is the first file, initialize
					if (f == 0) {
						dilations[li] = classifier_obj.getInt("Dilation");
						tlabels[li] = classifier_obj.getInt("Label");
						numiters[li] = classifier_obj.getInt("Iters");
						feature_rads[li][0] = classifier_obj.getInt("RadX");
						feature_rads[li][1] = classifier_obj.getInt("RadY");
						feature_rads[li][2] = classifier_obj.getInt("RadZ");

					// else, make sure everything stays the same
					} else {
						if (dilations[li] != classifier_obj.getInt("Dilation"))
							throw new RuntimeException("Not all adaboost match!");
						if (tlabels[li] != classifier_obj.getInt("Label"))
							throw new RuntimeException("Not all adaboost match!");
						if (numiters[li] != classifier_obj.getInt("Iters"))
							throw new RuntimeException("Not all adaboost match!");
						if (feature_rads[li][0] != classifier_obj.getInt("RadX"))
							throw new RuntimeException("Not all adaboost match!");
						if (feature_rads[li][1] != classifier_obj.getInt("RadY"))
							throw new RuntimeException("Not all adaboost match!");
						if (feature_rads[li][2] != classifier_obj.getInt("RadZ"))
							throw new RuntimeException("Not all adaboost match!");
						if (numiters[li] != bwlList.length())
							throw new RuntimeException("Number of found learners " + bwlList.length() + " does not match the expected amount!" + numiters[li]);
					}

					// get the instance of each learner
					learners[li][f] = new WeakLearner [numiters[li]];
					for (int j = 0; j < numiters[li]; j++)
						learners[li][f][j] = WeakLearner.getInstance(bwlList.getJSONObject(j));
				}
			}
		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (JSONException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
}
