package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MultiSetPerformanceParameterMatrix extends AbstractCalculation{

	private float[][] mat;
	private int numObserved, numTarget;
	private String className;
	private int numTarReal, numObsReal;

	public MultiSetPerformanceParameterMatrix(){
		super();
		setLabel("MultiSetPerformanceParameterMatrix");
	}

	public MultiSetPerformanceParameterMatrix(Float[][] matIn,String c){
		this();
		className = c;
		loadMatrix(matIn);
		numObserved = mat.length;
		numTarget   = mat[0].length;
	}

	public MultiSetPerformanceParameterMatrix(String c,int numObservedIn,int numTargetIn){
		this();
		className = c;
		numObserved = numObservedIn;
		numTarget   = numTargetIn;
		setMatrixEmpty();
		JistLogger.logOutput(JistLogger.INFO, String.format("Num Observered: %d Num Target: %d", numObserved,numTarget));
	}

	public void setMatrixEmpty(){
		mat = new float[numObserved+1][numTarget+1];
		for(float[] row:mat)
			Arrays.fill(row, 0f);
	}

	public void loadMatrix(Float[][] matIn){
		mat = new float[matIn.length][];
		for(int i=0;i<mat.length;i++){
			mat[i] = new float[matIn[i].length];
			for(int j=0;j<mat[i].length;j++)
				mat[i][j] = matIn[i][j];
		}
	}

	public void copy(MultiSetPerformanceParameterMatrix o){
		mat = new float[numObserved+1][numTarget+1];
		for(int i=0;i<numObserved;i++)
			for(int j=0;j<numTarget;j++)
				mat[i][j] = o.getVal(i, j);
	}

	public void normalize(){
		// Over all columns
		for(int i=0;i<numTarget+1;i++){
			float sum = 0f;
			// Over all rows
			for(int j=0;j<numObserved+1;j++){
				sum += mat[j][i];
			}
			// Normalize
			if(sum==0f)
				continue;
			for(int j=0;j<numObserved+1;j++)
				mat[j][i] = mat[j][i] / sum;
		}
	}

	public int getNumObserved(){ return this.numObserved; }
	public int getNumTarget(){ return this.numTarget; }
	public int getNumObservedReal(){ return this.numObsReal; }
	public int getNumTargetReal(){ return this.numTarReal; }
	public float getVal(int obs,int tar){ return mat[obs][tar]; }

	protected void print_status(int ind,
			int num) {

		int total = 10;
		int currval = (int)((total * (float)ind) / ((float)(num-1)));
		int prevval = (int)((total * ((float)ind-1)) / ((float)(num-1)));

		if (currval > prevval) {
			String msg = "[";
			for (int i = 0; i < currval; i++)
				msg += "=";
			for (int i = currval; i < total; i++)
				msg += "+";
			msg += "]";

			JistLogger.logOutput(JistLogger.INFO, msg);
		}
	}
	
	protected double getDiff(MultiSetPerformanceParameterMatrix o){
		double diff = 0f;
		boolean[] tar = new boolean[numTarget];
		boolean[] obs = new boolean[numObserved];
		numTarReal = 0;
		numObsReal = 0;
		for(int i=0;i<numObserved+1;i++){
			for(int j=0;j<numTarget+1;j++){
				diff += Math.abs(mat[i][j] - o.mat[i][j]);
				if(mat[i][j] > 0 || o.mat[i][j] > 0){
					//JistLogger.logOutput(JistLogger.WARNING, String.format("%d %d %.3f %.3f %.3f", i,j,mat[i][j],o.mat[i][j],mat[i][j] - o.mat[i][j]));
					if(!obs[i])
						numObsReal++;
					if(!tar[j])
						numTarReal++;
					obs[i] = true;
					tar[j] = true;
				}
					
			}
		}
		return diff;
	}
	
	protected void add(short l, int s, float val){
		mat[l][s] += val;
	}
}
