package edu.vanderbilt.masi.algorithms.clasisfication;

import java.io.*;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.zip.GZIPInputStream;

import org.json.*;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.FileUtil;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.algorithms.clasisfication.decisiontree.DecisionTreeClassifier;

public class ClassificationRunner extends AbstractCalculation {

	public Classifier classifier;
	private ArrayList<File> files;
	private File outdir;

	public ClassificationRunner(File input_classifier,List<File> features,File outdir){
		this.files = new ArrayList<File>(features.size());
		this.outdir = outdir;
		try{
			JistLogger.logOutput(JistLogger.FINE, "Loading Classifier");
			FileReader fstream = new FileReader(input_classifier);
			BufferedReader br = new BufferedReader(fstream);
			String line = br.readLine();
			br.close();
			fstream.close();
			JSONObject obj = new JSONObject(line);
			JistLogger.logOutput(JistLogger.FINE, obj.toString());
			if(obj.getString("classifier_type").equals("Decision Tree")) this.classifier = new DecisionTreeClassifier();
			this.classifier.buildFromJSON(obj);
		}catch(IOException e){
			e.printStackTrace();
		} catch (JSONException e) {
			e.printStackTrace();
		}
		runSingleThread(features);
	}

	private void runSingleThread(List<File> features){
		JistLogger.logOutput(JistLogger.FINE, "Running with Single Thread");
		for(File f: features){
			String name = f.getName();
			name = name.replace("Feature_Label_", "");
			name = name.replace(".json.gz","");
			JistLogger.logOutput(JistLogger.FINE, "Starting to process file "+f.getName());
			JistLogger.logFlush();
			int n=0;
			int m=0;
			JSONObject obj = null;
			try {
				InputStream fileStream = new FileInputStream(f.getAbsoluteFile());
				InputStream gzipStream = new GZIPInputStream(fileStream);
				Reader decoder = new InputStreamReader(gzipStream, "utf-8");
				BufferedReader br = new BufferedReader(decoder);
				String line;
//				String line = br.readLine();
//				obj = new JSONObject(line);
//				float[][] negative = new float[obj.getInt("negative")][((JSONArray )obj.get("feature_names")).length()];
//				JistLogger.logOutput(JistLogger.FINE, "Instantiated a Matrix of size "+negative.length+"x"+negative[0].length);
//				float[][] positive = new float[obj.getInt("positive")][((JSONArray )obj.get("feature_names")).length()];
//				JistLogger.logOutput(JistLogger.FINE, "Instantiated a Matrix of size "+positive.length+"x"+positive[0].length);
				LinkedList<float[]> positive=new LinkedList<float[]>();
				LinkedList<float[]> negative=new LinkedList<float[]>();
				while((line=br.readLine())!=null){
					obj = new JSONObject(line);
					if(!obj.has("class"))
						JistLogger.logOutput(JistLogger.INFO, "No class found in this object: "+obj.toString());
					if(obj.getInt("class")==1){
						JSONArray temp = obj.getJSONArray("features");
						float[] t_p = new float[temp.length()];
						for(int i=0;i<temp.length();i++) t_p[i] = (float) temp.getDouble(i);
						positive.add(t_p);
					}else if(obj.getInt("class")==-1){
						JSONArray temp = obj.getJSONArray("features");
						float[] t_n = new float[temp.length()];
						for(int i=0;i<temp.length();i++) t_n[i] = (float) temp.getDouble(i);
						negative.add(t_n);
					}else{
						JistLogger.logOutput(JistLogger.WARNING, "Found and object with class "+obj.getInt("class")+" and this is unexpected");
					}
				}
				br.close();
				decoder.close();
				gzipStream.close();
				fileStream.close();
				JistLogger.logOutput(JistLogger.INFO,"There were "+negative.size()+" negative and "+positive.size()+" positive for label "+name);
				float [][] neg = new float[negative.size()][];
				float [][] pos = new float[positive.size()][];
				n=0;
				for(float[] a:negative){
					neg[n]=a;
					n++;
				}
				n=0;
				for(float[] a:positive){
					pos[n]=a;
					n++;
				}
				Classifier C = this.runTraining(pos,neg,name);
				File o = new File(outdir, String.format("Classifier%s.json", name));
				JistLogger.logOutput(JistLogger.FINE,"Writing to "+o.getAbsolutePath());
				C.writeToFile(o);
				this.files.add(o);
			} catch (IOException e) {
				e.printStackTrace();
			} catch (JSONException e) {
				e.printStackTrace();
			}
			JistLogger.logOutput(JistLogger.FINE, "Found "+n+" positive feature rows and "+m+" negative feature rows");
			JistLogger.logFlush();
		}
	}

	private Classifier runTraining(float[][] positive, float[][] negative, String label){

		JistLogger.logOutput(JistLogger.INFO, "Starting Classifier Building for Label "+label);
		Classifier C = this.classifier.basicClone();
		int[] target = new int[positive.length+negative.length];
		float[] weight = new float[target.length];
		float[][] features = new float[target.length][];
		int n=0;
		for(int i=0;i<positive.length;i++){
			n++;
			weight[i] = 1;
			target[i] = 1;
			features[i] = positive[i];
		}
		for(int i=0;i<negative.length;i++){
			weight[i+n] = 1;
			target[i+n] = -1;
			features[i+n] = negative[i];
		}
		C.train(features, target, weight);
		return C;
	}

	public List<File> getFiles(){
		return this.files;
	}

}
