package edu.vanderbilt.masi.algorithms.clasisfication;

import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;
import java.util.zip.GZIPInputStream;

import org.json.*;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;

public class ErrorClassificationTrainingSetBuilder2 extends AbstractCalculation {

	public int[][][] manual_segmentation;
	public boolean[][][] mask;
	public int[][][] host_segmentation;
	public String[][][] rows;
	public ArrayList<ArrayList<ArrayList<ArrayList<Label>>>> voxel_map;
	public JSONObject header;
	public HashSet<Integer> labels;
	public int r,c,s;
	private int radius=3;
	public ArrayList<VolumeLocation> mask_locations;
	private int max_samples;
	private HashMap<Integer,BufferedWriter> out_files;


	public ErrorClassificationTrainingSetBuilder2(File features,
			ParamVolume manual_segmentation, ParamVolume host_segmentation, 
			ParamVolume mask,HashMap<Integer,BufferedWriter> files,
			ParamInteger max, ParamInteger dilation_distance){
		ImageData man = manual_segmentation.getImageData();
		ImageData mas = mask.getImageData();
		ImageData hos = host_segmentation.getImageData();
		JistLogger.logOutput(JistLogger.INFO, "Image Dimensions: "+man.getRows()+"x"+man.getCols()+"x"+man.getSlices()+" "+mas.getRows()+"x"+mas.getCols()+"x"+mas.getSlices()+" "+hos.getRows()+"x"+hos.getCols()+"x"+hos.getSlices());
		this.max_samples = max.getInt();
		this.radius = dilation_distance.getInt();
		this.r = man.getRows();
		this.c = man.getCols();
		this.s = man.getSlices();
		this.voxel_map = new ArrayList<ArrayList<ArrayList<ArrayList<Label>>>>(r);
		for(int i=0;i<r;i++){
			ArrayList<ArrayList<ArrayList<Label>>> AL1 = new ArrayList<ArrayList<ArrayList<Label>>>(c);
			for(int j=0;j<c;j++){
				ArrayList<ArrayList<Label>> AL2 = new ArrayList<ArrayList<Label>>(s);
				for(int k=0;k<s;k++){
					AL2.add(new ArrayList<Label>());
				}
				AL1.add(AL2);
			}
			this.voxel_map.add(AL1);
		}
		this.out_files = files;
		this.manual_segmentation = new int[man.getRows()][man.getCols()][man.getSlices()];
		this.host_segmentation = new int[man.getRows()][man.getCols()][man.getSlices()];
		this.mask = new boolean[man.getRows()][man.getCols()][man.getSlices()];
		for(int i=0;i<r;i++){
			for(int j=0;j<c;j++){
				for(int k=0;k<s;k++){
					this.manual_segmentation[i][j][k] = man.getInt(i, j,k);
					this.host_segmentation[i][j][k] = hos.getInt(i,j,k);
					this.mask[i][j][k] = mas.getBoolean(i, j, k);
				}
			}
		}
		man.dispose();
		hos.dispose();
		mas.dispose();
		this.getLabels();
		for(int i: this.labels) this.processLabel(i);
		this.runFileWriting(features);
	}
	private void getLabels(){
		JistLogger.logOutput(JistLogger.INFO, "Determining Labels");
		this.labels = new HashSet<Integer>();
		for(int i=0;i<r;i++){
			for(int j=0;j<c;j++){
				for(int k=0;k<s;k++){
					this.labels.add(this.host_segmentation[i][j][k]);
				}
			}
		}
		JistLogger.logOutput(JistLogger.FINE, "There were "+this.labels.size()+" labels found");
	}
	private void processLabel(int label){
		JistLogger.logOutput(JistLogger.INFO, "Processing Label "+label);
		JistLogger.logFlush();
		int r = this.manual_segmentation.length;
		int c = this.manual_segmentation[0].length;
		int s = this.manual_segmentation[0][0].length;
		boolean[][][] pos = new boolean[r][c][s];
		boolean[][][] neg = new boolean[r][c][s];
		int n_pos = 0;
		int n_neg = 0;
		for(int i=0;i<r;i++){
			for(int j=0;j<c;j++){
				for(int k=0;k<s;k++){
					neg[i][j][k]=false;
					pos[i][j][k]=false;
				}
			}
		}
		for(int i=0;i<r;i++){
			for(int j=0;j<c;j++){
				for(int k=0;k<s;k++){
					if(this.host_segmentation[i][j][k] == label&&this.mask[i][j][k]){
						int xl = Math.max(i- this.radius, 0);
						int xh = Math.min(i+ this.radius, r-1);
						int yl = Math.max(j- this.radius, 0);
						int yh = Math.min(j+ this.radius, c-1);
						int zl = Math.max(k- this.radius, 0);
						int zh = Math.min(k+ this.radius, s-1);
						for(int l=xl;l<=xh;l++){
							for(int m=yl;m<yh;m++){
								for(int n=zl;n<zh;n++){
									if(this.mask[l][m][n]){
										if(this.manual_segmentation[l][m][n] == label){
											if(!pos[l][m][n]){
												pos[l][m][n]=true;
												n_pos++;
											}
										}
										else if(!neg[l][m][n]){
											neg[l][m][n]=true;
											n_neg++;
										}
									}
								}
							}
						}
					}
				}
			}
		}
		JistLogger.logOutput(JistLogger.INFO, "There were "+n_pos+" positive samples and "+n_neg+" negative samples");
		JistLogger.logFlush();
		float num_samps = (float) Math.min((double) n_pos,Math.min((double) n_neg,(double)this.max_samples));
		float neg_p = num_samps / (float) n_neg;
		float pos_p = num_samps / (float) n_pos;
		JistLogger.logOutput(JistLogger.INFO, "The number of samples to keep is "+num_samps+".  The negative percent is "+neg_p+" and the positive percent is "+pos_p+".");		
		for(int i=0;i<r;i++){
			for(int j=0;j<c;j++){
				for(int k=0;k<s;k++){
					if(pos[i][j][k]&&neg[i][j][k]) JistLogger.logError(JistLogger.SEVERE, "There are positive and negative in the same voxel");
					else if(pos[i][j][k]){
						ArrayList<Label> AL1 = this.voxel_map.get(i).get(j).get(k);
						AL1.add(new Label(label,pos_p,1));
						ArrayList<ArrayList<Label>> AL2 = this.voxel_map.get(i).get(j);
						AL2.set(k,AL1);
						ArrayList<ArrayList<ArrayList<Label>>> AL3 = this.voxel_map.get(i);
						AL3.set(j, AL2);
						this.voxel_map.set(i,AL3);
					}
					else if(neg[i][j][k]){
						ArrayList<Label> AL1 = this.voxel_map.get(i).get(j).get(k);
						AL1.add(new Label(label,neg_p,-1));
						ArrayList<ArrayList<Label>> AL2 = this.voxel_map.get(i).get(j);
						AL2.set(k,AL1);
						ArrayList<ArrayList<ArrayList<Label>>> AL3 = this.voxel_map.get(i);
						AL3.set(j, AL2);
						this.voxel_map.set(i,AL3);
					}
				}
			}
		}
	}

