package edu.vanderbilt.masi.algorithms.labelfusion.simple;

import java.util.Arrays;
import java.util.List;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SpatialSIMPLE extends SIMPLEBase{

	protected float weights[][][][]; //x y z rater
	private float prevWeights[][][][];
	private int poolRegion;

	public SpatialSIMPLE(List<ImageData> imageDataList, int int1, int int2, float float1){
		super();
		maxIter = int2;
		obs = new SimpleLabelVolume(imageDataList);
		epsilon = float1;
		poolRegion = int1;
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("Running Spatial SIMPLE with pool size of %d",poolRegion));
		JistLogger.logFlush();
	}

	@Override
	public boolean hasGlobalWeights() {
		return true;
	}

	@Override
	protected float getWeight(int x, int y, int z, int r) {
		return weights[x][y][z][r];
	}

	@Override
	protected boolean hasConverged() {
		// TODO Auto-generated method stub
		return false;
	}

	@Override
	protected void estimateWeights() {
		this.prevWeights = new float[obs.getR()][obs.getC()][obs.getS()][obs.getN()];
		for(int x=0;x<obs.getR();x++)
			for(int y=0;y<obs.getC();y++)
				for(int z=0;z<obs.getS();z++)
					for(int j=0;j<obs.getN();j++)
						this.prevWeights[x][y][z][j] = this.weights[x][y][z][j];
		this.weights = new float[obs.getR()][obs.getC()][obs.getS()][obs.getN()];
		for(int j=0;j<obs.getN();j++){
			boolean[][][] subjCorrect = new boolean[obs.getR()][obs.getC()][obs.getS()];
			calculateCorrectMatrix(j,subjCorrect);
			for(int x=0;x<obs.getR();x++)
				for(int y=0;y<obs.getC();y++)
					for(int z=0;z<obs.getS();z++)
						if(!obs.isConsensus(x, y, z) 
								&& this.prevWeights[x][y][z][j]>0){
							calculateWeight(x,y,z,j,subjCorrect);
							removeWeights(x,y,z);
						}
		}
		printWeights();
	}
	
	private void printWeights(){
		int[] numRemoved = new int[obs.getN()];
		Arrays.fill(numRemoved, 0);
		for(int x=0;x<obs.getR();x++)
			for(int y=0;y<obs.getC();y++)
				for(int z=0;z<obs.getS();z++)
					if(!obs.isConsensus(x, y, z))
						for(int j=0;j<obs.getN();j++)
							if(weights[x][y][z][j]==0)
								numRemoved[j]++;
		for(int j=0;j<obs.getN();j++)
			JistLogger.logOutput(JistLogger.WARNING,
					String.format("Rater %d: %d removed", j,numRemoved[j]));
	}
	
	private void removeWeights(int x,int y,int z){
		float[] vec = this.weights[x][y][z];
		float mean = calculateMean(vec);
		float std  = (float) calculateSTD(vec, (double) mean);
		float minVal = mean - 2*std;
		for(int j=0;j<vec.length;j++)
			if(vec[j] < minVal)
				this.weights[x][y][z][j] = 0;
	}

	private void calculateCorrectMatrix(int j, boolean[][][] mat){
		for(int x=0;x<obs.getR();x++)
			for(int y=0;y<obs.getC();y++)
				for(int z=0;z<obs.getS();z++)
					if(obs.isConsensus(x, y, z))
						mat[x][y][z] = true;
					else if(estimate[x][y][z]==obs.getObservation(x, y, z, j))
						mat[x][y][z] = true;
					else
						mat[x][y][z] = false;
	}

	private void calculateWeight(int x,int y,int z, int j,boolean[][][] correctMat){
		int numTotal = 0;
		int numCorrect = 0;
		int xl = Math.max(0, x-poolRegion);
		int xh = Math.min(obs.getR()-1, x+poolRegion);
		int yl = Math.max(0, y-poolRegion);
		int yh = Math.min(obs.getC()-1, y+poolRegion);
		int zl = Math.max(0, z-poolRegion);
		int zh = Math.min(obs.getS()-1, z+poolRegion);
		for(int xi=xl;xi<=xh;xi++)
			for(int yi=yl;yi<=yh;yi++)
				for(int zi=zl;zi<=zh;zi++)
					if(!obs.isConsensus(xi, yi, zi)){
						numTotal++;
						numCorrect+= (obs.getObservation(xi, yi, zi, j)==estimate[xi][yi][zi]) ? 1 : 0;
					}
		float weight = (float) numCorrect / (float) numTotal;
		this.weights[x][y][z][j] = weight;
	}

	@Override
	protected void initializeWeights() {
		this.weights = new float[obs.getR()][obs.getC()][obs.getS()][obs.getN()];
		for(float[][][] mat: weights)
			for(float[][] im: mat)
				for(float[] row: im)
					Arrays.fill(row, 1f);
	}

}
