package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.File;

import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MultiSetMajorityVote extends AbstractMultiSetVoting {
	
	private short[][][][] labs;
	private int[] dims;
	private short[][][] seg;

	public MultiSetMajorityVote(ParamVolumeCollection obs,File raterMap, File thetaFile, String tc) {
		super(raterMap, thetaFile, tc);
		loadRaters(obs);
		log(String.format("The image size is [%d %d %d]", dims[0],dims[1],dims[2]));
	}

	@Override
	public void run() {
		log("Running Segmentation");
		seg = new short[dims[0]][dims[1]][dims[2]];
		for(int i=0;i<dims[0];i++)
			for(int j=0;j<dims[1];j++)
				for(int k=0;k<dims[2];k++)
					seg[i][j][k] = getSegmentation(i,j,k);
	}
	
	private short getSegmentation(int i,int j,int k){
		short res = -1;
		double[] labProbs = new double[numTarget];
		for(int l=0;l<numRaters;l++){
			short lab = labs[l][i][j][k];
			double prob[] = getProbabilities(l,lab);
			for(int s=0;s<prob.length;s++)
				labProbs[s] += prob[s];
		}
		int maxLab = -1;
		double maxProb = -1f;
		for(int s=0;s<labProbs.length;s++){
			if(labProbs[s] > maxProb){
				maxProb = labProbs[s];
				maxLab = s;
			}
		}
		if(maxProb < 0f){
			JistLogger.logOutput(JistLogger.SEVERE, String.format("WARNING!! Got no valid probabilities at voxels %d %d %d", i,j,k));
		}
		res = (short) maxLab;
		return res;
	}

	@Override
	public ImageData getSegmentation() {
		ImageData im = new ImageDataInt("segmentation",dims[0],dims[1],dims[2]);
		for(int i=0;i<dims[0];i++)
			for(int j=0;j<dims[1];j++)
				for(int k=0;k<dims[2];k++)
					im.set(i, j, k, seg[i][j][k]);
		return im;
	}

	private void loadRaters(ParamVolumeCollection obs){
		log("Loading Raters");
		dims = new int[3];
		ImageData im = obs.getParamVolume(0).getImageData();
		dims[0] = im.getRows();
		dims[1] = im.getCols();
		dims[2] = im.getSlices();
		labs = new short[numRaters][dims[0]][dims[1]][dims[2]];
		for(int l=0;l<numRaters;l++){
			im = obs.getParamVolume(l).getImageData();
			for(int i=0;i<dims[0];i++)
				for(int j=0;j<dims[1];j++)
					for(int k=0;k<dims[2];k++)
						labs[l][i][j][k] = im.getShort(i, j,k);
			im.dispose();
		}
	}
	
}
