package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;

public class ObservationVolumePartial extends ObservationBase {
	
	protected SparseMatrix6D obs;
	private ObservationBase tmpobs;
	private float [][][][] target;
	private float [][][][] df;
	private float [][][][] target_means;
	private float [][][][] target_stds;
	private float [][][][] atlas_means;
	private float [][][][] atlas_stds;
	private boolean [][][][] tcon;
	private IntensityNormalizer normalizer;
	private boolean use_atlas_selection = false;
	private boolean use_patch_selection = true;
	private boolean use_intensity_normalization;

	// variables that may or may not be initialized
	private SparseMatrix5D weighted_prior;
	private boolean [][][][][] atlas_selection;
	
	// define some constants
	public static final int PATCH_SELECTION_TYPE_SSIM = 0;
	public static final int PATCH_SELECTION_TYPE_JACCARD = 1;
	public static final int PATCH_SELECTION_TYPE_NONE = 2;
	public static final int WEIGHT_TYPE_LNCC = 0;
	public static final int WEIGHT_TYPE_MSD = 1;
	public static final int WEIGHT_TYPE_MIXED = 2;
	
	public ObservationVolumePartial (ParamVolume targetim,
									 ParamVolumeCollection obsvols,
									 ParamVolumeCollection imsvols,
									 int weight_type,
									 int [] sv_in,
									 int [] pv_in,
									 float [] sp_stdevs,
									 float int_stdev,
									 int selection_type,
									 float sel_thresh,
									 int num_keep,
									 float cons_thresh_in,
									 int probability_index,
									 boolean use_intensity_normalization_in) {
		super(cons_thresh_in, sv_in, pv_in);
		setLabel("ObservationVolumePartial");
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing Label Observations (Partial Type)  +++");

		// initialize the observation volumes
		if (probability_index > 0)
			tmpobs = new ObservationVolumeProbability(obsvols, cons_thresh_in, probability_index, sv_in, pv_in, true);
		else
			tmpobs = new ObservationVolume(obsvols, cons_thresh_in, sv_in, pv_in, true);
		this.copy_data(tmpobs);
		
		use_intensity_normalization = use_intensity_normalization_in;
		
		// set the Non-Local Correspondence Parameters
		for (int i = 0; i < 4; i++) {
			if (dimres[i] == 0 || sp_stdevs[i] == 0)
				sp_stdevs[i] = 2;
			else
				sp_stdevs[i] = (sp_stdevs[i] >= 0) ? sp_stdevs[i] / dimres[i] : -sp_stdevs[i];
		}
		float ff = (float) (1 / Math.pow(int_stdev, 2));
		
		// determine whether or not we're going to use the selection threshold
		int num_vox_patch = (2*pv[0]+1) * (2*pv[1]+1) * (2*pv[2]+1) * (2*pv[3]+1);
		if (num_vox_patch < 25) {
			JistLogger.logOutput(JistLogger.INFO, "-> Not enough voxels in patch -- forcing parameters.");
			weight_type = WEIGHT_TYPE_MSD;
			selection_type = PATCH_SELECTION_TYPE_NONE;
		}
		
		// determine whether or not we're going to keep certain patches
		int num_vox_search = (2*sv[0]+1) * (2*sv[1]+1) * (2*sv[2]+1) * (2*sv[3]+1);
		if (num_keep > num_vox_search || num_keep <= 0)
			num_keep = -1;
		
		if (selection_type == PATCH_SELECTION_TYPE_NONE) {
			use_patch_selection = false;
			num_keep = -1;
		}
		
		// set the weight string
		String weight_str = "";
		if (weight_type == WEIGHT_TYPE_LNCC)
			weight_str = "LNCC";
		else if (weight_type == WEIGHT_TYPE_MSD)
			weight_str = "MSD";
		else if (weight_type == WEIGHT_TYPE_MIXED)
			weight_str = "Mixed";
		
		// set the selection string
		String selection_str = "";
		if (selection_type == PATCH_SELECTION_TYPE_SSIM)
			selection_str = "SSIM";
		else if (selection_type == PATCH_SELECTION_TYPE_JACCARD)
			selection_str = "Jaccard";
		else if (selection_type == PATCH_SELECTION_TYPE_NONE)
			selection_str = "None";
		
		// set the consensus values
		determine_initial_consensus();
		set_atlas_consensus();
		
		// create the intensity normalizer
		if (use_intensity_normalization)
			normalizer = new IntensityNormalizer(tmpobs);
				
		// print out some information
		JistLogger.logOutput(JistLogger.INFO, "-> Determined the following information");
		JistLogger.logOutput(JistLogger.INFO, "Number of Raters: " + num_raters);
		JistLogger.logOutput(JistLogger.INFO, "Number of Labels: " + num_labels);
		JistLogger.logOutput(JistLogger.INFO, String.format("Original Dimensions (voxels): [%d %d %d %d]", orig_dims[0], orig_dims[1], orig_dims[2], orig_dims[3]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Cropped Dimensions (voxels): [%d %d %d %d]", dims[0], dims[1], dims[2], dims[3]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Voxel Resolutions (Image Units): [%f %f %f %f]", dimres[0], dimres[1], dimres[2], dimres[3]));
		JistLogger.logOutput(JistLogger.INFO, "Ignore Consensus: " + ignore_consensus());
		JistLogger.logOutput(JistLogger.INFO, "Difference Metric: " + weight_str);
		JistLogger.logOutput(JistLogger.INFO, String.format("Search Volume Radius (voxels): [%d %d %d %d]", sv[0], sv[1], sv[2], sv[3]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Search Volume Dimensions (voxels): [%d %d %d %d]", 2*sv[0]+1, 2*sv[1]+1, 2*sv[2]+1, 2*sv[3]+1));
		JistLogger.logOutput(JistLogger.INFO, String.format("Patch Volume Radius (voxels): [%d %d %d %d]", pv[0], pv[1], pv[2], pv[3]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Patch Volume Dimensions (voxels): [%d %d %d %d]", 2*pv[0]+1, 2*pv[1]+1, 2*pv[2]+1, 2*pv[3]+1));
		JistLogger.logOutput(JistLogger.INFO, String.format("Search Volume Standard Deviations (voxels): [%f %f %f %f]", sp_stdevs[0], sp_stdevs[1], sp_stdevs[2], sp_stdevs[3]));
		JistLogger.logOutput(JistLogger.INFO, "Intensity Standard Deviation: " + int_stdev);
		JistLogger.logOutput(JistLogger.INFO, "Using Intensity Normalization: " + use_intensity_normalization);
		JistLogger.logOutput(JistLogger.INFO, "Selection Type: " + selection_str);
		JistLogger.logOutput(JistLogger.INFO, "Using Patch Selection: " + use_patch_selection);
		JistLogger.logOutput(JistLogger.INFO, "Patch Selection Threshold: " + sel_thresh);
		JistLogger.logOutput(JistLogger.INFO, "Number of Patches to Keep: " + num_keep);

		// allocate some space
		target = new float [dims[0]][dims[1]][dims[2]][dims[3]];
		float [][][][] im = new float [dims[0]][dims[1]][dims[2]][dims[3]];
		df = new float [2*sv[0]+1][2*sv[1]+1][2*sv[2]+1][2*sv[3]+1];
		obs = new SparseMatrix6D(dims[0], dims[1], dims[2], dims[3], num_raters, num_labels);
		if (use_patch_selection || weight_type == WEIGHT_TYPE_LNCC || weight_type == WEIGHT_TYPE_MIXED) {
			target_means = new float [dims[0]][dims[1]][dims[2]][dims[3]];
			target_stds = new float [dims[0]][dims[1]][dims[2]][dims[3]];
			atlas_means = new float [dims[0]][dims[1]][dims[2]][dims[3]];
			atlas_stds = new float [dims[0]][dims[1]][dims[2]][dims[3]];
		}
		
		// set the distance factor
		set_dist_factor(sp_stdevs);
		
		// load the target volume
		JistLogger.logOutput(JistLogger.INFO, "-> Loading the target image");
		load_intensity_image(targetim, target);
		if (use_intensity_normalization)
			normalizer.set_image_unit_normal(target);
		if (use_patch_selection || weight_type == WEIGHT_TYPE_LNCC || weight_type == WEIGHT_TYPE_MIXED)
			set_LNCC_parameters(target, target_means, target_stds);
		
		// run the Non-Local Correspondence Model on each atlas image
		for (int i = 0; i < num_raters; i++) {
			
			// load the image volume
			JistLogger.logOutput(JistLogger.INFO, "");
			JistLogger.logOutput(JistLogger.INFO, String.format("*** Processing Atlas Image Volume: %d of %d ***", i+1, num_raters));
			
			JistLogger.logOutput(JistLogger.INFO, "-> Loading the atlas image");
			load_intensity_image(imsvols.getParamVolume(i), im);
			if (use_intensity_normalization)
				normalizer.regress_image(im, target);
			if (use_patch_selection || weight_type == WEIGHT_TYPE_LNCC || weight_type == WEIGHT_TYPE_MIXED)
				set_LNCC_parameters(im, atlas_means, atlas_stds);
			
			// iterate over the patch
			JistLogger.logOutput(JistLogger.INFO, "-> Calculating the Non-Local Correspondence Model");
			for (int x = 0; x < dims[0]; x++) {
				print_status(x, dims[0]);
				for (int y = 0; y < dims[1]; y++)
					for (int z = 0; z < dims[2]; z++)
						for (int v = 0; v < dims[3]; v++)
							if (!is_consensus(x, y, z, v))
								if (num_keep == -1)
									iterate_patch(target, im, weight_type, ff, sel_thresh,
												  selection_type, x, y, z, v, i);
								else
									iterate_patch_keep(target, im, weight_type, ff, sel_thresh,
													   selection_type, num_keep, x, y, z, v, i);
			}
		}
		JistLogger.logOutput(JistLogger.INFO, "");
		
		// free the temporary stuff
		tmpobs = null;
		df = null;
		target_means = null;
		target_stds = null;
		atlas_means = null;
		atlas_stds = null;
		tcon = null;
		im = null;
		
		// we're done.
		JistLogger.logOutput(JistLogger.INFO, "-> Finished in initialization of ObservationVolumePartial");
		JistLogger.logOutput(JistLogger.INFO, "");
	}
	
	// consensus-based functions
	private void determine_initial_consensus() {
		
		boolean [][][][] tcon = new boolean [dims[0]][dims[1]][dims[2]][dims[3]];
		consensus = new boolean[dims[0]][dims[1]][dims[2]][dims[3]];
		cons_estimate = new short[dims[0]][dims[1]][dims[2]][dims[3]];
		
		float [] lp = new float [num_labels];
		float maxval, frac_con;
		short labelest = 0;
		int numcon;
		
		JistLogger.logOutput(JistLogger.INFO, "-> Estimating initial consensus voxels");
		// first set the initial consensus
		numcon = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++) {
						maxval = 0;
						labelest = 0;
						
						Arrays.fill(lp, 0f);
						for (int j = 0; j < num_raters; j++)
							tmpobs.add_to_probabilities(x, y, z, v, j, lp, 1f);
						
						for (short s = 0; s < num_labels; s++)
							if (lp[s] > maxval) {
								maxval = lp[s];
								labelest = s;
							}
						
						// set the consensus estimate
						cons_estimate[x][y][z][v] = labelest;
						
						if (ignore_consensus())
							tcon[x][y][z][v] = maxval >= ((float)num_raters - 0.0001f);
						else
							tcon[x][y][z][v] = false;
						
						if (tcon[x][y][z][v]) numcon++;
					}
		frac_con = ((float)numcon) / (dims[0]*dims[1]*dims[2]*dims[3]);
		JistLogger.logOutput(JistLogger.INFO, "Initial Fraction Consensus: " + frac_con);
		
		// second dilate intelligently
		numcon = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						if (!tcon[x][y][z][v]) {
							
							consensus[x][y][z][v] = false;
							
						} else {
							
							// set the current region of interest
							int xl = Math.max(x - sv[0], 0);
							int xh = Math.min(x + sv[0], dims[0]-1);
							int yl = Math.max(y - sv[1], 0);
							int yh = Math.min(y + sv[1], dims[1]-1);
							int zl = Math.max(z - sv[2], 0);
							int zh = Math.min(z + sv[2], dims[2]-1);
							int vl = Math.max(v - sv[3], 0);
							int vh = Math.min(v + sv[3], dims[3]-1);
							
							float numtotal = 0;
							float numfalse = 0;
							for (int xi = xl; xi <= xh; xi++)
								for (int yi = yl; yi <= yh; yi++)
									for (int zi = zl; zi <= zh; zi++)
										for (int vi = vl; vi <= vh; vi++) {
											numtotal++;
											if (tcon[xi][yi][zi][vi] == false)
												numfalse++;
										}
							
							if (numfalse / numtotal > 0.5)
								consensus[x][y][z][v] = false;
							else
								consensus[x][y][z][v] = true;
							
							if (consensus[x][y][z][v]) numcon++;
						}
		frac_con = ((float)numcon) / (dims[0]*dims[1]*dims[2]*dims[3]);
		JistLogger.logOutput(JistLogger.INFO, "Final Fraction Consensus: " + frac_con);	
	}

	private void set_atlas_consensus() {
		tcon = new boolean [dims[0]][dims[1]][dims[2]][dims[3]];
		
		// initialize
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++) 
						tcon[x][y][z][v] = true;
		
		// set everything within the search regions to false
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++) 
						if (consensus[x][y][z][v] == false) {
							
							// set the current region of interest
							int xl = Math.max(x - sv[0] - pv[0], 0);
							int xh = Math.min(x + sv[0] + pv[0], dims[0]-1);
							int yl = Math.max(y - sv[1] - pv[1], 0);
							int yh = Math.min(y + sv[1] + pv[1], dims[1]-1);
							int zl = Math.max(z - sv[2] - pv[2], 0);
							int zh = Math.min(z + sv[2] + pv[2], dims[2]-1);
							int vl = Math.max(v - sv[3] - pv[3], 0);
							int vh = Math.min(v + sv[3] + pv[3], dims[3]-1);
							
							for (int xi = xl; xi <= xh; xi++)
								for (int yi = yl; yi <= yh; yi++)
									for (int zi = zl; zi <= zh; zi++)
										for (int vi = vl; vi <= vh; vi++)
											tcon[xi][yi][zi][vi] = false;
						}
		
		int numcon = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						if (tcon[x][y][z][v])
							numcon++;
		
		float frac_con = ((float)numcon) / (dims[0]*dims[1]*dims[2]*dims[3]);
		JistLogger.logOutput(JistLogger.INFO, "LNCC Fraction Consensus: " + frac_con);
	}
	
	// pre-processing functions
	private void set_LNCC_parameters(float [][][][] im,
									 float [][][][] means,
									 float [][][][] stds) {
		
		JistLogger.logOutput(JistLogger.INFO, "-> Calculating image LNCC parameters");
		
		float val, meanval, stdval, count;
		
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						if (!tcon[x][y][z][v]) {
							
							meanval = 0; stdval = 0; count = 0;
							
							// set the current region of interest
							int xl = Math.max(x-pv[0], 0);
							int xh = Math.min(x+pv[0], dimx()-1);
							int yl = Math.max(y-pv[1], 0);
							int yh = Math.min(y+pv[1], dimy()-1);
							int zl = Math.max(z-pv[2], 0);
							int zh = Math.min(z+pv[2], dimz()-1);
							int vl = Math.max(v-pv[3], 0);
							int vh = Math.min(v+pv[3], dimv()-1);
							
							// set the mean
							for (int xi = xl; xi <= xh; xi++)
								for (int yi = yl; yi <= yh; yi++)
									for (int zi = zl; zi <= zh; zi++)
										for (int vi = vl; vi <= vh; vi++) {
											meanval += im[xi][yi][zi][vi];
											count++;
										}
							meanval /= count;
							
							// calculate the standard deviation
							for (int xi = xl; xi <= xh; xi++)
								for (int yi = yl; yi <= yh; yi++)
									for (int zi = zl; zi <= zh; zi++)
										for (int vi = vl; vi <= vh; vi++) {
											val = im[xi][yi][zi][vi] - meanval;
											stdval += val*val;
										}
							stdval = (float)Math.sqrt(stdval / count);
							
							means[x][y][z][v] = meanval;
							if (stdval == 0)
								stdval = 0.001f;
							stds[x][y][z][v] = stdval; 
							
						}
	}

	private void load_intensity_image(ParamVolume paramim,
									  float [][][][] im) {
		
		int orig_level = prefs.getDebugLevel();
		prefs.setDebugLevel(JistLogger.SEVERE);
		ImageData img = paramim.getImageData(true);
		// make sure that the dimensions match
		if (orig_dims[0] != Math.max(img.getRows(), 1) ||
			orig_dims[1] != Math.max(img.getCols(), 1) ||
			orig_dims[2] != Math.max(img.getSlices(), 1) ||
			orig_dims[3] != Math.max(img.getComponents(), 1)) {
			JistLogger.logOutput(JistLogger.SEVERE, "Error: Target Image Dimensions do not match");
			throw new RuntimeException("Error: Rater Dimensions do not match");
		}
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					for (int v = 0; v < dims[3]; v++)
						im[x][y][z][v] = img.getFloat(x + offx(),
													  y + offy(),
													  z + offz(),
													  v + offv());
		paramim.dispose();
		prefs.setDebugLevel(orig_level);
	}
	
	private void set_dist_factor(float [] sp_stdevs) {
		
		float ffx = (float) (-1 / Math.pow(sp_stdevs[0], 2));
		float ffy = (float) (-1 / Math.pow(sp_stdevs[1], 2));
		float ffz = (float) (-1 / Math.pow(sp_stdevs[2], 2));
		float ffv = (float) (-1 / Math.pow(sp_stdevs[3], 2));
		
		int svxf = 2*sv[0]+1;
		int svyf = 2*sv[1]+1;
		int svzf = 2*sv[2]+1;
		int svvf = 2*sv[3]+1;
		
		for (int x = 0; x < svxf; x++)
			for (int y = 0; y < svyf; y++)
				for (int z = 0; z < svzf; z++)
					for (int v = 0; v < svvf; v++)
						df[x][y][z][v] = (float)
							(Math.exp(ffx * Math.pow((float)(sv[0]-x), 2)) *
							 Math.exp(ffy * Math.pow((float)(sv[1]-y), 2)) *
							 Math.exp(ffz * Math.pow((float)(sv[2]-z), 2)) *
							 Math.exp(ffv * Math.pow((float)(sv[3]-v), 2)));
	}
	
	private void iterate_patch(float [][][][] target,
							   float [][][][] im,
							   int weight_type,
							   float ff,
							   float sel_thresh,
							   int selection_type,
							   int x,
							   int y,
							   int z,
							   int v,
							   int j) {
		
		float [] lp = new float [num_labels];
		float diff, dfact, val, sel_val;
		boolean use_patch = true;
		sel_val = 0;
		
		int xl0 = x - sv[0];
		int xh0 = x + sv[0];
		int yl0 = y - sv[1];
		int yh0 = y + sv[1];
		int zl0 = z - sv[2];
		int zh0 = z + sv[2];
		int vl0 = v - sv[3];
		int vh0 = v + sv[3];
		
		// set the current region of interest
		int xl = Math.max(xl0, 0);
		int xh = Math.min(xh0, dimx()-1);
		int yl = Math.max(yl0, 0);
		int yh = Math.min(yh0, dimy()-1);
		int zl = Math.max(zl0, 0);
		int zh = Math.min(zh0, dimz()-1);
		int vl = Math.max(vl0, 0);
		int vh = Math.min(vh0, dimv()-1);
				
		// iterate over the search volume
		for (int xi = xl; xi <= xh; xi++)
			for (int yi = yl; yi <= yh; yi++)
				for (int zi = zl; zi <= zh; zi++)
					for (int vi = vl; vi <= vh; vi++) {
							
						// if we're using patch selection
						if (use_patch_selection) {
							
							// get the value for this patch
							if (sel_thresh <= 0)
								sel_val = 1;
							else if (selection_type == PATCH_SELECTION_TYPE_SSIM)
								sel_val = get_ssim(x, y, z, v,
												   xi, yi, zi, vi);
							else if (selection_type == PATCH_SELECTION_TYPE_JACCARD)
								sel_val = get_jaccard(x, y, z, v, 
													  xi, yi, zi, vi);
							
							// add it to the sorted list of kept indices
							use_patch = sel_val >= sel_thresh;
						}
						
						// if we're using this patch, process it
						if (use_patch) {
							dfact = df[xi-xl0][yi-yl0][zi-zl0][vi-vl0];
							if (weight_type == WEIGHT_TYPE_LNCC)
								diff = get_LNCC_diff(x, y, z, v, 
													 xi, yi, zi, vi, 
													 target, im);
							else if (weight_type == WEIGHT_TYPE_MSD)
								diff = get_MSD_diff(x, y, z, v, 
													 xi, yi, zi, vi, 
													 target, im);
							else if (weight_type == WEIGHT_TYPE_MIXED)
								diff = get_mixed_diff(x, y, z, v, 
										 			  xi, yi, zi, vi, 
										 			  target, im);
							else
								throw new RuntimeException("Error: Invalid Weight Type");
								
							val = (float) (dfact * Math.exp(ff * diff));
							tmpobs.add_to_probabilities(xi, yi, zi, vi, j, lp, val);
						}
					}
		
		// initialize the voxel
		obs.init_voxel(x, y, z, v, j, lp, tmpobs.get(x, y, z, v, j));
	}
	
	private void iterate_patch_keep(float [][][][] target,
							   		float [][][][] im,
							   		int weight_type,
							   		float ff,
							   		float sel_thresh,
							   		int selection_type,
							   		int num_keep,
							   		int x,
							   		int y,
							   		int z,
							   		int v,
							   		int j) {
		
		// key: selection type
		// 0: SSIM
		// 1: Jaccard
		// 2: None
		
		float [] lp = new float [num_labels];
		float diff, dfact, val, sel_val;
		int [][] keeploc = new int [num_keep][4];
		float [] keepval = new float [num_keep];
		
		int xl0 = x - sv[0];
		int xh0 = x + sv[0];
		int yl0 = y - sv[1];
		int yh0 = y + sv[1];
		int zl0 = z - sv[2];
		int zh0 = z + sv[2];
		int vl0 = v - sv[3];
		int vh0 = v + sv[3];
		
		// set the current region of interest
		int xl = Math.max(xl0, 0);
		int xh = Math.min(xh0, dimx()-1);
		int yl = Math.max(yl0, 0);
		int yh = Math.min(yh0, dimy()-1);
		int zl = Math.max(zl0, 0);
		int zh = Math.min(zh0, dimz()-1);
		int vl = Math.max(vl0, 0);
		int vh = Math.min(vh0, dimv()-1);
		
		sel_val = 0;
		
		int [] currnum = new int [1];
		int [] currindex = new int[1];
						
		// find the locations to keep
		for (int xi = xl; xi <= xh; xi++)
			for (int yi = yl; yi <= yh; yi++)
				for (int zi = zl; zi <= zh; zi++)
					for (int vi = vl; vi <= vh; vi++) {
							
						// get the value for this patch
						if (selection_type == PATCH_SELECTION_TYPE_SSIM)
							sel_val = get_ssim(x, y, z, v,
											   xi, yi, zi, vi);
						else if (selection_type == PATCH_SELECTION_TYPE_JACCARD)
							sel_val = get_jaccard(x, y, z, v, 
												  xi, yi, zi, vi);
						else
							throw new RuntimeException("Error: Invalid Selection Type");
						
						// add it to the sorted list of kept indices
						if (sel_val >= sel_thresh)
							add_to_keepinfo(keeploc, keepval, num_keep, sel_val, currnum, currindex, xi, yi, zi, vi);
					}
	
		// iterate over the kept locations
		for (int i = 0; i < num_keep; i++)
			if (keepval[i] > 0) {
				
				// set the location
				int xi = keeploc[i][0];
				int yi = keeploc[i][1];
				int zi = keeploc[i][2];
				int vi = keeploc[i][3];
					
				// set the distance factor
				dfact = df[xi-xl0][yi-yl0][zi-zl0][vi-vl0];
				
				// set the similarity measure
				if (weight_type == WEIGHT_TYPE_LNCC)
					diff = get_LNCC_diff(x, y, z, v, 
										 xi, yi, zi, vi, 
										 target, im);
				else if (weight_type == WEIGHT_TYPE_MSD)
					diff = get_MSD_diff(x, y, z, v, 
										 xi, yi, zi, vi, 
										 target, im);
				else if (weight_type == WEIGHT_TYPE_MIXED)
					diff = get_mixed_diff(x, y, z, v, 
							 			  xi, yi, zi, vi, 
							 			  target, im);
				else
					throw new RuntimeException("Error: Invalid Weight Type");
				
				// add to the current probabilities
				val = (float) (dfact * Math.exp(ff * diff));
				tmpobs.add_to_probabilities(xi, yi, zi, vi, j, lp, val);
			}
		
		// initialize the voxel
		obs.init_voxel(x, y, z, v, j, lp, tmpobs.get(x, y, z, v, j));
	}
	
	private float get_LNCC_diff(int x,
								int y,
								int z,
								int v,
								int xi,
								int yi,
								int zi,
								int vi,
								float [][][][] target,
								float [][][][] im) {
		
		
		float diff = 0;
		float count = 0;
		float tval, ival;
		float tmean = target_means[x][y][z][v];
		float tstd = target_stds[x][y][z][v];
		float imean = atlas_means[xi][yi][zi][vi];
		float istd = atlas_stds[xi][yi][zi][vi];
		
		int xm = Math.min(Math.min(x, xi), dims[0] - Math.max(x, xi) - 1);
		int ym = Math.min(Math.min(y, yi), dims[1] - Math.max(y, yi) - 1);
		int zm = Math.min(Math.min(z, zi), dims[2] - Math.max(z, zi) - 1);
		int vm = Math.min(Math.min(v, vi), dims[3] - Math.max(v, vi) - 1);
		
		int xs = (xm < pv[0]) ? xm : pv[0];
		int ys = (ym < pv[1]) ? ym : pv[1];
		int zs = (zm < pv[2]) ? zm : pv[2];
		int vs = (vm < pv[3]) ? vm : pv[3];
		
		for (int xp = -xs; xp <= xs; xp++)
			for (int yp = -ys; yp <= ys; yp++)
				for (int zp = -zs; zp <= zs; zp++)
					for (int vp = -vs; vp <= vs; vp++) {
						tval = target[x+xp][y+yp][z+zp][v+vp];
						ival = im[xi+xp][yi+yp][zi+zp][vi+vp];
						diff += (tval - tmean)*(ival - imean);
						count++;
					}
		diff /= count * (tstd * istd);
		
		// correct some minor errors
		if (diff < -1)
			diff = -1;
		if (diff > 1)
			diff = 1;
			
		return(diff-1);
	}
	
	private float get_MSD_diff(int x,
								int y,
								int z,
								int v,
								int xi,
								int yi,
								int zi,
								int vi,
								float [][][][] target,
								float [][][][] im) {
		float diff = 0;
		float tval, ival;
		
		int xm = Math.min(Math.min(x, xi), dims[0] - Math.max(x, xi) - 1);
		int ym = Math.min(Math.min(y, yi), dims[1] - Math.max(y, yi) - 1);
		int zm = Math.min(Math.min(z, zi), dims[2] - Math.max(z, zi) - 1);
		int vm = Math.min(Math.min(v, vi), dims[3] - Math.max(v, vi) - 1);
		
		int xs = (xm < pv[0]) ? xm : pv[0];
		int ys = (ym < pv[1]) ? ym : pv[1];
		int zs = (zm < pv[2]) ? zm : pv[2];
		int vs = (vm < pv[3]) ? vm : pv[3];
		
		for (int xp = -xs; xp <= xs; xp++)
			for (int yp = -ys; yp <= ys; yp++)
				for (int zp = -zs; zp <= zs; zp++)
					for (int vp = -vs; vp <= vs; vp++) {
						tval = target[x+xp][y+yp][z+zp][v+vp];
						ival = im[xi+xp][yi+yp][zi+zp][vi+vp];
						diff += (tval - ival) * (tval - ival);
					}
		diff /= (2*xs+1) * (2*ys+1) * (2*zs+1) * (2*vs+1); 
			
		return(-diff);
		
	}
	
	private float get_mixed_diff(int x,
								int y,
								int z,
								int v,
								int xi,
								int yi,
								int zi,
								int vi,
								float [][][][] target,
								float [][][][] im) {
		
		
		float diff_LNCC = 0;
		float diff_MSD = 0;
		float diff;
		float count = 0;
		float tval, ival;
		float tmean = target_means[x][y][z][v];
		float tstd = target_stds[x][y][z][v];
		float imean = atlas_means[xi][yi][zi][vi];
		float istd = atlas_stds[xi][yi][zi][vi];
		
		int xm = Math.min(Math.min(x, xi), dims[0] - Math.max(x, xi) - 1);
		int ym = Math.min(Math.min(y, yi), dims[1] - Math.max(y, yi) - 1);
		int zm = Math.min(Math.min(z, zi), dims[2] - Math.max(z, zi) - 1);
		int vm = Math.min(Math.min(v, vi), dims[3] - Math.max(v, vi) - 1);
		
		int xs = (xm < pv[0]) ? xm : pv[0];
		int ys = (ym < pv[1]) ? ym : pv[1];
		int zs = (zm < pv[2]) ? zm : pv[2];
		int vs = (vm < pv[3]) ? vm : pv[3];
		
		for (int xp = -xs; xp <= xs; xp++)
			for (int yp = -ys; yp <= ys; yp++)
				for (int zp = -zs; zp <= zs; zp++)
					for (int vp = -vs; vp <= vs; vp++) {
						tval = target[x+xp][y+yp][z+zp][v+vp];
						ival = im[xi+xp][yi+yp][zi+zp][vi+vp];
						diff_LNCC += (tval - tmean)*(ival - imean);
						diff_MSD += (tval - ival) * (tval - ival);
						count++;
					}
		diff_MSD /= count;
		diff_LNCC /= count * (tstd * istd);

		
		// correct some minor errors
		if (diff_LNCC < -1)
			diff_LNCC = -1;
		if (diff_LNCC > 1)
			diff_LNCC = 1;
		
		// make LNCC in the range of 0:good -2:bad
		diff_LNCC -= 1;
		
		diff = diff_LNCC * diff_MSD; 
			
		return(diff);
	}
		
	private float get_jaccard(int x,
							  int y,
							  int z,
							  int v,
							  int xi,
							  int yi,
							  int zi,
							  int vi) {
		
		float t_m = target_means[x][y][z][v];
		float t_s = target_stds[x][y][z][v];
		float a_m = atlas_means[xi][yi][zi][vi];
		float a_s = atlas_stds[xi][yi][zi][vi];
		
		float num_stdev = 2;
		float tl = t_m - num_stdev * t_s;
		float tr = t_m + num_stdev * t_s;
		float al = a_m - num_stdev * a_s;
		float ar = a_m + num_stdev * a_s;
		
		float liu = Math.min(tl, al);
		float riu = Math.max(tr, ar);
		float lii = Math.max(tl, al);
		float rii = Math.min(tr, ar);
		
		if (lii > rii)
			lii = rii;
		
		return((rii - lii) / (riu - liu));
		
	}
		
	private float get_ssim(int x,
						   int y,
						   int z,
						   int v,
						   int xi,
						   int yi,
						   int zi,
						   int vi) {
		
		float t_m = target_means[x][y][z][v];
		float t_s = target_stds[x][y][z][v];
		float a_m = atlas_means[xi][yi][zi][vi];
		float a_s = atlas_stds[xi][yi][zi][vi];
		
		float ssim = Math.abs(((2*t_m*a_m)/(t_m*t_m + a_m*a_m))*
							  ((2*t_s*a_s)/(t_s*t_s + a_s*a_s)));
		return(ssim);
	}
	
	private void add_to_keepinfo(int [][] keeploc,
								 float [] keepval,
								 int num_keep,
								 float val,
								 int [] currnum,
								 int [] currindex,
								 int x,
								 int y,
								 int z,
								 int v) {
		
		// if we have less than the number we want to keep
		int num = currnum[0];
		float minval = keepval[currindex[0]];
		
		// we currently have less than we can keep case
		if (currnum[0] <= num_keep-1) {
			keepval[num] = val;
			keeploc[num][0] = x;
			keeploc[num][1] = y;
			keeploc[num][2] = z;
			keeploc[num][3] = v;
			
		// we have more than we can keep, and val > current minimum
		} else if (val > minval) {
			keepval[currindex[0]] = val;
			keeploc[currindex[0]][0] = x;
			keeploc[currindex[0]][1] = y;
			keeploc[currindex[0]][2] = z;
			keeploc[currindex[0]][3] = v;
		}
		
		currnum[0]++;
		
		// if this is the first time we are beyond the end of the list
		if (currnum[0] >= num_keep && val > minval) {
			
			float currmin = 10000;
			
			// set the index of the current minimum
			for (int i = 0; i < num_keep; i++)
				if (keepval[i] < currmin) {
					currmin = keepval[i];
					currindex[0] = i;
				}
		}
				
	}
		
	// public functions
	public void normalize_all() {
		JistLogger.logOutput(JistLogger.INFO, "-> Normalizing All Observations");
		for (int j = 0; j < num_raters; j++)
			for (int x = 0; x < dims[0]; x++)
				for (int y = 0; y < dims[1]; y++)
					for (int z = 0; z < dims[2]; z++)
						for (int v = 0; v < dims[3]; v++)
							if (!is_consensus(x, y, z, v))
								obs.normalize(x, y, z, v, j);
	}
	public void create_atlas_selection_matrix(float global_thresh,
											  float local_thresh) {
		// allocate the weighted prior
		JistLogger.logOutput(JistLogger.INFO, "-> Constructing Local Selection Matrix");
		JistLogger.logOutput(JistLogger.INFO, String.format("Global Atlas Selection Threshold: %f", global_thresh));
		JistLogger.logOutput(JistLogger.INFO, String.format("Local Atlas Selection Threshold: %f", local_thresh));
		use_atlas_selection = true;
		atlas_selection = new boolean[dims[0]][dims[1]][dims[2]][dims[3]][num_raters];
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						for (int j = 0; j < num_raters; j++)
							atlas_selection[x][y][z][v][j] = true;
		
		float [] local_ratervals = new float [num_raters];
		float [] global_ratervals = new float [num_raters];
		int [] ratercounts = new int [num_raters];

		/*
		 * First, do global atlas selection
		 */
		if (global_thresh > 0) {
			for (int x = 0; x < dims[0]; x++)
				for (int y = 0; y < dims[1]; y++)
					for (int z = 0; z < dims[2]; z++) 
						for (int v = 0; v < dims[3]; v++)
							if (!is_consensus(x, y, z, v)) {
								
								// add up the votes from each rater
								float ratersum = 0;
								for (int j = 0; j < num_raters; j++) {
									local_ratervals[j] = 0;
									short [] obslabels = obs.get_all_labels(x, y, z, v, j);
									float [] obsvals = obs.get_all_vals(x, y, z, v, j);
									for (int i = 0; i < obslabels.length; i++) {
										local_ratervals[j] += obsvals[i];
									}
									ratersum += local_ratervals[j];
								}
								
								// keep running track of the global fraction contributed by rater j
								for (int j = 0; j < num_raters; j++)
									global_ratervals[j] += local_ratervals[j] / ratersum;
							}
			
			// normalize the values
			float ratersum = 0;
			for (int j = 0; j < num_raters; j++)
				ratersum += global_ratervals[j];
			for (int j = 0; j < num_raters; j++)
				global_ratervals[j] /= ratersum;

			// remove raters that are not selected globally
			for (int j = 0; j < num_raters; j++)
				if (global_ratervals[j] < (global_thresh / num_raters))
					for (int x = 0; x < dims[0]; x++)
						for (int y = 0; y < dims[1]; y++)
							for (int z = 0; z < dims[2]; z++) 
								for (int v = 0; v < dims[3]; v++)
									atlas_selection[x][y][z][v][j] = false;
		}
		

		/*
		 * second, do local atlas selection
		 */
		if (local_thresh > 0) {
			for (int x = 0; x < dims[0]; x++)
				for (int y = 0; y < dims[1]; y++)
					for (int z = 0; z < dims[2]; z++) 
						for (int v = 0; v < dims[3]; v++)
							if (!is_consensus(x, y, z, v)) {
								
								// add up the votes from each rater
								float ratersum = 0;
								int num_keep = 0;
								for (int j = 0; j < num_raters; j++) {
									local_ratervals[j] = 0;
									if (atlas_selection[x][y][z][v][j] == true) {
										num_keep++;
										short [] obslabels = obs.get_all_labels(x, y, z, v, j);
										float [] obsvals = obs.get_all_vals(x, y, z, v, j);
										for (int i = 0; i < obslabels.length; i++) {
											local_ratervals[j] += obsvals[i];
										}
										ratersum += local_ratervals[j];
									}
								}
								
								// set the fraction contributed by this rater
								for (int j = 0; j < num_raters; j++)
									if (atlas_selection[x][y][z][v][j] == true) {
										if ((local_ratervals[j]/ratersum) < (local_thresh/num_keep))
											atlas_selection[x][y][z][v][j] = false;
									}
										
							}
		}
		
		// calculate the fraction observed for each rater
		int count = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						if (!is_consensus(x, y, z, v)) {
							count++;
							for (int j = 0; j < num_raters; j++)
								if (atlas_selection[x][y][z][v][j])
									ratercounts[j]++;
						}
		
		for (int j = 0; j < num_raters; j++) {
			float frac_used = (float)ratercounts[j] / (float)count;
			JistLogger.logOutput(JistLogger.INFO, String.format("-> Rater %02d: %f", j, frac_used));
		}
	}
	public void create_weighted_prior() {
		
		// allocate the weighted prior
		JistLogger.logOutput(JistLogger.INFO, "-> Constructing Weighted Voxelwise Prior");
		use_weighted_prior = true;
		weighted_prior = new SparseMatrix5D(dims[0], dims[1], dims[2], dims[3], num_labels);
		
		float [] lp = new float [num_labels];

		// iterate over every voxel
		for (int x = 0; x < dims[0]; x++) {
			print_status(x, dims[0]);
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						if (!is_consensus(x, y, z, v)) {
							
							// initialize the label probabilities
							Arrays.fill(lp, 0f);
							
							// add up the votes from each rater
							for (int j = 0; j < num_raters; j++)
								if (get_local_selection(x, y, z, v, j)) {
									short [] obslabels = obs.get_all_labels(x, y, z, v, j);
									float [] obsvals = obs.get_all_vals(x, y, z, v, j);
									for (int i = 0; i < obslabels.length; i++)
										lp[obslabels[i]] += obsvals[i];
								}
							
							weighted_prior.init_voxel(x, y, z, v, lp);
							weighted_prior.normalize(x, y, z, v);
						}
		}
	}
	public void get_weighted_prior(int x, int y, int z, int v, float [] lp) {
		
		if (!use_weighted_prior)
			throw new RuntimeException("Trying to use uninitialized weighted prior");
	
		// reset the probabilities to zero
		Arrays.fill(lp, 0f);
		
		// set the probabilities from the weighted prior
		short [] priorlabels = weighted_prior.get_all_labels(x, y, z, v);
		float [] priorvals = weighted_prior.get_all_vals(x, y, z, v);
		for (int l = 0; l < priorlabels.length; l++)
			lp[priorlabels[l]] = priorvals[l];
	}
	public boolean get_local_selection(int x, int y, int z, int v, int j) {
		if (use_atlas_selection) return(atlas_selection[x][y][z][v][j]);
		return(true);
	}
	public short get(int x, int y, int z, int v, int j) { return(obs.get_max_label(x, y, z, v, j)); }
	public float get_val(int x, int y, int z, int v, int j) {
		short label = obs.get_max_label(x, y, z, v, j);
		return(obs.get_val(x, y, z, v, j, label)); 
	}
	public short [] get_all(int x, int y, int z, int v, int j) { return obs.get_all_labels(x, y, z, v, j); }
	public float [] get_all_vals(int x, int y, int z, int v, int j) { return obs.get_all_vals(x, y, z, v, j); }
	public void free_obs(int x, int y, int z, int v) { obs.free(x, y, z, v); }
}