	private void runFileWriting(File f){
		Random R = new Random();
		JistLogger.logOutput(JistLogger.INFO, "Starting to read feature file "+f.getAbsolutePath()+" for Error Classification Training Set Builder");
		try {
			InputStream fileStream = new FileInputStream(f.getAbsoluteFile());
			BufferedReader br;
			if(f.getAbsolutePath().endsWith(".gz")){
				InputStream gzipStream;

				gzipStream = new GZIPInputStream(fileStream);

				Reader decoder = new InputStreamReader(gzipStream, "utf-8");
				br = new BufferedReader(decoder);
			}
			else{
				Reader reader = new InputStreamReader(fileStream);
				br = new BufferedReader(reader);
			}
			String line = br.readLine();
			JSONObject obj = new JSONObject(line);
			BufferedWriter wr;
			while((line=br.readLine())!=null){
				obj = new JSONObject(line);
				int x = obj.getInt("x");
				int y = obj.getInt("y");
				int z = obj.getInt("z");
				ArrayList<Label> L = this.voxel_map.get(x).get(y).get(z);
				for(Label l:L){
					int n = R.nextInt(101);
					float p = (float) ((float) n/100.0);
					if(p<=l.p){
						wr = this.out_files.get(l.l);
						obj.put("class", l.c);
						wr.write(obj.toString());
						wr.write("\n");
					}
				}
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		} catch (JSONException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	private class Label{
		public int l;
		public float p;
		public int c;
		public Label(int l, float p,int c){
			this.l = l;
			this.p = p;
			this.c = c;
		}
	}
}
