package edu.vanderbilt.masi.algorithms.labelfusion.simple;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public abstract class SIMPLEBase extends AbstractCalculation {
	
	protected int maxIter;
	protected SimpleLabelVolume obs;
	protected int[][][] estimate;
	protected float epsilon;

	public SIMPLEBase() { super(); }
	
	protected abstract boolean hasConverged();
	public abstract boolean hasGlobalWeights();
	protected abstract float getWeight(int x, int y, int z, int r);
	protected abstract void estimateWeights();
	protected abstract void initializeWeights();

	public ImageData getSegmentation() {
		ImageData segmentation = obs.getSegmentation(estimate);
		return segmentation;
	}
	protected void runWeightedVote(){
		for(int x=0;x<obs.getR();x++)
			for(int y=0;y<obs.getC();y++)
				for(int z=0;z<obs.getS();z++)
					if(!obs.isConsensus(x, y, z))
						estimateVoxel(x,y,z);
	}
	private void estimateVoxel(int x, int y, int z){
		float[] arr = new float[obs.getMaxLabel()+1];
		Arrays.fill(arr, 0f);
		for(int j=0;j<obs.getN();j++){
			int l=obs.getObservation(x,y,z,j);
			arr[l] += getWeight(x,y,z,j);
		}
		int maxLab = 0;
		float maxWeight = 0;
		for(int l=0;l<=obs.getMaxLabel();l++)
			if(arr[l] > maxWeight){
				maxLab = l;
				maxWeight = arr[l];
			}
		estimate[x][y][z] = maxLab;
	}
	protected double calculateSTD(float[] arr, double mean){
		double std = 0;
		for(float f: arr)
			std += Math.pow(f - mean, 2);
		std = Math.sqrt(std / (arr.length-1));
		return std;
	}
	protected float calculateMean(float[] arr){
		float mean = 0;
		for(float f: arr)
			mean+= f;
		mean = mean / arr.length;
		return mean;
	}
	public void run(){
		this.estimate = new int[obs.getR()][obs.getC()][obs.getS()];
		int numIter = 0;
		initializeWeights();
		while(numIter < maxIter){
			numIter++;
			JistLogger.logOutput(JistLogger.WARNING,
					String.format("Starting iteration %d", numIter));
			JistLogger.logFlush();
			runWeightedVote();
			estimateWeights();
			if(hasConverged())
				break;
		}
	}
	
}
