package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.File;
import java.util.Arrays;

import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class IdealHierarchicalSTAPLE extends IdealSTAPLE {
	
	private HierarchicalModel h_model;
	
	public IdealHierarchicalSTAPLE (ParamVolume truthvol,
									ObservationBase obs_in,
				               		float epsilon_in,
				               		int maxiter_in,
				               		int priortype_in,
				               		File hierarchy_file,
				               		String outname) {
		
		super(truthvol, obs_in, epsilon_in, maxiter_in, priortype_in, outname, true);
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Ideal Hierarchical STAPLE +++");
		setLabel("Ideal_Hierarchical_STAPLE");
		
		// construct the hierarchical model
		h_model = new HierarchicalModel(hierarchy_file, obs);
		
		float [] lp = new float [obs.num_labels()];
		
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Calculating Ideal Performance Parameters"));
		theta = new PerformanceParametersVectorized(obs.num_raters(), h_model);
		theta.reset();
		for (int x = 0; x < obs.dimx(); x++) {
			if (print_status)
				print_status(x, obs.dimx());
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++)
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							Arrays.fill(lp, 0f);
							lp[obs.unmap_label(truth[x][y][z][v])] = 1f;
							run_M_step_voxel(x, y, z, v, lp);
						}
		}
		theta.normalize();
		
		truth = null;
		
	}
	
	protected void run_M_step_voxel(int x, int y, int z, int v, float [] lp) {
		
		// add the impact to theta (M-step)
		for (int i = 0; i < h_model.num_levels(); i++)
			if (h_model.consensus_level(x, y, z, v) < i+1)
				for (int s = 0; s < obs.num_labels(); s++)
					if (lp[s] > 0)
						for (int j = 0; j < obs.num_raters(); j++) {
							
							// use the local selection
							if (!obs.get_local_selection(x, y, z, v, j))
								continue;
							
							// get the rater observations
							short [] obslabels = obs.get_all(x, y, z, v, j);
							float [] obsvals = obs.get_all_vals(x, y, z, v, j);
							
							// add the impact to theta
							for (int l = 0; l < obslabels.length; l++)
								theta.add(i, j, obslabels[l], s, obsvals[l]*lp[s]);
						}
	}

}
