package edu.jhu.pami.spring2009;
import java.util.Arrays;

import edu.jhu.ece.iacl.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.pipeline.parameter.ParamObject;
import edu.jhu.ece.iacl.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.structures.image.ImageData;
import edu.jhu.ece.iacl.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.structures.image.ImageDataFloat;
//import edu.jhu.ece.iacl.structures.matrix.Matrix2;
import edu.jhu.bme.smile.commons.math.StatisticsDouble;
import edu.jhu.bme.smile.commons.math.MatrixMath;
import Jama.Matrix;

/**
 * Segmentation Module, developed for PAMI HW2
 *
 * @author Yu "Charlie" Ouyang
 */
public class CopyOfHW2_Segmentation extends ProcessingAlgorithm{
	//Input Variables
	private ParamVolume vol;	//input volume
	//	private ParamBoolean useMask;
	private ParamVolume	mask;	//mask for input
	private ParamOption alg;	//algorithm
	private ParamInteger k;		//cluster number

	//Output Variables
	private ParamVolume	out;	//output volume

	//CVS
	private static final String rcsid = "$Id: CopyOfHW2_Segmentation.java,v 1.1 2009/03/05 16:45:44 bennett Exp $";
	private static final String cvsversion = "$Revision: 1.1 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "");

	protected void createInputParameters(ParamCollection inputParams){
		//Plugin Information
		inputParams.setCategory("PAMI");
		inputParams.setPackage("Spring2009");
		inputParams.setName("HW2-Segmentation(BL)");
		inputParams.setLabel("Segmentation(BL)");
		AlgorithmInformation info=getAlgorithmInformation();
		info.setWebsite("http://sites.google.com/site/jhupami/");
		info.add(new AlgorithmAuthor("Yu Ouyang","couyang@jhu.edu","http://sites.google.com/a/jhu.edu/pami/"));
		info.setDescription("Allows segmentation with k-means or maximum likelihood with Gaussian mixture models.");
		info.setAffiliation("Johns Hopkins University, Department of Biomedical Engineering");
		info.setVersion(revnum);

		//Inputs
		inputParams.add(vol = new ParamVolume("Input Volume"));
		inputParams.add(mask = new ParamVolume("Mask"));
		mask.setMandatory(false);
		//mask.setMandatory(false);
		//		inputParams.add(useMask = new ParamBoolean ("Use mask", true));
		inputParams.add(alg = new ParamOption("Algorithm", new String[]{"K-Means", "EM (ML w/ GMM)"}));
		inputParams.add(k = new ParamInteger("Clusters"));
		k.setValue(new Integer(3));
	}

	protected void createOutputParameters(ParamCollection outputParams) {
		outputParams.add(out = new ParamVolume("Output Volume"));
	}

