package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.*;
import java.util.Arrays;

import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MultiClassPerformanceParameterBase {

	private int numObsLabels;
	private int numTargetLabels;

	private float[][] theta;
	private float[][] thetaLog;


	public MultiClassPerformanceParameterBase(BufferedReader br){
		try {
			String line = br.readLine();
			String[] lineBits = line.split(",");
			numObsLabels = Integer.parseInt(lineBits[0]);
			numTargetLabels = Integer.parseInt(lineBits[1]);
			theta = new float[numObsLabels][numTargetLabels];
			thetaLog = new float[numObsLabels][numTargetLabels];
			for(int i = 0; i < numObsLabels; i++){
				line = br.readLine();
				lineBits = line.split(",");
				for(int j=0;j<numTargetLabels;j++){
					theta[i][j] =Float.parseFloat(lineBits[j]);
					if(theta[i][j]==Double.NaN)
						theta[i][j]=0;
				}
			}
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

	}
	
	public MultiClassPerformanceParameterBase(int numObs,int numTarget){
		numObsLabels = numObs;
		numTargetLabels = numTarget;
		theta = new float[numObs][numTarget];
		for(float[] row : theta) Arrays.fill(row, 0f);
	}

	public MultiClassPerformanceParameterBase(){ }

	public MultiClassPerformanceParameterBase copy(){
		MultiClassPerformanceParameterBase P = new MultiClassPerformanceParameterBase();
		P.setTheta(theta);
		P.setNumObsLabels(numObsLabels);
		P.setNumTargetLabels(numTargetLabels);
		return P;

	}

	public void setTheta(float[][] theta0){
		theta = new float[theta0.length][theta0[0].length];
		thetaLog = new float[theta0.length][theta0[0].length];
		for(int i = 0; i < theta0.length; i++){
			for(int j = 0; j < theta0[0].length; j++){
				theta[i][j] = theta0[i][j];
				thetaLog[i][j] = 0;
			}
		}
	}

	public void setNumObsLabels(int o){
		numObsLabels = o;
	}

	public void setNumTargetLabels(int o){
		numTargetLabels = o;
	}

	public int getNumTargetLabels(){ return numTargetLabels; }

	public int getNumObsLabels(){ return numObsLabels; }

	public void reset(){
		for(int i=0; i<numObsLabels; i++){
			for(int j=0;j<numTargetLabels; j++){
				theta[i][j] = 0;
				thetaLog[i][j] = 0;
			}
		}
	}

	public float getValue(int i,int j){
		return theta[i][j];
	}

	public void normalize(){
		double sum;
		for(int i = 0;i < theta[0].length;i++){
			sum = 0;
			for(int j=0;j<theta.length;j++){
				sum += theta[j][i];
			}
			if(sum>0){
				for(int j=0;j<theta.length;j++){
					theta[j][i] /= sum;

				}
			}
		}
	}

	public void addValue(int obs, int target, double val){
		theta[obs][target] += val;
	}

	public String toString(){
		String s = "";
		for(float[] row: theta){
			for(float d:row){
				s += d;
				s += ",";
			}
			s += "\n";
		}
		return s;
	}

	public float calculateDiff(MultiClassPerformanceParameterBase theta2){
		float diff = 0;
		for(int i=0;i<theta.length;i++){
			for(int j=0;j<theta[i].length;j++){
				diff += Math.abs(theta[i][j] - theta2.theta[i][j]);
			}
		}
		return diff;
	}
	
	public void giveLittleWeight(float v){
		for(int i=0;i<theta.length;i++){
			for(int j=0;j<theta[i].length;j++){
				theta[i][j] += v;
			}
		}
	}
	
	public void nudge(float f,MultiClassPerformanceParameterBase t){
		for(int i=0;i<theta.length;i++){
			for(int j=0;j<theta[i].length;j++){
				addValue(i,j, f*t.getValue(i,j)); // Nudge towards known
			}
		}
	}
}
