package edu.vanderbilt.masi.algorithms.labelfusion;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import java.io.*;
import java.util.Arrays;

public class PerformanceParameters extends PerformanceParametersBase {
	
	private float [][][] theta;
	private float [][] labelnorms;
	private short num_labels;
	
	public PerformanceParameters (short num_labels_in, int num_raters_in) {
		super(num_raters_in);
		setLabel("PerformanceParameters");
		
		JistLogger.logOutput(JistLogger.INFO, "-> Initializing new Performance Level Parameters");
		
		num_labels = num_labels_in;
		theta = new float[num_raters][num_labels][num_labels];
		labelnorms = new float[num_raters][num_labels];
	}
	
	public PerformanceParameters (short num_labels_in, int num_raters_in, boolean quiet) {
		super(num_raters_in);
		setLabel("PerformanceParameters");
		
		if (!quiet)
			JistLogger.logOutput(JistLogger.INFO, "+++ Initializing new Performance Level Parameters +++");
			
		num_labels = num_labels_in;
		theta = new float[num_raters][num_labels][num_labels];
		labelnorms = new float[num_raters][num_labels];
		
	}
	
	public PerformanceParametersBase getPerformanceParameters(int i) { return(this); }
	
	public void initialize() {
		
		// make sure everything is zero
		reset();
		
		// add the default values
		for (int j = 0; j < num_raters; j++)
			for (int s = 0; s < num_labels; s++)
				for (int l = 0; l < num_labels; l++)
					if (s == l)
						add(j, s, l, diagval);
					else
						add(j, s, l, (1-diagval) / (num_labels-1));
		
		// normalize properly
		normalize();
		
	}
	
	public void normalize() {
		for (int j = 0; j < num_raters; j++) {
			for (int l = 0; l < num_labels; l++)
				if (labelnorms[j][l] == 0)
					for (int s = 0; s < num_labels; s++)
						theta[j][s][l] = 1 / ((float)num_labels);
				else
					for (int s = 0; s < num_labels; s++)
						theta[j][s][l] /= labelnorms[j][l];
			Arrays.fill(labelnorms[j], 1f);
		}
	}
	
	public void normalize(float [] lp) {
		for (int j = 0; j < num_raters; j++) {
			for (int l = 0; l < num_labels; l++)
				if (lp[l] == 0)
					theta[j][l][l] = 1;
				else
					if (labelnorms[j][l] == 0)
						for (int s = 0; s < num_labels; s++)
							theta[j][s][l] = 1 / ((float)num_labels);
					else
						for (int s = 0; s < num_labels; s++)
							theta[j][s][l] /= labelnorms[j][l];
			Arrays.fill(labelnorms[j], 1f);
		}
	}
	
	public void copy(PerformanceParametersBase theta2) {
		if (theta2.get_type() == PERFORMANCE_PARAMETERS_SINGLE_TYPE) {
			for (int j = 0; j < num_raters; j++)
				for (short s = 0; s < num_labels; s++) {
					labelnorms[j][s] = theta2.get_labelnorm(j, s);
					for (short l = 0; l < num_labels; l++)
						theta[j][s][l] = theta2.get(j, s, l);
				}
		} else {
			reset();
			for (int j = 0; j < num_raters; j++)
				for (short s = 0; s < num_labels; s++)
					for (short l = 0; l < num_labels; l++)
						add(j, s, l, theta2.get(j, s, l));
		}
		
		// normalize everything
		normalize();
	}
	
	public void reset() {
		for (int j = 0; j < num_raters; j++) {
			Arrays.fill(labelnorms[j], 0f);
			for (int s = 0; s < num_labels; s++)
				Arrays.fill(theta[j][s], 0f);
		}
	}
	
	public File toFile (int j, File outdir) {
		
		File f = new File(outdir, String.format("Performance_Parameters_%04d.txt", j)); 
		
		String out;
		
		if(theta != null) {
			//set the output string
			out = String.format("%d\n", num_labels);
			for (int s1 = 0; s1 < num_labels; s1++) {
				for (int s2 = 0; s2 < num_labels; s2++)
					out = String.format("%s%f ", out, get(j, s1, s2));
				out = String.format("%s\n", out);
			}
		} else {
			out = null;
		}
		
		FileWriter fstream;
		try {
			fstream = new FileWriter(f);
			BufferedWriter b = new BufferedWriter(fstream);
			b.write(out); b.close(); fstream.close();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

		return(f);
	}
	
	public ImageData toImage (String name) {
		
		// allocate space for the performance parameters
		ImageData thetaimg = new ImageDataFloat(name, 
												num_labels,
												num_labels,
												num_raters,
												1);
		
		// set the values in the volume
		for (int l1 = 0; l1 < num_labels; l1++)
			for (int l2 = 0; l2 < num_labels; l2++)
				for (int j = 0; j < num_raters; j++)
					thetaimg.set(l1, l2, j, 0, theta[j][l1][l2]);
		
		// return the performance level parameters
		return(thetaimg);
	}
	
	public float get_convergence_factor(PerformanceParametersBase theta2) {
		float current_sum = 0;
		float previous_sum = 0;
		float convergence_factor;
		
		for (int j = 0; j < num_raters; j++) {
			float theta_trace = 0;
			float theta_prev_trace = 0;
			
			for (int s = 0; s < num_labels; s++) {
				theta_trace += theta[j][s][s];
				theta_prev_trace += theta2.get(j, s, s);
			}
			
			current_sum += theta_trace;
			previous_sum += theta_prev_trace;
		}
		
		convergence_factor = Math.abs(current_sum - previous_sum);
		convergence_factor /= (num_raters * num_labels);
		
		return(convergence_factor);
	}
	
	public void add(int j, int s1, int s2, float val) {
		theta[j][s1][s2] += val;
		labelnorms[j][s2] += val;
	}
	
	public void add(int lvl, int j, int s1, int s2, float val) { add(j, s1, s2, val); }
	
	public float get(int j, int s1, int s2) { return(theta[j][s1][s2]); }
	
	public float get_labelnorm(int j, int s) { return(labelnorms[j][s]); }
	
	public double get(int j, int s, short [] obslabels, float [] obsvals) {
		double val = 0;
		for (int l = 0; l < obslabels.length; l++)
			val += obsvals[l] * theta[j][obslabels[l]][s];
		return(val);
	}
	
	public double get_log(int j, int s, short [] obslabels, float [] obsvals) {
		
		double val = 0;
		for (int l = 0; l < obslabels.length; l++)
			val += obsvals[l] * get(j, obslabels[l], s);
		return(Math.log(val));
	}
	
	public int get_type() { return(PERFORMANCE_PARAMETERS_SINGLE_TYPE); }
}