	protected void execute(CalculationMonitor monitor){
		/**
		 * Initialization
		 */
		//image data properties
		ImageData idVol = vol.getImageData();

		ImageData idMask = mask.getImageData();
		boolean maskUsed = idMask!=null;//useMask.getValue();
		int rows = idVol.getRows();
		int cols = idVol.getCols();
		int slices = idVol.getSlices()<1?1:idVol.getSlices();
		int components = idVol.getComponents()<1?1:idVol.getComponents();
		float maxVal = 0;
		//user parameter
		int clusters = k.getInt();

		//read through idVol for its maximum value
		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					for (int l = 0; l < components; l++) {
						if(idMask!=null) {
							if (maskUsed == true && idMask.getInt(i, j, k ,l) == 1)
								maxVal = Math.max(maxVal, idVol.getFloat(i,j,k,l));
						}
						else
							maxVal = Math.max(maxVal, idVol.getFloat(i,j,k,l));

					}
				}
			}
		}

		//initialize array of k means and generate random means
		float [] meansArray = new float[clusters]; //array of random means
		java.util.Random rand = new java.util.Random();
		for (int n = 0; n < clusters; n++){	//fill array
			meansArray[n] = maxVal*rand.nextFloat();
			//			meansArray[n] = maxVal*n/clusters;
		}

		//initialize Assignment Volume
		int [][][][]volAssign = new int[rows][cols][slices][components];

		/**
		 * Algorithms
		 */

		switch(alg.getIndex()){

		//K-Means Algorithm
		case 0:

			doKMeans(volAssign,idVol,idMask,maskUsed,meansArray,clusters);

			break;


			//EM Algorithm

		case 1:

			/**
			 * Initialize with K-Means first
			 */
			System.out.println("EM Algorithm: Initializing with K-Means...");
			doKMeans(volAssign,idVol,idMask,maskUsed,meansArray,clusters);
			//initialize parameters... using K-means
			/**
			 * End K-Means for Intialization
			 */
			doEM(volAssign,idVol,idMask,maskUsed,meansArray,clusters);


			break;

		} //end switch

		//End Classification, output segmented volume

		//rescale assignment volume (segmentation volume), and output
		//		for (int i = 0; i < rows; i++) {
		//			for (int j = 0; j < cols; j++) {
		//				for (int k = 0; k < slices; k++) {
		//					for (int l = 0; l < components; l++) {
		//						if (idMask.getInt(i, j, k ,l) == 1)
		//							volAssign[i][j][k][l] = volAssign[i][j][k][l]; //((volAssign[i][j][k][l])+1)*(255/(clusters+1));
		//					}
		//				}
		//			}
		//		}
		ImageDataInt tmp = (new ImageDataInt(volAssign));
		tmp.setName("Segmentation");
		out.setValue(tmp);


	} //end execute method


	private void doEM(int[][][][] volAssign, ImageData idVol, ImageData idMask,
			boolean maskUsed, float[] meansArray, int clusters) {

		int rows = volAssign.length;
		int cols = volAssign[0].length;
		int slices = volAssign[0][0].length;
		int components = volAssign[0][0][0].length;

		// TODO Auto-generated method stub
		boolean stopEM = false;
		/*new boolean[clusters];
		for (int i = 0; i < clusters; i++)
			stopEM[i] = false;
		 */

		// Already initialized 
		float theta_muk[] = meansArray.clone(); // means

		// Initialize with random numbers
		float theta_pik[] = new float[clusters]; //prior
		float sum=0;
		for(int i=0;i<theta_pik.length;i++) { 
			theta_pik[i]=(float)Math.random();
			sum+=theta_pik[i];
		}
		for(int i=0;i<theta_pik.length;i++)
			theta_pik[i]/=sum;			
		float []last_theta_muk = new float[clusters];
		// Doesn't actually need to be a matrix! 
		float theta_sigmaK[] = new float[clusters]; //cluster variance
		for(int i=0;i<theta_sigmaK.length;i++) {			
			theta_sigmaK[i] = 100; //2*theta_muk[i];
		}


		//volume of posterior probabilities (extra dimension for clusters)
		float [][][][][]volProb = new float[rows][cols][slices][components][clusters];

		/**
		 * End Initialization
		 */

		while(!stopEM) {
			System.out.println("EM-means:");
			for(int i=0;i<meansArray.length;i++)
				System.out.println("\t"+i+":"+theta_muk[i]+" pi:"+theta_pik[i]+" sig:"+theta_sigmaK[i]);

			//E-STEP - compute posterior probabilities for each cluster at each point			
			System.out.println("EM Algorithm: E-Step");


			//for every point in the volume
			for (int i = 0; i < rows; i++) {
				for (int j = 0; j < cols; j++) {
					for (int k = 0; k < slices; k++) {
						for (int l = 0; l < components; l++) {
							//for each cluster at this point
							if(maskUsed && idMask!=null) {
								if(idMask.getInt(i, j, k ,l) == 1) {
									float[] tmp = GMMprobKgivXTheta(idVol.getFloat(i, j, k, l),
											theta_pik,theta_muk,theta_sigmaK);
									for (int c = 0; c < clusters; c++) 
										volProb[i][j][k][l][c] = tmp[c];
								}
							} else {
								float[] tmp = GMMprobKgivXTheta(idVol.getFloat(i, j, k, l),
										theta_pik,theta_muk,theta_sigmaK);
								//								System.out.println(Arrays.toString(tmp));
								for (int c = 0; c < clusters; c++) 
									volProb[i][j][k][l][c] = tmp[c];
							}

							/*
								//If we are at unmasked values, and if we are on the correct cluster
								if (idMask.getInt(i, j, k ,l) == 1 && volAssign[i][j][k][l] == c) {
									//calculate numerator for p_i,k
									volProb[i][j][k][l][c] = (weightsArray[c] * (float) (1 / Math.sqrt((2*Math.PI)*emCov[c].det())) * (float) Math.exp((getExpTerms(idVol.getFloat(i, j, k, l), emMeans, emCov[c], c))/2));
									//calculate denominator for p_i,k
									float tmpDenom = 0;
									for (int x = 0; x <= clusters - 1; x++)
										tmpDenom += (weightsArray[x] * (float) (1 / Math.sqrt((2*Math.PI)*emCov[c].det())) * (float) Math.exp((getExpTerms(idVol.getFloat(i, j, k, l), emMeans, emCov[c], x))/2));
									//divide numerator by denominator
									volProb[i][j][k][l][c] = volProb[i][j][k][l][c] / tmpDenom;
								}
							 */
							//							} //end for loop for each cluster

						}
					}
				}
			} //end scanning through volume


			//M-STEP - update cluster weights, means, and covariance
			System.out.println("EM Algorithm: M-Step");
			// pik = mean over all probs
			/*
			 * 
			for k=1:K
		        p = GMMprobKgivXTheta(k,X,theta_pik{n},theta_muk{n},theta_sigmaK{n});
		        theta_pik{n+1}(k) = mean(p);
		    end    
			 */
			// muk = weighted mean by x value
			/*for k=1:K
	        	p = GMMprobKgivXTheta(k,X,theta_pik{n},theta_muk{n},theta_sigmaK{n});
	        	theta_muk{n+1}(k,:) = sum(repmat(p,[1 d]).*X,1)/sum(p);
	    	end*/

			// sigmak = mean of sum of squares of diff
			/*for k=1:K
	        p = GMMprobKgivXTheta(k,X,theta_pik{n},theta_muk{n},theta_sigmaK{n});
	        theta_sigmaK{n+1}(:,:,k) = zeros([d d]);
	        for i=1:N
	            s=(X(i,:)-theta_muk{n}(k,:))'*(X(i,:)-theta_muk{n}(k,:));
	            theta_sigmaK{n+1}(:,:,k) = theta_sigmaK{n+1}(:,:,k)+p(i)*s;
	        end
	        theta_sigmaK{n+1}(:,:,k) = theta_sigmaK{n+1}(:,:,k)/sum(p);
	    	end 
			 */




			for(int c=0;c<clusters;c++) {
				last_theta_muk [c]=theta_muk[c];
				float sumP=0;				
				float next_theta_muk=0;
				float next_theta_sigmaK=0;
				int count=0;
				for (int i = 0; i < rows; i++) {
					for (int j = 0; j < cols; j++) {
						for (int k = 0; k < slices; k++) {
							for (int l = 0; l < components; l++) {
								if(maskUsed && idMask!=null) {
									if(idMask.getInt(i, j, k ,l) == 1) {
										float p = volProb[i][j][k][l][c];
										if(!Float.isNaN(p)) {
											float x=idVol.getFloat(i, j, k, l);
											count++;									
											//pik
											sumP+=p;
											//muk
											next_theta_muk+=x*p;
											//sigmak
											float diff = x-theta_muk[c];
											float val = p*diff*diff;
											next_theta_sigmaK+=val;
										}
									}
								} else {
									float p = volProb[i][j][k][l][c];
									if(!Float.isNaN(p)) {
										float x=idVol.getFloat(i, j, k, l);
										count++;									
										//pik
										sumP+=p;
										//muk
										next_theta_muk+=x*p;
										//sigmak
										float diff = x-theta_muk[c];
										float val = p*diff*diff;
										next_theta_sigmaK+=val;
									}
								}
							}
						}
					}
				}
				theta_pik[c] = sumP/count;
				theta_muk[c] = next_theta_muk/sumP;
				theta_sigmaK[c]=(float)Math.max(10, Math.sqrt(next_theta_sigmaK/sumP));; //prevent sigma->0


				//				weightsArray  = upWeights(idMask, volProb, clusters);
				//				emMeans = upMeans(idVol, idMask, volProb, clusters);
				//				emCov = upCov(idVol, idMask, volProb, clusters, emMeans);

			}

			float diff=0;
			for(int c=0;c<theta_muk.length;c++)
				diff+=Math.abs(last_theta_muk[c]-theta_muk[c]);
			System.out.println("Stopping criteria:"+diff);
			if(diff<1)
				stopEM=true;
		}

		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					for (int l = 0; l < components; l++) {						
						int label =0;
						float p=0;
						//for each cluster at this point
						if(maskUsed && idMask!=null) {
							if(idMask.getInt(i, j, k ,l) == 1) {
								for (int c = 0; c < clusters; c++) {
									if(volProb[i][j][k][l][c]>p) {
										p=volProb[i][j][k][l][c];
										label=c;
									}										
								}									
							}
						} else {
							for (int c = 0; c < clusters; c++) {
								if(volProb[i][j][k][l][c]>p) {
									p=volProb[i][j][k][l][c];
									label=c;
								}										
							}	
						}
						volAssign[i][j][k][l]=label;
						
						//volAssign[i][j][k][l]=(int)(volProb[i][j][k][l][0]*10000);
					}
				}
			}
		}

	}

	private float[] GMMprobKgivXTheta(float x, float[] theta_pik,
			float[] theta_muk, float[] theta_sigmaK) {
		float pXgivKTheta[] = new float[theta_pik.length];
		for(int k=0;k<theta_pik.length;k++) {
			pXgivKTheta[k] = GMMprobXgivTheta(x,theta_muk[k],theta_sigmaK[k]);
		}
		float result[] = new float[theta_pik.length];
		float sum=0;
		for(int k=0;k<theta_pik.length;k++) {
			result[k] = theta_pik[k]*pXgivKTheta[k];
			sum+= result[k];
		}
		for(int k=0;k<result.length;k++)
			result[k]/=sum;
		return result;
		/* 
		 * 
pXgivKTheta = GMMprobXgivTheta(X,muk(k,:),sigmak(:,:,k));
pK = pik(k);

num = pK*pXgivKTheta;

denom=0;
for jk=1:length(pik)
    pXgivKTheta = GMMprobXgivTheta(X,muk(jk,:),sigmak(:,:,jk));
    pK = pik(jk);
    denom = denom + pK*pXgivKTheta;
end
p = num./denom;
		 */		
	}

	private float GMMprobXgivTheta(float x, float mu, float sigma) {
		/* 
		function p = GMMprobXgivTheta(X,mu,sigma)
		N=size(X,1);

		p = zeros([N 1]);
		d = size(sigma,1);
		for i=1:N
		x=(X(i,:)-mu);
		y = x*inv(sigma(:,:))*x';
		p(i) = 1/(2*pi)^(d/2)/sqrt(det(sigma)) * exp(-.5*y);
		end
		 */
		float xMmu = x-mu;	
		float prob= (float)(1/Math.sqrt(2*Math.PI*sigma*sigma)*Math.exp(-1*xMmu*xMmu/(sigma*sigma)));
		if(Float.isNaN(prob)) 
			return 1e-10f; //numeric difficulties
		else 
			return prob;

	}

	private void doKMeans(int[][][][] volAssign, ImageData idVol,
			ImageData idMask, boolean useMask2, float[] meansArray, int clusters) {
		System.out.println("K-Means Algorithm Running...");

		int rows = volAssign.length;
		int cols = volAssign[0].length;
		int slices = volAssign[0][0].length;
		int components = volAssign[0][0][0].length;
		boolean stopKM = false;//new boolean[clusters];
		//	for (int i = 0; i < clusters; i++)
		//		stopKM[i] = false;

		while (!stopKM) {//(checkStop(stopKM) != true){
			double deltaKM=0;
			System.out.println("K-means:");
			for(int i=0;i<meansArray.length;i++)
				System.out.println("\t"+i+":"+meansArray[i]);

			//Assignment Step
			for (int i = 0; i < rows; i++) {
				for (int j = 0; j < cols; j++) {
					for (int k = 0; k < slices; k++) {
						for (int l = 0; l < components; l++) {
							if(idMask!=null) {
								if (idMask.getInt(i, j, k ,l) == 1)
									volAssign[i][j][k][l] = assignMean(idVol.getFloat(i,j,k,l),meansArray);
							} else {
								volAssign[i][j][k][l] = assignMean(idVol.getFloat(i,j,k,l),meansArray);
							}

						}
					}
				}
			}

			//Update Step
			for (int h = 0; h < clusters; h++) {	//for each mean...
				int counter = 0;
				int sumVal = 0;

				//scan through volume for assigned points, and sum values
				for (int i = 0; i < rows; i++) {
					for (int j = 0; j < cols; j++) {
						for (int k = 0; k < slices; k++) {
							for (int l = 0; l < components; l++) {
								if(idMask!=null) {
									if (idMask.getInt(i, j, k ,l) == 1){
										if (h == volAssign[i][j][k][l])
										{
											counter++;
											sumVal += idVol.getFloat(i,j,k,l);
										}
										else {}
									}
									else {}
								} else {
									if (h == volAssign[i][j][k][l])
									{
										counter++;
										sumVal += idVol.getFloat(i,j,k,l);
									}
								}
							}
						}
					}
				}//end scanning through volume

				//Update mean/stopping criteria (if rounded values are the same)
				if(counter==0)
					counter=1;
				float chkVal =  ((float)sumVal/counter);
				//			if()
				/*if (Math.round(meansArray[h]) != Math.round(chkVal)) { //why round?
				//update this mean
				meansArray[h] = chkVal;
			}
			else {stopKM[h] = true;}*/
				deltaKM+=(Math.abs(meansArray[h]-chkVal));
				meansArray[h]=chkVal;



			} //end for loop for means
			System.out.println("Stopping criteria:"+deltaKM);
			if(deltaKM<1)
				stopKM=true;
		} //end while loop for stopping criterion

	}

	//Other methods


	/**
	 * Returns the index of the closest mean (in Euclidean distance)
	 * @param	ptVal	the value of the input point
	 * @param	means	the array of means
	 * @return			the index of the Euclidean nearest mean
	 */
	protected int assignMean(float ptVal, float[] means){
		//initialize Euclidean distance with first mean
		float euclidean = (float) (Math.sqrt(Math.pow((ptVal - means[0]), 2)));
		int index = 0;
		float tmp;

		//iterate through means array and keep smallest distance and corresponding index
		for (int i=0; i < means.length; i++){
			tmp = (float) (Math.sqrt(Math.pow((ptVal - means[i]), 2)));
			if (tmp < euclidean){
				euclidean = tmp;
				index = i;
			}
			else {}
		}

		return index;
	} //end assignMean method

	/**
	 * Given an boolean array of stopping criteria, returns true if all are true, and false if at least one is false
	 * @param 	s	array of booleans
	 * @return		true if all true, false if at least one false
	 */
	protected boolean checkStop(boolean[] s){
		boolean out = true;
		for (int i = 0; i < s.length; i++)
		{
			out = out && s[i];
		}
		return out;
	} //end checkStop method

	/**
	 * Calculate the argument (?) of the exponential in the Gaussian pdf
	 * @param val
	 * @param emMeans
	 * @param emCov
	 * @param c
	 * @return float value of evaluated expression
	 */
	protected float getExpTerms(float val, Matrix emMeans, Matrix emCov, int c) {
		Matrix valMatrix = new Matrix(emMeans.getRowDimension(),1);
		Matrix diffMatrix = new Matrix(emMeans.getRowDimension(),1);
		diffMatrix = valMatrix.minus(emMeans);
		return (- (float)((diffMatrix.transpose().times(emCov)).times(diffMatrix)).trace()) / 2;
	}

	/**
	 * Given an volume of assigned clusters, the number of clusters, and a mask, calculate the initial weights
	 * of the clusters
	 * @param inMask Image volume of a mask to be used
	 * @param m whether or not the mask is used
	 * @param assignVol the volume of assigned clusters
	 * @param numClusters the number of clusters
	 * @return
	 */
	protected float[] initWeights(ImageData inMask, boolean m, int[][][][] assignVol, int numClusters) {
		float[] out = new float[numClusters];
		int counter = 0;

		for (int i = 0; i < assignVol.length; i++) {
			for (int j = 0; j < assignVol[1].length; j++) {
				for (int k = 0; k < assignVol[1][1].length; k++) {
					for (int l = 0; l < assignVol[1][1][1].length; l++) {
						if (inMask.getInt(i, j, k ,l) == 1) {
							out[assignVol[i][j][k][l]] += 1;
							counter++;
						}
					}
				}
			}
		}


		for(int i = 0; i <= numClusters - 1; i++) {
			out[i] = out[i] / counter;
		}

		return out;
	} //end initWeights method


	/**
	 * Initialize the covariance matrix with the variances of each cluster along the diagonal
	 * @param idVol
	 * @param idMask
	 * @param maskUsed
	 * @param volAssign
	 * @param clusters
	 * @param emMeans
	 * @return
	 */
	protected Matrix[] initCov(ImageData inVol, ImageData inMask, boolean m, int[][][][] inAssign, int numClusters, Matrix inMeans) {
		//initialize k-by-k return matrix
		Matrix[] out = new Matrix[numClusters];
		for (int i = 0; i < numClusters-1; i++)
			out[i] = new Matrix(numClusters, numClusters);
		//initialize k-size array of summed denominator values, do same for counts
		double [] numerator = new double[numClusters];
		int []	counts = new int[numClusters];


		//for every point in the volume
		for (int i = 0; i < inVol.getRows(); i++) {
			for (int j = 0; j < inVol.getCols(); j++) {
				for (int k = 0; k < inVol.getSlices(); k++) {
					for (int l = 0; l < inVol.getComponents(); l++) {
						//for every cluster, if we are at unmasked values
						for (int c = 0; c <= numClusters - 1; c++) {
							if (inMask.getInt(i, j, k ,l) == 1 && inAssign[i][j][k][l] == c) {
								counts[c]++;
								numerator[c] += Math.pow((inVol.getDouble(i, j, k, l) - inMeans.get(c,0)), 2);
							}
						}
					}
				}
			}
		}

		Matrix ident = new Matrix(numClusters, numClusters);
		for (int i = 0; i <= numClusters - 1; i++)
			ident.set(i, i, 1);

		//along the diagonal of the covariance matrix
		for (int r = 0; r <= numClusters - 1; r++) {
			out[r] = ident.times((double) (numerator[r]/counts[r]));
		}
		return out;
	} //end initCov

	/**
	 * Update the weights
	 * @param idMask
	 * @param volProb
	 * @param clusters
	 * @return
	 */
	protected float[] upWeights(ImageData inMask, float[][][][][] inProb, int numClusters) {
		float[] out = new float[numClusters];

		int count = 0;
		for (int i = 0; i <= inProb.length - 1; i++) {
			for (int j = 0; j <= inProb[0].length - 1; j++) {
				for (int k = 0; k <= inProb[0][0].length - 1; k++) {
					for (int l = 0; l <= inProb[0][0][0].length - 1; l++) {
						//for every k cluster
						for (int c = 0; c <= numClusters - 1; c++) {
							if (inMask.getInt(i, j, k, l) == 1) {
								out[c] += inProb[i][j][k][l][c]; //sum corresponding probabilities
								count++;
							}
						}
					}
				}
			}
		} //end scan through volume

		//divide numerator by denominator
		for (int i = 0; i <= numClusters - 1; i++)
			out[i] = out[i] / count;

		return out;
	}

	/**
	 * Update the means
	 * @param inVol
	 * @param inMask
	 * @param inProb
	 * @param numClusters
	 * @return
	 */
	protected Matrix upMeans(ImageData inVol, ImageData inMask, float[][][][][] inProb, int numClusters) {
		float[] denom = new float[numClusters];
		float[] numer = new float[numClusters];
		Matrix out = new Matrix(numClusters,1);

		for (int i = 0; i <= inProb.length - 1; i++) {
			for (int j = 0; j <= inProb[0].length - 1; j++) {
				for (int k = 0; k <= inProb[0][0].length - 1; k++) {
					for (int l = 0; l <= inProb[0][0][0].length - 1; l++) {
						//for every k cluster
						for (int c = 0; c <= numClusters - 1; c++) {
							if (inMask.getInt(i, j, k, l) == 1) {
								denom[c] += inProb[i][j][k][l][c]; //sum corresponding probabilities
								numer[c] += inProb[i][j][k][l][c] * inVol.getFloat(i, j, k, l);
							}
						}
					}
				}
			}
		} //end scan through volume

		for (int i = 0; i <= numClusters - 1; i++)
			out.set(i, 1 , (double) (numer[i] / denom[i]));

		return out;
	}

	protected Matrix[] upCov(ImageData inVol, ImageData inMask, float[][][][][] inProb, int numClusters, Matrix inMeans) {
		float[] denom = new float[numClusters];
		Matrix[] numer = new Matrix[numClusters];
		for (int i = 0; i <= numClusters - 1; i++)
			numer[i] = new Matrix(numClusters,numClusters);


		for (int i = 0; i <= inProb.length - 1; i++) {
			for (int j = 0; j <= inProb[0].length - 1; j++) {
				for (int k = 0; k <= inProb[0][0].length - 1; k++) {
					for (int l = 0; l <= inProb[0][0][0].length - 1; l++) {
						//for every k cluster
						for (int c = 0; c <= numClusters - 1; c++) {
							if (inMask.getInt(i, j, k, l) == 1) {
								denom[c] += inProb[i][j][k][l][c]; //sum corresponding probabilities
								Matrix tmp = new Matrix(numClusters,1);
								for (int x = 0; x < numClusters; x++)
									tmp.set(x,1,(double) inVol.getFloat(i, j, k, l));
								tmp = tmp.minus(inMeans);
								numer[c] = numer[c].plus((tmp.times(tmp.transpose())).times((double) inProb[i][j][k][l][c]));
							}
						}
					}
				}
			}
		} //end scan through volume

		for (int i = 0; i <= numClusters - 1; i++)
			numer[i] = numer[i].times((double) (1 / denom[i]));

		return null;
	}






} //end HW2_Segmentation
