package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class HierarchicalModel extends AbstractCalculation {
	
	private short [][] hierarchy_map;
	private short [] h_num_labels;
	private int num_levels;
	private int [][][][] h_consensus;

	public HierarchicalModel(File hierarchy_file,
							 ObservationBase obs) {
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing new Hierarchical Model +++");
		
		// read the hierarchy file
		try {
			set_hierarchy_map(hierarchy_file, obs);
		} catch (IOException e) {
			throw new RuntimeException("Error in hierarchy file.");
		}
		
		// set the hierarchical consensus
		set_hierarchical_consensus(obs);
	}
	
	private void set_hierarchy_map(File file, ObservationBase obs) throws IOException {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Reading the hierarchy file");
		BufferedReader reader  = new BufferedReader(new FileReader(file));
		String line;
		
		String [] lineparts;
		
		// read the first line to get the header information
		lineparts = reader.readLine().split(",");
		int tmp_num_labels = Integer.parseInt(lineparts[0]);
		num_levels = lineparts.length - 1;
		
		// do some basic error checking
		if (tmp_num_labels != obs.num_labels())
			throw new RuntimeException("Hierarchy File does not match observations");
		
		// set the hierarchical number of labels
		h_num_labels = new short [num_levels];
		for (int i = 0; i < num_levels; i++)
			h_num_labels[i] = (short)Integer.parseInt(lineparts[i+1]);
		
		// allocate space for the hierarchy map
		hierarchy_map = new short [obs.num_labels][num_levels];
		
		// read each of the additional lines
		while ((line = reader.readLine()) != null) {
			lineparts = line.split(",");
			
			// get the label we're interested in
			int labelnum = Integer.parseInt(lineparts[0]);
			short lnum = obs.unmap_label((short)labelnum);
			
			// set the hierarchical mapping
			for (int i = 0; i < num_levels; i++)
				hierarchy_map[lnum][i] = (short)Integer.parseInt(lineparts[i+1]);
		}
		
		reader.close();
		
		// print out some important information
		JistLogger.logOutput(JistLogger.INFO, "-> Found the following hierarchical information:");
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of hierarchical levels: %d", num_levels));
		for (int i = 0; i < num_levels; i++)
			JistLogger.logOutput(JistLogger.INFO, String.format("Number of labels at %d: %d", i, h_num_labels[i]));
		for (int l = 0; l < obs.num_labels; l++) {
			String outstr = String.format("Labelnum %02d: %d", l, hierarchy_map[l][0]);
			for (int i = 1; i < num_levels; i++)
				outstr = String.format("%s, %d", outstr, hierarchy_map[l][i]);
			JistLogger.logOutput(JistLogger.CONFIG, outstr);
		}
		
	}
	private void set_hierarchical_consensus(ObservationBase obs) {
		
		h_consensus = new int[obs.dimx()][obs.dimy()][obs.dimz()][obs.dimv()];
		
		// allocate space for some temporary variables
		float [][] lp = new float [num_levels][];
		for (int i = 0; i < num_levels; i++)
			lp[i] = new float [h_num_labels[i]];
		float [] maxval = new float [num_levels];
		int [] numcon = new int [num_levels];
		
		for (int x = 0; x < obs.dimx(); x++)
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++) 
					for (int v = 0; v < obs.dimv(); v++) 
						if (!obs.ignore_consensus()) {
							
							// set it to zero if we're not ignoring consensus
							h_consensus[x][y][z][v] = 0;
							
						} else if (obs.is_consensus(x, y, z, v)) {
							
							// if it is consensus, then it is fully consensus
							for (int i = 0; i < num_levels; i++)
								numcon[i]++;
							h_consensus[x][y][z][v] = num_levels;

						} else {
							
							// determine the consensus level
							h_consensus[x][y][z][v] = 0;
							
							// initialize the temporary variables
							Arrays.fill(maxval, 0f);
							for (int i = 0; i < num_levels; i++)
								Arrays.fill(lp[i], 0f);
							
							// get the probabilities at each level
							for (int j = 0; j < obs.num_raters; j++) {
								// get the information for the current rater
								short [] obslabels = obs.get_all(x, y, z, v, j);
								float [] obsvals = obs.get_all_vals(x, y, z, v, j);
								for (int l = 0; l < obslabels.length; l++)
									for (int i = 0; i < num_levels; i++)
										lp[i][hierarchy_map[obslabels[l]][i]] += obsvals[l];
							}
							
							// get the max value at each level
							for (int i = 0; i < num_levels; i++) {
								for (short s = 0; s < h_num_labels[i]; s++)
									if (lp[i][s] > maxval[i])
										maxval[i] = lp[i][s];
								maxval[i] /= (float)obs.num_raters();
								
								if (maxval[i] >= (obs.consensus_threshold() - 0.0001f)) {
									h_consensus[x][y][z][v] = i+1;
									numcon[i]++;
								}
							}
						}
						
		// print out the fraction consensus at each level
		for (int i = 0; i < num_levels; i++) {
			float frac_con = ((float)numcon[i]) / (obs.num_vox());
			JistLogger.logOutput(JistLogger.INFO, String.format("Fraction Consensus at level %d: %f", i, frac_con));
		}
	}

	public int num_levels() { return(num_levels); }
	public short map(int label, int lvl) { return(hierarchy_map[label][lvl]); }
	public short num_labels(int i) { return(h_num_labels[i]); }
	public int consensus_level(int x, int y, int z, int v) { return(h_consensus[x][y][z][v]); }
}
