package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.FileReader;
import java.util.HashMap;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MultiSetPerformanceParameters extends AbstractCalculation {
	
	private String targetClass;
	private int numRaters;
	private String[] classes;
	private HashMap<String,Float[][]> initMap;
	private int numClasses;
	private MultiSetPerformanceParameterMatrix[] thetas;
	private int numTarget;
	
	public MultiSetPerformanceParameters(MultiSetPerformanceParameters o){
		classes = o.classes;
		numClasses = o.numClasses;
		numRaters  = o.numRaters;
		thetas = new MultiSetPerformanceParameterMatrix[numRaters];
		for(int i=0;i<numRaters;i++){
			thetas[i] = new MultiSetPerformanceParameterMatrix(classes[i],o.thetas[i].getNumObserved(),o.thetas[i].getNumTarget());
		}
		numTarget = thetas[0].getNumTarget();
	}
	
	
	public MultiSetPerformanceParameters(String targetClass_in,String[] classes_in,File f){
		super();
		setLabel("MultiSetPerformanceParameters");
		targetClass = targetClass_in;
		classes = classes_in;
		numRaters = classes.length;
		loadFromFile(f);
		loadTheta();
		numTarget = thetas[0].getNumTarget();
	}
	
	private void loadTheta(){
		thetas = new MultiSetPerformanceParameterMatrix[numRaters];
		for(int i=0;i<numRaters;i++){
			String c = classes[i];
			thetas[i] = new MultiSetPerformanceParameterMatrix(initMap.get(c),c);
		}
	}
	
	private void loadFromFile(File f){
		initMap = new HashMap<String,Float[][]>();
		try {
			BufferedReader br = new BufferedReader(new FileReader(f));
			String line;
			numClasses = Integer.parseInt(br.readLine());
			JistLogger.logOutput(JistLogger.INFO, String.format("There are %d classes found in the theta file", numClasses));
			while((line=br.readLine())!=null){
				String[] lineBits = line.split(",");
				String obsClass = lineBits[0];
				String tarClass = lineBits[1];
				if(isTargetClass(tarClass)){
					loadMatrix(br,obsClass);
				}
				else{
					skipClass(br);
				}
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}
	
	private void loadMatrix(BufferedReader br,String c) throws IOException{
		String line = br.readLine();
		String[] lineBits = line.split(",");
		int numRows = Integer.parseInt(lineBits[0]);
		int numCols = Integer.parseInt(lineBits[1]);
		JistLogger.logOutput(JistLogger.INFO, String.format("Loading class %s had %d labels", c,numRows));
		Float[][] mat = new Float[numRows][];
		for(int i=0;i<numRows;i++){
			line = br.readLine();
			mat[i] = parseRow(line,numCols);
		}
		initMap.put(c, mat);
	}
	
	private Float[] parseRow(String line,int n){
		String[] lineBits = line.split(",");
		Float[] row = new Float[n];
		for(int i=0;i<n;i++)
			row[i] = Float.parseFloat(lineBits[i]);
		return row;
	}
	
	private void skipClass(BufferedReader br) throws IOException{
		String line = br.readLine();
		String[] lineBits = line.split(",");
		int numLines = Integer.parseInt(lineBits[0]);
		for(int i=0;i<numLines;i++)
			br.readLine();
	}
	
	public boolean isTargetClass(String c){
		return targetClass.equals(c);
	}
	
	public void copy(MultiSetPerformanceParameters o){
		for(int i=0;i<numRaters;i++){
			thetas[i].copy(o.thetas[i]);
		}
	}
	
	public void reset(){
		for(int i=0;i<numRaters;i++){
			thetas[i].setMatrixEmpty();
		}
	}

	public void normalize(){
		for(int i=0;i<numRaters;i++)
			thetas[i].normalize();
	}
	
	public int getNumTarget(){ return numTarget; }
	
	public double get_convergence_factor(MultiSetPerformanceParameters o){
		double diff = 0f;
		int count  = 0;
		for(int i=0;i<numRaters;i++){
			diff += thetas[i].getDiff(o.thetas[i]);
			count += (thetas[i].getNumObservedReal() * thetas[i].getNumTargetReal());
		}
		return diff/count;
	}
	
	public double getLog(int j,int s,short[] obsLabs,float[] obsVals){
		double val = 0;
		for(int l=0;l<obsLabs.length;l++)
			val += obsVals[l] * thetas[j].getVal(obsLabs[l],s);
		
		return Math.log(val);
	}
	
	public void add(int j, short l, int s, float val){
		thetas[j].add(l,s,val);
	}
}
