package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.File;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class HierarchicalSpatialSTAPLESimple extends SpatialSTAPLESimple {
	
	private HierarchicalModel h_model;
	
	public HierarchicalSpatialSTAPLESimple (ObservationBase obs_in,
										    int [] hws_in,
					               		    float epsilon_in,
					                        float bias_in,
					               		    int maxiter_in,
					                        int priortype_in,
					                        File hierarchy_file,
					                        String outname) {
		
		super(obs_in, hws_in, epsilon_in, bias_in, maxiter_in, priortype_in, outname, true);
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Hierarchical Simplified Spatial STAPLE +++");
		setLabel("Hierarchical Simplified Spatial STAPLE");
		
		// construct the hierarchical model
		h_model = new HierarchicalModel(hierarchy_file, obs);
		
		// allocate space for the hierarchical thetas
		theta = get_initial_theta(h_model);
		theta_prev = new PerformanceParametersVectorized(obs.num_raters(), h_model);
		
	}
	
	protected void run_M_step_voxel(int x, int y, int z, int v, float [] lp) {
		
		// add the impact to theta (M-step)
		for (int i = 0; i < h_model.num_levels(); i++)
			if (h_model.consensus_level(x, y, z, v) < i+1)
				for (int s = 0; s < obs.num_labels(); s++)
					if (lp[s] > 0)
						for (int j = 0; j < obs.num_raters(); j++) {
							
							// use the local selection
							if (!obs.get_local_selection(x, y, z, v, j))
								continue;
							
							// get the rater observations
							short [] obslabels = obs.get_all(x, y, z, v, j);
							float [] obsvals = obs.get_all_vals(x, y, z, v, j);
							
							// add the impact to theta
							for (int l = 0; l < obslabels.length; l++)
								theta.add(i, j, obslabels[l], s, obsvals[l]*lp[s]);
						}
	}

}
