package edu.jhu.ece.iacl.algorithms.manual_label.staple;

/**
 * 
 * @author John Bogovic
 * @date 5/31/2008
 * 
 * Simultaneous Truth and Performance Level Estimation (STAPLE)
 * 
 * Warfield, Zou, and Wells, "Simultaneous Truth and Performace Level Estimation (STAPLE):
 * An Algorithm for the Validation of Image Segmentation," 
 * IEEE Trans. Medical Imaging vol. 23, no. 7, 2004
 */

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

import edu.jhu.ece.iacl.io.CubicVolumeReaderWriter;
import edu.jhu.ece.iacl.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.structures.geom.GridPt;
import edu.jhu.ece.iacl.structures.geom.GridPt.Connectivity;
import edu.jhu.ece.iacl.structures.image.ImageData;
import edu.jhu.ece.iacl.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.structures.image.ImageHeader;

public class STAPLEmulti extends AbstractCalculation{

	protected int[][][][] imagesArray;
	protected List<ImageData> images;
	protected float[][][][] truthImage;
	protected ArrayList<ImageData> truth;
	protected ImageDataInt hardseg;
	
	//variables for efficiency
	protected int MaxUniqueLabels=0;
	protected int[][][][] uniqueLabels;
	int labelshere = 0;
	protected HashMap<Integer,Integer> labelstoindex;
	
	protected PerformanceLevel pl; //performance level estimates
	protected ArrayList<Float> priors;
	protected ArrayList<Integer> labels;
	protected double convergesum;
	protected float normtrace;
	protected int maxiters = 1000;
	protected String dir;
	protected double eps=.00001;
	protected String initType;
	protected int rows, cols, slices;
	
	//MRF parameters
	private float beta =0f;
	private Connectivity connectivity;
	StapleMRFUtil mrfUtil;
	
	private boolean keepgoing = true;
	private boolean efficient = false;
	private CubicVolumeReaderWriter rw = CubicVolumeReaderWriter.getInstance();
	
	public STAPLEmulti(){
		super();
		setLabel("STAPLE");
	}
	public STAPLEmulti(int[][][][] img){
		super();
		setLabel("STAPLE");
		
		imagesArray=img;
		truthImage = new float[imagesArray.length][imagesArray[0].length][imagesArray[0][0].length][imagesArray[0][0][0].length];
		rows = imagesArray.length;
		cols = imagesArray[0].length;
		slices = imagesArray[0][0].length;
		getPriorProb();
	}
//	public STAPLEmulti(String[] filenames,int[] dim){
//		super();
//		setLabel("STAPLE");
//		
//		imagesArray=StapleReader.readImgs(filenames, dim);
//		truthImage = new float[imagesArray.length][imagesArray[0].length][imagesArray[0][0].length][imagesArray[0][0][0].length];
//		getPriorProb();
//	}
	public STAPLEmulti(List<ImageData> img){
		super();
		setLabel("STAPLE");
		images = img;
		rows = images.get(0).getRows();
		cols = images.get(0).getCols();
		slices = images.get(0).getSlices();
		if(verifySizes()){
			getPriorProb();
		}else{
			System.err.println("Rater images must have equal dimensions");
		}
		
		System.out.println("Priors: ");
		printArray(getPriorArray());
	}
	
	public void setmaxIters(int max){
		maxiters=max;
	}
	public void setEps(double eps){
		this.eps=eps;
	}
	
	public void setImages(int[][][][] img){
		imagesArray=img;
		getPriorProb();
	}
	public void setDir(String dir){
		this.dir = dir;
	}
	public void setInit(String init){
		initType=init;
	}
	public void setBeta(float beta){
		this.beta=beta;
	}
	public void distributeBeta(){
		if(connectivity==Connectivity.SIX){
			beta=beta/6;
		}else if(connectivity==Connectivity.EIGHTEEN){
			beta=beta/18;
		}else if(connectivity==Connectivity.EIGHTEEN){
			beta=beta/26;
		}else{
			System.out.println("Invalid Connectivity!");
		}
	}	
	public void setConnectivity(Connectivity c){
		connectivity = c;
	}
	public GridPt[] getNeighbors(int i, int j, int k){
		if(connectivity==Connectivity.SIX){
			return GridPt.onlyInBounds(GridPt.neighborhood6C(i, j, k), rows, cols, slices);
		}else if(connectivity==Connectivity.EIGHTEEN){
			return GridPt.onlyInBounds(GridPt.neighborhood18C(i, j, k), rows, cols, slices);
		}else if(connectivity==Connectivity.EIGHTEEN){
			return  GridPt.onlyInBounds(GridPt.neighborhood26C(i, j, k), rows, cols, slices);
		}else{
			System.out.println("Invalid Connectivity!");
			return null;
		}
		
	}
	public boolean verifySizes(){
		int rows = images.get(0).getRows();
		int cols = images.get(0).getCols();
		int slcs = images.get(0).getSlices();
		for(int i=1; i<images.size(); i++){
			if(images.get(i).getRows()!=rows || images.get(i).getCols()!=cols ||
			   images.get(i).getSlices()!=slcs){
				return false;
			}
		}
		return true;
	}
	
