package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.File;
import java.util.List;

import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MultiSetLocallyWeightedVote extends AbstractMultiSetVoting {
	
	private ObservationBase2 obs;
	private short[][][] seg;
	private int[] dims;

	public MultiSetLocallyWeightedVote(ParamVolumeCollection regLabs,File raterMap, File thetaFile, String targetClass,ParamVolume target,ParamVolumeCollection regIms) {
		super(raterMap, thetaFile, targetClass);
		
		int[] sv = {3,3,3,0};
		int[] pv = {2,2,2,0};
		float[] sp = {1.5f,1.5f,1.5f,1.5f};
		
		obs = new ObservationVolumePartial2(target,regLabs,regIms, 
				1, sv, pv,sp,
				.5f,1,0.05f,
				-1,1f,-1,true,
				classNames,targetClass);
		
		dims = new int[3];
		dims[0] = obs.dimx();
		dims[1] = obs.dimy();
		dims[2] = obs.dimz();
		seg = new short[dims[0]][dims[1]][dims[2]];
	}

	@Override
	public void run() {
		for(int i=0;i<dims[0];i++){
			for(int j=0;j<dims[1];j++){
				for(int k=0;k<dims[2];k++){
					seg[i][j][k] = runSegmentation(i,j,k);
				}
			}
		}
	}
	
	private short runSegmentation(int i,int j,int k){
		double[] labProbs = new double[numTarget];
		if(obs.is_consensus(i, j, k, 0))
			return obs.get_consensus_estimate(i, j, k, 0);
		for(int l = 0;l <obs.num_raters();l++){
			short[] labs = obs.get_all(i, j, k, 0, l);
			float[] vals = obs.get_all_vals(i, j, k, 0,l);
			for(int m=0;m<labs.length;m++){
				short lab = labs[m];
				float v = vals[m];
				double prob[] = getProbabilities(l,lab);
				for(int s=0;s<prob.length;s++)
					labProbs[s] += prob[s]*v;
			}
		}
		int maxLab = -1;
		double maxProb = -1f;
		for(int s=0;s<labProbs.length;s++){
			if(labProbs[s] > maxProb){
				maxProb = labProbs[s];
				maxLab = s;
			}
		}
		if(maxProb < 0f){
			JistLogger.logOutput(JistLogger.SEVERE, String.format("WARNING!! Got no valid probabilities at voxels %d %d %d", i,j,k));
		}
		short res = (short) maxLab;
		return res;
	}

	@Override
	public ImageData getSegmentation() {
		
		ImageData im = new ImageDataInt("segmentation",obs.orig_dimx(),obs.orig_dimy(),obs.orig_dimz());
		for(int i=0;i<dims[0];i++)
			for(int j=0;j<dims[1];j++)
				for(int k=0;k<dims[2];k++)
					im.set(i+obs.offx(), j+obs.offy(), k+obs.offz(), seg[i][j][k]);
		return im;
	}

}
