package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.*;
import java.util.HashMap;

import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SimplifiedMultiClassPerformanceParameters {
	
	//rater observed truth
	private float[][][] theta;
	
	private HashMap<String,Float[][]> thetaMap;
	private String[] classes;
	private int numLabels;
	
	public SimplifiedMultiClassPerformanceParameters(){
		
	}
	
	public SimplifiedMultiClassPerformanceParameters(File f,String[] cls,String target){
		loadFromFile(f,target);
		classes = cls;
		theta = new float[classes.length][][];
		for(int i=0;i<classes.length;i++){
			Float[][] mat = thetaMap.get(classes[i]);
			int o = mat.length;
			int t = mat[0].length;
			theta[i] = new float[o][t];
			for(int j=0;j<o;j++)
				for(int k=0;k<t;k++)
					theta[i][j][k] = mat[j][k];
		}
		numLabels = theta[0][0].length;
	}

	private void loadFromFile(File f,String target){
		JistLogger.logOutput(JistLogger.WARNING, "Target class is "+target);
		thetaMap = new HashMap<String,Float[][]>();
		try{
			BufferedReader br = new BufferedReader(new FileReader(f));
			String line = br.readLine();
			String[] lineBits;
			while((line=br.readLine())!=null){
				lineBits = line.split(",");
				String obs = lineBits[0];
				String tar = lineBits[1];
				line = br.readLine();
				lineBits = line.split(",");
				int nObs = Integer.parseInt(lineBits[0]);
				int nTar = Integer.parseInt(lineBits[1]);
				if(tar.equals(target)){
					Float[][] mat = new Float[nObs][nTar];
					for(int i=0;i<nObs;i++){
						line = br.readLine();
						lineBits = line.split(",");
						for(int j=0;j<nTar;j++){
							float val = Float.parseFloat(lineBits[j]);
							mat[i][j] = val;
						}
					}
					thetaMap.put(obs, mat);
				}else{
					for(int i=0;i<nObs;i++)
						br.readLine();
				}
			}
			br.close();
		}catch(IOException e){
			e.printStackTrace();
		}
	}

	public SimplifiedMultiClassPerformanceParameters copy(){
		SimplifiedMultiClassPerformanceParameters c = new SimplifiedMultiClassPerformanceParameters();
		c.setTheta(theta);
		return c;
	}
	
	public void setTheta(float[][][] th){
		int n = th.length;
		theta = new float[n][][];
		for(int i=0;i<n;i++){
			float[][] mat = th[i];
			int o = mat.length;
			int t = mat[0].length;
			theta[i] = new float[o][t];
			for(int j=0;j<o;j++)
				for(int k=0;k<t;k++)
					theta[i][j][k] = mat[j][k];
		}
	}
	
	public void normalize(){
		for(int i=0;i<theta.length;i++){
			theta[i] = normalize(theta[i]);
		}
	}
	
	public void clear(){
		for(int i=0;i<theta.length;i++)
			for(int j=0;j<theta[i].length;j++)
				for(int k=0;k<theta[i][j].length;k++)
					theta[i][j][k] = 0;
		
	}
	
	private float[][] normalize(float[][] th){
		int t = th[0].length;
		int o = th.length;
		for(int i=0;i<t;i++){
			float sum = 0f;
			for(int j=0;j<o;j++){
				sum += th[j][i];
			}
			if(sum > 0){
				for(int j=0;j<o;j++){
					th[j][i] = th[j][i]/sum;
				}
			}
		}
		return th;
	}
	
	public int getNumLabels(){ return numLabels; }
	
	public double getLog(int rater, int obs, int tar){
		return Math.log(theta[rater][obs][tar]);
	}
	
	public void add(int rater, int obs, int tar, float val){
		theta[rater][obs][tar] += val;
	}

	public float calculateDifference(
			SimplifiedMultiClassPerformanceParameters thetaOld) {
		float diff = 0f;
		for(int i=0;i<theta.length;i++){
			for(int j=0;j<theta[i].length;j++){
				for(int k=0;k<theta[i][j].length;k++){
					diff += Math.abs(theta[i][j][k] - thetaOld.theta[i][j][k]);
				}
			}
		}
		diff = diff/theta.length;
		diff = diff/numLabels;
		return diff;
	}
	
	public void regularize(float val){
		Float[][] mat;
		for(int i=0;i<classes.length;i++){
			mat = thetaMap.get(classes[i]);
			regularize(mat,i,val);
		}
	}
	private void regularize(Float[][] mat, int r,float val){
		float[][] t = theta[r];
		for(int i=0;i<t.length;i++){
			for(int j=0;j<t[i].length;j++){
				t[i][j] += val*mat[i][j];
			}
		}
		theta[r] = t;
	}
}