	public void findLabels(){
		
		labels = new ArrayList<Integer>();
		MaxUniqueLabels = 0;
		HashSet<Integer> labelshere = new HashSet<Integer>();
		labelstoindex = new HashMap<Integer,Integer>();
		if(imagesArray!=null){
			int num = 0;
			for(int i=0; i<imagesArray.length; i++){
				for(int j=0; j<imagesArray[0].length; j++){
					for(int k=0; k<imagesArray[0][0].length; k++){
						labelshere.clear();
						for(int l=0; l<imagesArray[0][0][0].length; l++){
							labelshere.add(imagesArray[i][j][k][l]);
							if(!labels.contains(imagesArray[i][j][k][l])){
								labels.add(imagesArray[i][j][k][l]);
								labelstoindex.put(imagesArray[i][j][k][l],num);
								num++;
							}
							if(labelshere.size()>MaxUniqueLabels){ MaxUniqueLabels = labelshere.size(); } 
						}
					}
				}
			}
		}else if(images!=null){
			int num =0;
			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						labelshere.clear();
						for(int l=0; l<images.size(); l++){
							labelshere.add(images.get(l).getInt(i,j,k));
							if(!labels.contains(images.get(l).getInt(i,j,k))){
								labels.add(images.get(l).getInt(i,j,k));
								labelstoindex.put(images.get(l).getInt(i,j,k),num);
								num++;
							}
							if(labelshere.size()>MaxUniqueLabels){ MaxUniqueLabels = labelshere.size(); }
						}
					}
				}
			}
		}else{
			System.err.println("No data!");
		}
		labels.trimToSize();
		System.out.println("Found Labels: ");
		System.out.println(labels);
		System.out.println("");
	}
	
