package edu.vanderbilt.masi.algorithms.clasisfication;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.TreeMap;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.FileUtil;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;

public class ErrorClassificationTrainingSetBuilder extends AbstractCalculation {

	public int[][][] manual_segmentation;
	public boolean[][][] mask;
	public int[][][] host_segmentation;
	public String[][][] rows;
	public JSONObject header;
	public ArrayList<File> files;
	public HashSet<Integer> labels;
	public int r,c,s;
	private int radius=3;
	public ArrayList<VolumeLocation> mask_locations;
	private int max_samples;
	private HashMap<Integer,BufferedWriter> out_files;


	public ErrorClassificationTrainingSetBuilder(File features,
			ParamVolume manual_segmentation, ParamVolume host_segmentation, 
			ParamVolume mask,HashMap<Integer,BufferedWriter> files,
			ParamInteger max, ParamInteger dilation_distance){
		ImageData man = manual_segmentation.getImageData();
		ImageData mas = mask.getImageData();
		ImageData hos = host_segmentation.getImageData();
		this.max_samples = max.getInt();
		this.radius = dilation_distance.getInt();
		this.r = man.getRows();
		this.c = man.getCols();
		this.s = man.getSlices();
		this.out_files = files;
		this.manual_segmentation = new int[man.getRows()][man.getCols()][man.getSlices()];
		this.host_segmentation = new int[man.getRows()][man.getCols()][man.getSlices()];
		this.mask = new boolean[man.getRows()][man.getCols()][man.getSlices()];
		this.rows = new String[r][c][s];
		this.mask_locations = new ArrayList<VolumeLocation>();
		for(int i=0;i<man.getRows();i++){
			for(int j=0;j<man.getCols();j++){
				for(int k=0;k<man.getSlices();k++){
					this.manual_segmentation[i][j][k] = man.getInt(i, j,k);
					this.host_segmentation[i][j][k] = hos.getInt(i,j,k);
					this.mask[i][j][k] = mas.getBoolean(i, j, k);
					if(mas.getBoolean(i,j,k)) this.mask_locations.add(new VolumeLocation(i,j,k));
				}
			}
		}
		man.dispose();
		hos.dispose();
		mas.dispose();
		this.buildFeatures(features);
		this.getLabels();

		for(int i: this.labels) this.processLabel(i);
	}

	private void getLabels(){
		JistLogger.logOutput(JistLogger.INFO, "Determining Labels");
		this.labels = new HashSet<Integer>();
		for(int i=0;i<r;i++){
			for(int j=0;j<c;j++){
				for(int k=0;k<s;k++){
					this.labels.add(this.host_segmentation[i][j][k]);
				}
			}
		}
		this.files = new ArrayList<File>(this.labels.size());
		JistLogger.logOutput(JistLogger.FINE, "There were "+this.labels.size()+" labels found");
	}

	private void buildFeatures(File f){
		JistLogger.logOutput(JistLogger.INFO, "Starting to read feature file "+f.getAbsolutePath()+" for Error Classification Training Set Builder");
		int n=0;
		try {
			VolumeLocation v;
			float[] b;
			int x,y,z;
			InputStream fileStream = new FileInputStream(f.getAbsoluteFile());
			BufferedReader br;
			if(f.getAbsolutePath().endsWith(".gz")){
				InputStream gzipStream = new GZIPInputStream(fileStream);
				Reader decoder = new InputStreamReader(gzipStream, "utf-8");
				br = new BufferedReader(decoder);
			}
			else{
				Reader reader = new InputStreamReader(fileStream);
				br = new BufferedReader(reader);
			}
			String line = br.readLine();
			JSONObject obj = new JSONObject(line);
			this.header = obj;
			JSONArray a;
			int num = obj.getInt("number");
			n = 0;
			while((line=br.readLine())!=null){
				if(n%10000 == 0){
					JistLogger.logOutput(JistLogger.FINE,"On number " + n + " out of "+num);
					JistLogger.logFlush();
				}
				obj = new JSONObject(line);
				x = obj.getInt("x");
				y = obj.getInt("y");
				z = obj.getInt("z");
				a = obj.getJSONArray("features");
				this.rows[x][y][z]=line;
				n++;
			}
			br.close();
			fileStream.close();
		} catch (IOException e) {
			e.printStackTrace();
		} catch (JSONException e) {
			e.printStackTrace();
		}
		JistLogger.logOutput(JistLogger.FINE, "There were "+n+ " feature rows collected");

	}

	public ArrayList<File> getFileLocations(){
		return this.files;
	}


	private void processLabel(int label){
		JistLogger.logOutput(JistLogger.INFO, "Processing Label "+label);
		JistLogger.logFlush();
		HashSet<VolumeLocation> positive = new HashSet<VolumeLocation>(this.max_samples*5,(float) 0.2);
		HashSet<VolumeLocation> negative = new HashSet<VolumeLocation>(this.max_samples*5,(float) 0.2);
		int i,j,k;

		try {
			BufferedWriter bw = this.out_files.get(label);

			outerloop: // magic
				for(VolumeLocation v:this.mask_locations){
					i = v.getX();
					j = v.getY();
					k = v.getZ();
					if(this.host_segmentation[i][j][k] == label){
						int xl = Math.max(i- this.radius, 0);
						int xh = Math.min(i+ this.radius, r-1);
						int yl = Math.max(j- this.radius, 0);
						int yh = Math.min(j+ this.radius, c-1);
						int zl = Math.max(k- this.radius, 0);
						int zh = Math.min(k+ this.radius, s-1);
						for(int l=xl;l<=xh;l++){
							for(int m=yl;m<yh;m++){
								for(int n=zl;n<zh;n++){
									if(this.mask[l][m][n]){
										if(this.manual_segmentation[l][m][n] == label && positive.size()<this.max_samples) positive.add(new VolumeLocation(l,m,n));
										else if(negative.size()<this.max_samples) negative.add(new VolumeLocation(l,m,n));
										if(negative.size()>=this.max_samples&&positive.size()>=this.max_samples) break outerloop;
									}
								}
							}
						}
					}
				}
			JSONObject o = this.header;
//			o.put("positive",positive.size());
//			o.put("negative", negative.size());
//			bw.write(o.toString());
//			bw.write("\n");
			for(VolumeLocation v:positive){
				String row = this.rows[v.getX()][v.getY()][v.getZ()];
				JSONObject obj = new JSONObject(row);
				obj.put("class", 1);
				bw.write(obj.toString());
				bw.write("\n");

			}
			for(VolumeLocation v:negative){
				String row = this.rows[v.getX()][v.getY()][v.getZ()];
				JSONObject obj = new JSONObject(row);
				obj.put("class", -1);
				bw.write(obj.toString());
				bw.write("\n");
			}
		} catch (IOException e) {
			e.printStackTrace();
		} catch (JSONException e) {
			e.printStackTrace();
		}
		JistLogger.logOutput(JistLogger.FINE, "There are "+positive.size()+" positive samples and "+negative.size()+" negative samples");
		JistLogger.logFlush();
	}
}
