package edu.vanderbilt.masi.algorithms.labelfusion;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import java.io.*;

public class PerformanceParametersVectorized extends PerformanceParametersBase {
	
	private PerformanceParameters [] theta;
	private PerformanceParameters  theta_final;
	private HierarchicalModel h_model;
	private float [][] labelnorms;
	private short num_labels_full;
	private int num_levels;
	
	public PerformanceParametersVectorized (int num_raters_in,
											HierarchicalModel h_model_in) {
		
		super(num_raters_in);
		setLabel("VectorizedPerformanceParameters");
		
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing new Vectorized Performance Level Parameters");
		
		h_model = h_model_in;
		num_levels = h_model.num_levels();
		num_labels_full = h_model.num_labels(h_model.num_levels()-1);
		
		// allocate space for the labelnorms
		labelnorms = new float [num_raters][num_labels_full];
		for (int j = 0; j < num_raters; j++)
			for (int s = 0; s < num_labels_full; s++)
				labelnorms[j][s] = 1;
		
		// allocate space for the hierarchical thetas
		theta = new PerformanceParameters [num_levels];
		for (int i = 0; i < h_model.num_levels(); i++)
			theta[i] = new PerformanceParameters(h_model.num_labels(i), num_raters, true);
		theta_final = new PerformanceParameters(num_labels_full, num_raters, true);
		
	}
	
	/*
	 * Implementing Methods
	 */
	
	public void initialize() {
		for (int i = 0; i < h_model.num_levels(); i++)
			theta[i].initialize();
	}
	
	public void normalize() {
		for (int i = 0; i < num_levels; i++)
			theta[i].normalize();
		set_labelnorms();
		set_final_theta();
	}
	
	public void normalize(float [] lp) {
		throw new RuntimeException("This should not be called -- not implemented");
	}
	
	public void copy(PerformanceParametersBase theta2) {
		if (theta2.get_type() == PERFORMANCE_PARAMETERS_VECTORIZED_TYPE) {
			// copy the values
			for (int i = 0; i < h_model.num_levels(); i++)
				theta[i].copy(theta2.getPerformanceParameters(i));
			normalize();
		} else {
			throw new RuntimeException("Cannot copy single->vectorized performance parameters.");
		}
	}
	
	public PerformanceParametersBase getPerformanceParameters(int i) { return(theta[i]); }
	
	public void reset() {
		for (int i = 0; i < h_model.num_levels(); i++)
			theta[i].reset();
	}
	
	public float get_convergence_factor(PerformanceParametersBase theta2) {
		
		// initialize to zero
		float convergence_factor = 0;
		
		// get the impact from each level
		for (int i = 0; i < h_model.num_levels(); i++)
			convergence_factor += theta[i].get_convergence_factor(theta2.getPerformanceParameters(i));
		
		// normalize by the number of levels
		convergence_factor /= h_model.num_levels();
		
		return(convergence_factor);
	}
	
	public void add(int lvl, int j, int s1, int s2, float val) {
		short o_label = h_model.map(s1, lvl);
		short t_label = h_model.map(s2, lvl);
		theta[lvl].add(j, o_label, t_label, labelnorms[j][s2] * val);
	}
	
	public void add(int j, int s1, int s2, float val) { 
		throw new RuntimeException("this type of add should not be called from vectorized theta.");
	}
	
	public float get(int j, int s1, int s2) { return(theta_final.get(j, s1, s2)); }
	public double get(int j, int s, short [] obslabels, float [] obsvals) { return(theta_final.get(j, s, obslabels, obsvals)); }
	public double get_log(int j, int s, short [] obslabels, float [] obsvals) { return(theta_final.get_log(j, s, obslabels, obsvals)); }
	public File toFile (int j, File outdir) { return(theta_final.toFile(j, outdir)); }
	public ImageData toImage (String name) { return(theta_final.toImage(name)); }
	public float get_labelnorm(int j, int s) { return(labelnorms[j][s]); }
	public int get_type() { return(PERFORMANCE_PARAMETERS_VECTORIZED_TYPE); }
	
	/*
	 * Private Methods
	 */
	
	private void set_labelnorms() {
		
		// first, set the full, temporary theta
		PerformanceParametersBase tmptheta = new PerformanceParameters(num_labels_full, num_raters, true);
		tmptheta.reset();
		double val;
		short o_label, t_label;
		for (int j = 0; j < num_raters; j++)
			for (int s1 = 0; s1 < num_labels_full; s1++)
				for (int s2 = 0; s2 < num_labels_full; s2++) {
					val = 1;
					for (int i = 0; i < num_levels; i++) {
						o_label = h_model.map(s1, i);
						t_label = h_model.map(s2, i);
						val *= theta[i].get(j, o_label, t_label);
					}
					tmptheta.add(j, s1, s2, (float)val);
				}
		
		// allocate a temporary vector including only non-zero elements of theta
		double [] vec = new double [num_labels_full];
		
		// calculate the normalization value for each rater/true label
		for (int j = 0; j < num_raters; j++)
			for (int s = 0; s < num_labels_full; s++) {
				
				int veclength = 0;
				for (int l = 0; l < num_labels_full; l++)
					if (tmptheta.get(j, l, s) > 0)
						veclength++;
				int cc = 0;
				for (int l = 0; l < num_labels_full; l++)
					if (tmptheta.get(j, l, s) > 0) {
						vec[cc] = tmptheta.get(j, l, s);
						cc++;
					}
								
				labelnorms[j][s] = (float)PerformanceParametersBase.get_exponential_labelnorm(vec, veclength);
			}
	}
	
	private void set_final_theta() {
		theta_final.reset();
		for (int j = 0; j < num_raters; j++)
			for (int s1 = 0; s1 < num_labels_full; s1++)
				for (int s2 = 0; s2 < num_labels_full; s2++)
					theta_final.add(j, s1, s2, get_vectorized(j, s1, s2));
	}
	
	private float get_vectorized(int j, int s1, int s2) {
		short o_label, t_label;
		double final_val = 1;
		for (int i = 0; i < num_levels; i++) {
			o_label = h_model.map(s1, i);
			t_label = h_model.map(s2, i);
			final_val *= theta[i].get(j, o_label, t_label);
		}
		
		return((float)Math.pow(final_val, labelnorms[j][s2]));
	}
	
	
}