//	public void initializeSingle(){
//		float init = 0.9999f;
//		try{
//			if(imagesArray!=null){
//				pl=new float[imagesArray[0][0][0].length][2];
//				for(int l=0; l<pl.length; l++){
//					pl[l][0]=init;
//					pl[l][1]=init;
//				}
//			}
//			else if(images!=null){
//				pl = new float[images.size()][2];
//				for(int l=0; l<pl.length; l++){
//					pl[l][0]=init;
//					pl[l][1]=init;
//				}
//			}
//			else{
//				System.err.println("Rater data is null");
//			}
//		}catch(Exception e){
//			e.printStackTrace();
//		}
//	}
	
	public void initialize(){
		if(initType.equals("Truth")){

			
			System.out.println("Initializing Truth");
			findLabels();
			if(efficient){
				initUniqueLabelList();
			}
			System.out.println("Labels Found: " + labels.size());
			pl = new PerformanceLevel(labels.size(), images.size());
			truth = new ArrayList<ImageData>();
			System.out.println("Num Rater Images: " + images.size());
			for(int i=0; i<labels.size(); i++){
				truth.add(new ImageDataFloat("TruthEstimate_"+labels.get(i),images.get(0).getRows(),images.get(0).getCols(),images.get(0).getSlices()));
			}
			System.out.println("Num Truth Images: " + truth.size());
			
			// Initialize the Truth Estimates
			
			
			double d = 1/images.size();

			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){

						for(int l=0; l<images.size(); l++){
							int t = getIndex(labels,images.get(l).getInt(i, j, k));
							truth.get(t).set(i, j, k, truth.get(t).getFloat(i, j,k)+d);
						}

					}
				}
			}
		
			
		}else{
			System.out.println("Initializing Performance Levels");
			float init = 0.9999f;
			try{
				findLabels();
				labels.trimToSize();
				System.out.println("Labels Found: " + labels.size());
				pl = new PerformanceLevel(labels.size(), images.size());
				//Initialize the Performance Level Estimates
				pl.initialize2(init);
				if(efficient){
					initUniqueLabelList();
					if(images!=null){
						truth = new ArrayList<ImageData>(MaxUniqueLabels);
						for(int i=0; i<MaxUniqueLabels; i++){
							truth.add(new ImageDataFloat("TruthEstimate_"+labels.get(i),images.get(0).getRows(),images.get(0).getCols(),images.get(0).getSlices()));
						}
					}else if(imagesArray!=null){
						truth = new ArrayList<ImageData>(labels.size());
						for(int i=0; i<labels.size(); i++){
							truth.add(new ImageDataFloat("TruthEstimate_"+labels.get(i),imagesArray.length,imagesArray[0].length,imagesArray[0][0].length));
						}
					}
				}else{
					if(images!=null){
						truth = new ArrayList<ImageData>(labels.size());
						for(int i=0; i<labels.size(); i++){
							truth.add(new ImageDataFloat("TruthEstimate_"+labels.get(i),images.get(0).getRows(),images.get(0).getCols(),images.get(0).getSlices()));
						}
					}else if(imagesArray!=null){
						truth = new ArrayList<ImageData>(labels.size());
						for(int i=0; i<labels.size(); i++){
							truth.add(new ImageDataFloat("TruthEstimate_"+labels.get(i),imagesArray.length,imagesArray[0].length,imagesArray[0][0].length));
						}
					}
				}
				System.out.println("Num Truth Images: " + truth.size());
				printPerformanceLevels();
			}catch(Exception e){
				e.printStackTrace();
			}
		}
	}
	
	private void initUniqueLabelList(){
		HashSet<Integer> labelshere = new HashSet<Integer>();
		if(images!=null){
			System.out.println("initializing unique label lists from raters");
			uniqueLabels = new int[images.get(0).getRows()][images.get(0).getCols()][images.get(0).getSlices()][MaxUniqueLabels];
			int num = images.size();
			int index = -1;
			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						labelshere.clear();
						index = 0;
						for(int l=0; l<num; l++){
//							System.out.println("labelshere " + labelshere);
							if(!labelshere.contains(images.get(l).getInt(i, j, k))){
//								System.out.println("adding");
								labelshere.add(images.get(l).getInt(i, j, k));
								uniqueLabels[i][j][k][index]=images.get(l).getInt(i, j, k);
								index++;
							}
						}
						for(int a = index; a<MaxUniqueLabels; a++){
							uniqueLabels[i][j][k][a]=-1;
						}
						
					}
				}
			}
		}else if(imagesArray!=null){
			System.out.println("initializing unique label lists from rater array");
			uniqueLabels = new int[imagesArray.length][imagesArray[0].length][imagesArray[0][0].length][MaxUniqueLabels];

			int num = images.size();
			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						labelshere.clear();
						int index = 0;
						//assign the values of uniqueLabels the elements of 
						for(int l=0; l<num; l++){
							if(!labels.contains(imagesArray[i][j][k][l])){
								labelshere.add(imagesArray[i][j][k][l]);
								uniqueLabels[i][j][k][index]=imagesArray[i][j][k][l];
								index++;
							}
						}
						for(int a = index; a<MaxUniqueLabels; a++){
							uniqueLabels[i][j][k][a]=-1;
						}
					}
				}
			}
		}
	}
	
	private void rowsHavingLabel(int lab){
		
	}
	
	public void EstepEfficient(){
		if(imagesArray!=null){
//			System.err.println("Not yet implemented");
			convergesum=0;
			int mindex = -1;
			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						float[] a = getPriorArray(uniqueLabels[i][j][k]);
						for(int l=0; l<imagesArray[0][0][0].length; l++){
							//Compute 'a' (Eqn 14)
							//Compute 'b' (Eqn 15)
							
							for(int m=0; m<labelshere; m++){
								mindex = getIndex(labels,uniqueLabels[i][j][k][m]);
								if(mindex>-1){
									System.out.println("Label: " );
									int t = getIndex(labels,imagesArray[i][j][k][l]);
									if(t>-1){
										a[m]=a[m]*pl.get(l, t, mindex);
									}else{
										System.err.println("Could not find label!");
									}
								}else{
									break;
								}
							}
						}
//						Compute weights for truth using Eqn (16)
						double sum=0;
						for(int m=0; m<labelshere; m++){
							sum=sum+a[m];
						}
						for(int n=0; n<labelshere; n++){
							truth.get(n).set(i,j,k, a[n]/sum);
							convergesum=convergesum+truth.get(n).getFloat(i,j,k);
						}

					}
				}
			}
		}else if(images!=null){
			convergesum=0;
			int mindex = -1;
			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						float[] a = getPriorArray(uniqueLabels[i][j][k]);
						for(int l=0; l<images.size(); l++){
							//Compute 'a' (Eqn 14)
							//Compute 'b' (Eqn 15)
							
							for(int m=0; m<labelshere; m++){
								if(uniqueLabels[i][j][k][m]>-1){
									mindex = getIndex(labels,uniqueLabels[i][j][k][m]);
									if(mindex>-1){
										int t = getIndex(labels,images.get(l).getInt(i, j, k));
										if(t>-1){
											a[m]=a[m]*pl.get(l, t, mindex);
										}else{
											System.err.println("Could not find label!");
										}
									}else{
										break;
									}
								}else{
									break;
								}
							}
						}
//						System.out.println("Local probs");
//						printArray(a);
//						System.out.println("");
//						Compute weights for truth using Eqn (16)
						float sum=0;
						for(int m=0; m<labelshere; m++){
							sum=sum+a[m];
						}
						for(int n=0; n<labelshere; n++){
							truth.get(n).set(i,j,k, a[n]/sum);
//							if(Float.isNaN(truth.get(n).get(i, j, k).intValue())){
//								System.out.println("U.n. prob: " + a[n]);
//								System.out.println("Sum: " + sum);
//							}else if(truth.get(n).getFloat(i, j, k)==0){
//								System.out.println("local label index: " + n);
//								System.out.println("labelshere: " + labelshere);
//								System.out.println("U.n. prob: " + a[n]);
//								System.out.println("Sum: " + sum);
//								printArray(uniqueLabels[i][j][k]);
//								System.out.println(" ");
//							}
							convergesum=convergesum+truth.get(n).getFloat(i,j,k);
							
						}

					}
				}
			}
		}
		
	}
	
	public void Estep(){
		if(imagesArray!=null){
//			System.err.println("Not yet implemented");
			convergesum=0;

			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						float[] a = getPriorArray();
						for(int l=0; l<imagesArray[0][0][0].length; l++){
							//Compute 'a' (Eqn 14)
							//Compute 'b' (Eqn 15)
							int t = getIndex(labels,imagesArray[i][j][k][l]);
							for(int m=0; m<labels.size(); m++){
//								System.out.println("Label: " );
								if(t>-1){
									a[m]=a[m]*pl.get(l, t, m);
								}else{
									System.err.println("Could not find label!");
								}
							}
						}
//						Compute weights for truth using Eqn (16)
						float sum=0;
						for(int m=0; m<labels.size(); m++){
							sum=sum+a[m];
						}
						for(int n=0; n<truth.size(); n++){
							truth.get(n).set(i,j,k, a[n]/sum);
							convergesum=convergesum+truth.get(n).getFloat(i,j,k);
						}
					}
				}
			}
		}else if(images!=null){
			convergesum=0;

			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						float[] a = getPriorArray();
						for(int l=0; l<images.size(); l++){
							//Compute 'a' (Eqn 14)
							//Compute 'b' (Eqn 15)

							for(int m=0; m<labels.size(); m++){
//								System.out.println("Label: " );
								int t = getIndex(labels,images.get(l).getInt(i, j, k));
								if(t>-1){
									a[m]=a[m]*pl.get(l, t, m);
								}else{
									System.err.println("Could not find label!");
								}
							}
						}
//						Compute weights for truth using Eqn (16)
						float sum=0;
						for(int m=0; m<labels.size(); m++){
							sum=sum+a[m];
						}
						for(int n=0; n<truth.size(); n++){
							truth.get(n).set(i,j,k, a[n]/sum);
							convergesum=convergesum+truth.get(n).getFloat(i,j,k);
						}

					}
				}
			}
		}

	}
	
	
	public void MstepEfficient(){
		if(imagesArray!=null){
			
//			Compute performance parameters given the truth
			// using Eqns (18) & (19)

			float[] totsum = new float[labels.size()];
			
			//DO I WANT TO CLEAR THE PERFORMANCE LEVELS HERE? 
			//YES I DO
			pl.clear();
			int mtindex = -1;
//			int ltindex = -1;
			int t = -1;
			for(int m=0; m<labels.size(); m++){
				
				for(int i=0; i<rows; i++){
					for(int j=0; j<cols; j++){
						for(int k=0; k<slices; k++){
							mtindex = getIndex(uniqueLabels[i][j][k],labels.get(m));
							if(mtindex>-1){
								totsum[m]+=truth.get(mtindex).getFloat(i,j,k);
								for(int l=0; l<imagesArray[0][0][0].length; l++){
//									ltindex = getIndex(uniqueLabels[i][j][k],imagesArray[i][j][k][l]);
									t = getIndex(labels,imagesArray[i][j][k][l]);
									if(t>-1){
										pl.set(l, t, m, pl.get(l, t, m)+truth.get(mtindex).getFloat(i,j,k));
										if(Float.isNaN(pl.get(l, t, m))){ keepgoing = false; }
									}else{
										System.err.println("Could not find label!");
									}

								}
							}else{
								break;
							}
						}
					}
				}
			}
			// Store performance parameter estimates for this iteration
			for(int n=0; n<labels.size(); n++){
				pl.divideByTots(n, totsum[n]);
			}
			
			System.out.println(pl);
		}else if(images!=null){

			
//			Compute performance parameters given the truth
			// using Eqns (18) & (19)

			float[] totsum = new float[labels.size()];
			
			//DO I WANT TO CLEAR THE PERFORMANCE LEVELS HERE?
			//YES I DO
			pl.clear();
			int mtindex = -1;
//			int ltindex = -1;
			int t = -1;
			for(int m=0; m<labels.size(); m++){
				System.out.println("Working on label " + m + ": " + labels.get(m));
				for(int i=0; i<rows; i++){
					for(int j=0; j<cols; j++){
						for(int k=0; k<slices; k++){
							mtindex = getIndex(uniqueLabels[i][j][k],labels.get(m));
							if(m>0){
								System.out.println("label: " + labels.get(m) + "   mindex: "+mtindex);
								printArray(uniqueLabels[i][j][k]);
							}
							
							if(mtindex>-1){
								totsum[m]+=truth.get(mtindex).getFloat(i,j,k);
								
								if(m>0){
									System.out.println("In here");
									System.out.println(truth.get(mtindex).getFloat(i,j,k));
									printArray(uniqueLabels[i][j][k]);
									printArray(totsum);
								}
								for(int l=0; l<images.size(); l++){
									t = getIndex(labels,images.get(l).getInt(i, j, k));
//									ltindex = getIndex(uniqueLabels[i][j][k],images.get(l).getInt(i, j, k));
									if(t>-1){
										pl.set(l, t, m, pl.get(l, t, m)+truth.get(mtindex).getFloat(i,j,k));
										if(Float.isNaN(pl.get(l, t, m))){ keepgoing = false; }
									}else{
										System.err.println("Could not find label!");
									}

								}
							}else{
								break;
							}
						}
					}
				}
			}
			System.out.println("Totsum:");
			printArray(totsum);
			// Store performance parameter estimates for this iteration
			for(int n=0; n<labels.size(); n++){
				pl.divideByTots(n, totsum[n]);
			}
			
			System.out.println(pl);
			
		}
	}

	public void Mstep(){
		if(imagesArray!=null){
			System.err.println("Not yet implemented!");
			
//			Compute performance parameters given the truth
			// using Eqns (18) & (19)

			float[] totsum = new float[labels.size()];
			
			//DO I WANT TO CLEAR THE PERFORMANCE LEVELS HERE?
			pl.clear();
			
			for(int m=0; m<labels.size(); m++){
//				float totsum = 0f;
				for(int i=0; i<rows; i++){
					for(int j=0; j<cols; j++){
						for(int k=0; k<slices; k++){
							totsum[m]=totsum[m]+truth.get(m).getFloat(i,j,k);
							for(int l=0; l<imagesArray[0][0][0].length; l++){
								
								int t = getIndex(labels,imagesArray[i][j][k][l]);
								if(t>-1){
									pl.set(l, t, m, pl.get(l, t, m)+truth.get(m).getFloat(i,j,k));
									if(Float.isNaN(pl.get(l, t, m))){ keepgoing = false; }
								}else{
									System.err.println("Could not find label!");
								}
								
							}
						}
					}
				}
			}
			// Store performance parameter estimates for this iteration
			for(int n=0; n<labels.size(); n++){
				pl.divideByTots(n, totsum[n]);
			}
			
			System.out.println(pl);
		}else if(images!=null){

			
//			Compute performance parameters given the truth
			// using Eqns (18) & (19)

			float[] totsum = new float[labels.size()];
			
			//DO I WANT TO CLEAR THE PERFORMANCE LEVELS HERE?
			pl.clear();
			
			for(int m=0; m<labels.size(); m++){
//				float totsum = 0f;
				for(int i=0; i<rows; i++){
					for(int j=0; j<cols; j++){
						for(int k=0; k<slices; k++){
							totsum[m]=totsum[m]+truth.get(m).getFloat(i,j,k);
							for(int l=0; l<images.size(); l++){
								
								int t = getIndex(labels,images.get(l).getInt(i, j, k));
								if(t>-1){
									pl.set(l, t, m, pl.get(l, t, m)+truth.get(m).getFloat(i,j,k));
									if(Float.isNaN(pl.get(l, t, m))){ keepgoing = false; }
								}else{
									System.err.println("Could not find label!");
								}
								
							}
						}
					}
				}
			}
			// Store performance parameter estimates for this iteration
			for(int n=0; n<labels.size(); n++){
				pl.divideByTots(n, totsum[n]);
			}
			
			System.out.println(pl);
			
		}
	}
	
	public void getPriorProb(){
		if(imagesArray!=null){
			float total = imagesArray.length*imagesArray[0].length*imagesArray[0][0].length*imagesArray[0][0][0].length;
			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						for(int l=0; l<imagesArray[0][0][0].length; l++){
							if(!labels.contains(imagesArray[i][j][k][l])){
								labels.add(imagesArray[i][j][k][l]);
								priors.add(new Float(1f));
							}else{
								int thisone = getIndex(labels,imagesArray[i][j][k][l]);
								priors.set(thisone, priors.get(thisone)+1);
							}
						}
					}
				}
			}
