package edu.vanderbilt.masi.algorithms.labelfusion.simple;

import java.util.Arrays;
import java.util.List;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SIMPLE extends SIMPLEBase{

	
	protected float[] weights;
	protected float[] prevWeights;
	
	public SIMPLE(List<ImageData> imageDataList, int int2, float float1) {
		super();
		maxIter = int2;
		obs = new SimpleLabelVolume(imageDataList);
		epsilon = float1;
	}
	
	protected boolean hasConverged(){
		float diff = 0;
		for(int j=0;j<this.weights.length;j++)
			diff += Math.abs(weights[j] - prevWeights[j]);
		return diff < epsilon;
	}
	
	protected void estimateWeights(){
		this.prevWeights = new float[this.weights.length];
		for(int i=0;i<this.weights.length;i++)
			this.prevWeights[i] = this.weights[i];
		for(int i=0;i<this.weights.length;i++)
			if(this.prevWeights[i] > 0)
				estimateWeight(i);
			else
				this.weights[i] = 0;
		double mean = calculateMean(this.weights);
		double std  = calculateSTD(this.weights,mean);
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("Mean: %.3f STD: %.3f", mean,std));
		float cutoff = (float) (mean - 2*std);
		for(int i=0;i<this.weights.length;i++)
			if(this.weights[i] < cutoff){
				JistLogger.logOutput(JistLogger.WARNING,
						String.format("Removing %d (weight was %.3f, cutoff %.3f)",
								i,this.weights[i],cutoff));
				this.weights[i] = 0;
			}
	}
	
	private void estimateWeight(int r){
		int numCorrect = 0;
		int numTotal = 0;
		for(int x=0;x<obs.getR();x++)
			for(int y=0;y<obs.getC();y++)
				for(int z=0;z<obs.getS();z++)
					if(!obs.isConsensus(x, y, z)){
						numTotal++;
						if(obs.getObservation(x, y, z, r) == estimate[x][y][z])
							numCorrect++;
					}
		this.weights[r] = (float) numCorrect / (float) numTotal;
	}
	
//	private void printWeights(){
//		JistLogger.logOutput(JistLogger.WARNING, "Weights:");
//		for(int j=0;j<weights.length;j++)
//			JistLogger.logOutput(JistLogger.WARNING,
//					String.format("  Rater %d: %.4f", j,weights[j]));
//	}
	
	@Override
	public boolean hasGlobalWeights() {
		return false;
	}

	@Override
	protected float getWeight(int x, int y, int z, int r) {
		return this.weights[r];
	}

	@Override
	protected void initializeWeights() {
		this.weights = new float[obs.getN()];
		Arrays.fill(this.weights, 1);
	}
	
}
