package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.algorithms.labelfusion.ObservationBase;

public class IdealSTAPLE extends STAPLE {
	
	protected short [][][][] truth;
		
	public IdealSTAPLE (ParamVolume truthvol,
						ObservationBase obs_in,
				   		float epsilon_in,
				   		int maxiter_in,
				   		int priortype_in,
				   		String outname) {
		
		this(truthvol, obs_in, epsilon_in, maxiter_in, priortype_in, outname, false);
	}
	
	public IdealSTAPLE (ParamVolume truthvol,
						ObservationBase obs_in,
					    float epsilon_in,
					    int maxiter_in,
					    int priortype_in,
					    String outname,
					    boolean quiet) {
	
		super(obs_in, epsilon_in, maxiter_in, priortype_in, outname, true);
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Ideal STAPLE +++");
		setLabel("Ideal_STAPLE");
		
		truth = new short [obs.dimx()][obs.dimy()][obs.dimz()][obs.dimv()];
		load_truth_labels(truthvol, truth);
		print_status = true;
		
		if (!quiet) {
			theta = new PerformanceParameters(obs.num_labels(), obs.num_raters());
			theta.reset();
			float [] lp = new float [obs.num_labels()];
			JistLogger.logOutput(JistLogger.INFO, String.format("-> Calculating Ideal Performance Parameters"));
			for (int x = 0; x < obs.dimx(); x++) {
				if (print_status)
					print_status(x, obs.dimx());
				for (int y = 0; y < obs.dimy(); y++)
					for (int z = 0; z < obs.dimz(); z++)
						for (int v = 0; v < obs.dimv(); v++)
							if (!obs.is_consensus(x, y, z, v)) {
								Arrays.fill(lp, 0f);
								lp[obs.unmap_label(truth[x][y][z][v])] = 1f;
								run_M_step_voxel(x, y, z, v, lp);
							}
			}
			theta.normalize();
			truth = null;
		}

	}

	public ImageData run () {
		
		float [] lp = new float [obs.num_labels()];
		
		// initialize all of the priors
		initialize_priors(lp);
		
		JistLogger.logOutput(JistLogger.INFO, String.format("\n-> Running the Estimation Algorithm"));
		
		// iterate over all non-consensus voxels
		for (int x = 0; x < obs.dimx(); x++) {
			if (print_status)
				print_status(x, obs.dimx());
			for (int y = 0; y < obs.dimy(); y++)
				for (int z = 0; z < obs.dimz(); z++)
					for (int v = 0; v < obs.dimv(); v++)
						if (!obs.is_consensus(x, y, z, v)) {
							// run the E-step for the voxel
							run_E_step_voxel(x, y, z, v, lp, theta);
							
							// set the estimate for this voxel
							estimate.set(x+obs.offx(), y+obs.offy(), z+obs.offz(), v+obs.offv(), get_estimate_voxel(lp));
						}
		}

		// remap the estimate to the original label space
		obs.remap_estimate(estimate);
		
		return(estimate);
	}
}