//			System.out.println("Sum: " + sum);
//			System.out.println("Total" + total);
			for(int m=0; m<priors.size(); m++){
				priors.set(m, priors.get(m)/total);
				System.out.println("Prior Prob of label: " + labels.get(m)+" is: " + priors.get(m));
			}
			priors.trimToSize();
		}else if(images!=null){
			float total = images.size()*rows*cols*slices;
			labels = new ArrayList<Integer>();
			priors = new ArrayList<Float>();
			for(int i=0; i<rows; i++){
				for(int j=0; j<cols; j++){
					for(int k=0; k<slices; k++){
						for(int l=0; l<images.size(); l++){
							if(!labels.contains(images.get(l).getInt(i, j, k))){
								labels.add(images.get(l).getInt(i, j, k));
								priors.add(new Float(1f));
							}else{
								int thisone = getIndex(labels,images.get(l).getInt(i, j, k));
								priors.set(thisone, priors.get(thisone)+1);
							}
						}
					}
				}
			}
//			System.out.println("Sum: " + sum);
//			System.out.println("Total" + total);
			priors.trimToSize();
			for(int m=0; m<priors.size(); m++){
				priors.set(m, priors.get(m)/total);
				System.out.println("Prior Prob of label: " + labels.get(m)+" is: " + priors.get(m));
			}
			priors.trimToSize();
		}else{
			System.err.println("Rater data is null");
		}
	}
	
	public void iterate(){
		if(initType.equals("Truth")){
			initialize();
		}else{
			initialize();
			if(efficient){
				EstepEfficient();
			}else{
				Estep();
			}
		}
		
//		writeUniqueLabels(new File("/home/john/Desktop/truthVols"));
//		writeTruth(new File("/home/john/Desktop/truthVols"));
		double prevcs = convergesum;
		int iters = 0;	
		
		float ntrace = 0;
		while(keepgoing && iters<maxiters){
//			System.out.println("Iteration: " +iters);
			if(efficient){
				MstepEfficient();
				EstepEfficient();
			}else{
				Mstep();
				Estep();
			}

			
			if(beta>0){
				smoothMRF();
			}
			ntrace = pl.normalizedTrace();
			if(Math.abs(ntrace-normtrace)<eps){
				System.out.println("Prev Sum: " +ntrace);
				System.out.println("Converge Sum: " +normtrace);
				System.out.println("Diff: " + Math.abs(prevcs-convergesum));
				System.out.println("Converged, Total Iterations: " + iters);
				keepgoing=false;
				printPerformanceLevels();
			}else{
				System.out.println("Iteration: " +iters);
				System.out.println("Prev Sum: " +ntrace);
				System.out.println("Converge Sum: " +normtrace);
				System.out.println("Diff: " + Math.abs(prevcs-convergesum));
				System.out.println("Need to get below: "+eps);

				System.out.println("*****************");
			}
//			if(Math.abs(prevcs-convergesum)<eps){
//				System.out.println("Converged, Total Iterations: " + iters);
//				keepgoing=false;
//				printPerformanceLevels();
//			}
//			System.out.println("Iteration: " +iters);
//			System.out.println("Prev Sum: " +prevcs);
//			System.out.println("Converge Sum: " +convergesum);
//			System.out.println("Diff: " + Math.abs(prevcs-convergesum));
//			System.out.println("Need to get below: "+eps);
//			System.out.println("*****************");

//			String iout = dir + "Iter"+iters+".raw";
//			System.out.println("Writing to: " +iout);
//			RawWriter.writeImgFloat(truth, iout);
			
			iters++;
			prevcs = convergesum;
			normtrace=ntrace;
		}
//		ParamObject<String> out = new ParamObject<String>("PerformanceLevels");
		
	}
	
	public void smoothMRF(){
		if(mrfUtil==null){
			mrfUtil=new StapleMRFUtil(rows,cols,slices,true);
		}

		System.out.print("Smoothing ");

		computeHardSeg();
		float truthhere = 0f;
		float truthhnbr = 0f;
		float truthdelta = 0f;
		int labelhere = -1;
		int labelnbr = -1;
		int i =-1;
		int j=-1;
		int k=-1;
//		writeTruth(new File("/home/john/Desktop/truthVols"),"OriginalTruth");
//		writeHardSeg(new File("/home/john/Desktop/truthVols"),"OriginalSeg");
		int numchanged = 0;
		if(beta>0){
			int iterations = 1;
			GridPt[] nbrs = null;
			for(int l=0; l<iterations; l++){
				numchanged = 0;
//				System.out.println("..." + l);
				//first handle even voxels
				mrfUtil.reset();
				mrfUtil.setEvenSlice(true);
				while(mrfUtil.hasNext()){
					mrfUtil.updateCoords();
					i=mrfUtil.thisrow;
					j=mrfUtil.thiscol;
					k=mrfUtil.thisslice;
					nbrs = getNeighbors(i,j,k);
					labelhere=hardseg.getInt(i, j, k);

					for(int n=0; n<nbrs.length; n++){
						if(nbrs[n]!=null){
							labelnbr=hardseg.getInt(nbrs[n].x, nbrs[n].y, nbrs[n].z);
							if(labelhere!=labelnbr){

								truthhere = truth.get(labelstoindex.get(labelhere)).getFloat(i, j, k);
								truthhnbr = truth.get(labelstoindex.get(labelnbr)).getFloat(nbrs[n].x, nbrs[n].y, nbrs[n].z);
								truthdelta=beta*(1-truthhere)*(1-truthhere)*(truthhnbr)*(truthhnbr);
								//update highest truth
								truth.get(labelstoindex.get(labelhere)).set(i,j,k,truthhere-truthdelta);
								//update truth for neighbor values
								truth.get(labelstoindex.get(labelnbr)).set(i, j, k, 
										truth.get(labelstoindex.get(labelnbr)).getInt(i,j,k)+(truthhnbr+truthdelta));

//								if(i==51 && j==4 && k==8){
//									System.out.println("truth here: " + truthhere);
//									System.out.println("truth nbr: " + truthhnbr);
//									System.out.println("truth delta : " + truthdelta);
//									System.out.println("beta: " + beta);
//
//									System.out.println("NEW TRUTH : " + (truthhere-truthdelta));
//									System.out.println("NEW COMPETITOR TRUTH: " + (truthhere+truthdelta));
//									if(!isTruthNormalized(i,j,k)){
//										System.out.println("Not normalized");
//									}
//								}

							}
						}
					}

					if(updateHardSeg(i,j,k)){numchanged++;}
				}

				//then look at odd voxels
				mrfUtil.reset();
				mrfUtil.setEvenSlice(false);
				while(mrfUtil.hasNext()){
					mrfUtil.updateCoords();
					i=mrfUtil.thisrow;
					j=mrfUtil.thiscol;
					k=mrfUtil.thisslice;
					nbrs = getNeighbors(i,j,k);
					labelhere=hardseg.getInt(i, j, k);
					for(int n=0; n<nbrs.length; n++){
						if(nbrs[n]!=null){
							labelnbr=hardseg.getInt(nbrs[n].x, nbrs[n].y, nbrs[n].z);
							if(labelhere!=labelnbr){
								truthhere = truth.get(labelstoindex.get(labelhere)).getFloat(i, j, k);
//								System.out.println(labelnbr);
								truthhnbr = truth.get(labelstoindex.get(labelnbr)).getFloat(nbrs[n].x, nbrs[n].y, nbrs[n].z);
								truthdelta=beta*(1-truthhere)*(1-truthhere)*(truthhnbr)*(truthhnbr);
								//update formerly highest truth
								truth.get(labelstoindex.get(labelhere)).set(i,j,k,truthhere-truthdelta);
								//update truth for neighborvalues
								truth.get(labelstoindex.get(labelnbr)).set(i, j, k, 
										truth.get(labelstoindex.get(labelnbr)).getInt(i,j,k)+(truthhere+truthdelta));

							}	
						}
					}

//					if(i==51 && j==4 && k==8){
//						System.out.println("truth here: " + truthhere);
//						System.out.println("truth nbr: " + truthhnbr);
//						System.out.println("truth delta : " + truthdelta);
//						System.out.println("beta: " + beta);
//
//						System.out.println("NEW LABEL TRUTH : " + (truthhere-truthdelta));
//						System.out.println("NEW COMPETITOR TRUTH : " + (truthhere+truthdelta));
//						if(!isTruthNormalized(i,j,k)){
//							System.out.println("Not normalized");
//						}
//					}

					if(updateHardSeg(i,j,k)){numchanged++;}
				}
				System.out.println("Iteration " + l);
				System.out.println("Voxel changed this iteration " + numchanged);
//				writeHardSeg(new File("/home/john/Desktop/truthVols"),"IntermediateSeg"+l);
			}
			System.out.print("...done\n");
//			writeTruth(new File("/home/john/Desktop/truthVols"),"FinalTruth");

		}
	}
	
	private boolean isTruthNormalized(int i, int j, int k){
		float eps = 0.0001f;
		float sum =0f;
		for(int l=0; l<labels.size(); l++){
			sum+=truth.get(l).getFloat(i, j, k);
		}
		return (Math.abs(sum-1f)<eps);
	}
	
	/**
	 * Returns true if the hard segmentation has changed here.
	 * @param i first coord
	 * @param j second coord
	 * @param k third coord
	 * @return true if the hard segmentation has changed
	 */
	private boolean updateHardSeg(int i, int j, int k){
		int current = hardseg.getInt(i,j,k);
		double max = 0;
		int label = -1;
		for(int l=0; l<truth.size(); l++){
			if(truth.get(l).getFloat(i, j, k)>max){
				max = truth.get(l).getFloat(i,j,k);
				label= labels.get(l);
			}
		}
		if(label==-1){
			System.out.println(i+","+j+","+k);
		}
		hardseg.set(i,j,k,label);
		return (current!=label);
	}
	
	public void smoothMRFEfficient(){
		
	}
	
	public void printPerformanceLevels(){
		System.out.println(pl);
	}
	public String getPerformanceLevels(){
		return pl.toString();
	}
		
	private static void printArray(float[] a){
		String ans = " ";
		for(int i=0; i<a.length; i++){
			ans = ans + a[i]+" ";
		}
		System.out.println(ans);
	}
	private static void printArray(int[] a){
		String ans = " ";
		for(int i=0; i<a.length; i++){
			ans = ans + a[i]+" ";
		}
		System.out.println(ans);
	}
	
	private float[] getPriorArray(){
		float[] a = new float[labels.size()];
		for(int i=0; i<labels.size(); i++){
			a[i]=priors.get(i).floatValue();
		}
		return a;
	}
	
	private float[] getPriorArray(int[] locallabels){
		int i = 0;
		float[] priorshere = new float[locallabels.length];
		while(i<locallabels.length && locallabels[i]>-1){
			priorshere[i]=priors.get(getIndex(labels,locallabels[i])).floatValue();
			i++;
		}
		labelshere=i;
		return priorshere;
	}
	
	public float[][][][] getTruthImage(){
		return truthImage;
	}
	
	public ArrayList<ImageData> getTruth(){
		truth.trimToSize();
		return truth;
	}

	public PerformanceLevel getPeformanceLevel(){
		return pl;
	}
	
	public int getIndex(ArrayList<Integer> l, Integer n){
		Iterator<Integer> it = l.iterator();
		int i =0;
		while(it.hasNext()){
			if(it.next().intValue()==n.intValue()){
				return i;
			}
			i++;
		}
		return -1;
	}
	
	public int getIndex(int[] a, int n){
		for(int i=0; i<a.length; i++){
			if(a[i]==n){
				return i;
			}
		}
		return -1;
	}
	public ImageData getHardSeg(){
		if(hardseg==null){
			computeHardSeg();
		}
		return hardseg;
	}
	public void computeHardSeg(){
		if(hardseg==null){
			if(images!=null){
				hardseg = new ImageDataInt(rows, cols, slices);
				hardseg.setName(images.get(0).getName()+"_StapleLabeling");
			}else if(imagesArray!=null){
				hardseg = new ImageDataInt(rows, cols, slices);
				hardseg.setName("StapleLabeling");
			}
		}
		
		if(truth!=null){
			for(int k =0; k<hardseg.getSlices(); k++){
				for(int j =0; j<hardseg.getCols(); j++){
					for(int i =0; i<hardseg.getRows(); i++){

						double max = -1;
						int label = -1;
						for(int l=0; l<labels.size(); l++){
							if(truth.get(l).getFloat(i, j, k)>max){
								max = truth.get(l).getFloat(i,j,k);
								label= labels.get(l);
							}
						}
						if(!labels.contains(label)){
							System.out.println("Problem!");
							System.out.println("label: " + label);
//							System.out.println("index " + l);
							System.out.println("label list: " + labels);
						}
						hardseg.set(i, j, k, label);
						
						if(!labels.contains(hardseg.getInt(i, j, k)) || hardseg.getInt(i, j, k)!=label){
							System.out.println("Problem!");
							System.out.println("set label: " + label);
							System.out.println("label: " + hardseg.getInt(i, j, k));
//							System.out.println("index " + l);
							System.out.println("label list: " + labels);
						}
						if(hardseg.getInt(i, j, k)==-1){
							System.out.println("Negative one");
						}
					}
				}
			}
		}else if(truthImage!=null){
			for(int k =0; k<hardseg.getSlices(); k++){
				for(int j =0; j<hardseg.getCols(); j++){
					for(int i =0; i<hardseg.getRows(); i++){
						double max = 0;
						int label = -1;
						for(int l=0; l<truthImage[0][0][0].length; l++){
							if(truthImage[i][j][k][l]>max){
								max = truthImage[i][j][k][l];
								label= labels.get(l);
							}
						}
						
						hardseg.set(i, j, k, label);
					}
				}
			}
		}
	}
	
	public int[][][] getHardSegArray(){

		int[][][] segout = null;
		if(images!=null){
			segout = new int[images.get(0).getRows()][images.get(0).getCols()][images.get(0).getSlices()];
		}else if(imagesArray!=null){
			segout = new int[imagesArray.length][imagesArray[0].length][imagesArray[0][0].length];
		}
		if(truth!=null){
			for(int k =0; k<segout[0][0].length; k++){
				for(int j =0; j<segout[0].length; j++){
					for(int i =0; i<segout.length; i++){

						double max = 0;
						int label = -1;
						for(int l=0; l<truth.size(); l++){
							if(truth.get(l).getFloat(i, j, k)>max){
								max = truth.get(l).getFloat(i,j,k);
								label= labels.get(l);
							}
						}
						segout[i][j][k]=label;
					}
				}
			}
		}else if(truthImage!=null){
			for(int k =0; k<segout[0][0].length; k++){
				for(int j =0; j<segout[0].length; j++){
					for(int i =0; i<segout.length; i++){
						double max = 0;
						int label = -1;
						for(int l=0; l<truthImage[0][0][0].length; l++){
							if(truthImage[i][j][k][l]>max){
								max = truthImage[i][j][k][l];
								label= labels.get(l);
							}
						}
						segout[i][j][k]=label;
					}
				}
			}
		}
		return segout;
	}
	
	public void writeTruth(File dir){
		
		for(int i=0; i<truth.size(); i++){
			rw.write(truth.get(i), dir);
		}
		
	}
	
	public void writeTruth(File dir, String name){
		
		for(int i=0; i<truth.size(); i++){
			rw.write(truth.get(i), new File(dir.getAbsolutePath()+"/"+name+"_"+i+".xml"));
		}
		
	}

	public void writeUniqueLabels(File dir){
		
		if(uniqueLabels!=null){
			ImageDataInt unique = new ImageDataInt(uniqueLabels);
			ImageHeader hdr = truth.get(0).getHeader().clone();
			unique.setHeader(hdr);
			unique.setName("UniqueValues");
			rw.write(unique, dir);
		}
	}
	
	public void writeHardSeg(File dir, String name){
			rw.write(hardseg, new File(dir.getAbsolutePath()+"/"+name+"_"+".xml"));
	}
	
	private class StapleMRFUtil{
		private int rows, cols, slices;
		private boolean evenslice;
		private boolean evencol;
		
		public int thisrow = -1;
		public int thiscol = -1;
		public int thisslice =-1;
		
		public StapleMRFUtil(int rows, int cols, int slices, boolean evenslice){
			this.rows=rows;
			this.cols=cols;
			this.slices=slices;	
			setEvenSlice(evenslice);
		}
		
		public void setEvenSlice(boolean evenslice){
			this.evenslice=evenslice;
		}
		
		public void reset(){
			thisrow = -1;
			thiscol = -1;
			thisslice =-1;
		}
		
		public void updateCoords(){

			if(thisslice<0){
				if(evenslice){
					thisrow=-1;
					thiscol=0;
					thisslice=0;
					evencol = true;
				}else{
					thisrow=-1;
					thiscol=0;
					thisslice=0;
					evencol = false;
				}
				updateCoords2d();
			}else{
				if(hasNext2d()){
//					System.out.println("Working on slice: " + thisslice);
					updateCoords2d();
				}else{
					thisslice++;
					thisrow=-1;
					evenslice=!evenslice;
					updateCoords2d();
				}
			}
		}
		
		
		public void updateCoords2d(){
			if(thisrow<0){
				if(evenslice){
					thisrow=0;
					thiscol=0;
					evencol = true;
				}else{
					thisrow=0;
					thiscol=1;
					evencol = false;
				}
			}else{

				if(thiscol+1==cols){
					thisrow++;
					if(evencol){
						thiscol=1;
					}else{
						thiscol=0;
					}
					evencol = !evencol;
				}else if(thiscol+2==cols){
					thisrow++;
					if(evencol){
						thiscol=1;
					}else{
						thiscol=0;
					}
					evencol = !evencol;
				}else{
					thiscol+=2;
				}
			}
		}
		
		public boolean hasNext(){
			if(thisrow+1>=rows && thiscol+2>=cols && thisslice+1==slices){
				return false;
			}
			return true;
		}
		public boolean hasNext2d(){
			if(thisrow+1>=rows && thiscol+2>=cols){
				return false;
			}
			return true;
		}
	}

}
