package edu.jhu.pami.spring2009;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.pipeline.parameter.ParamObject;
import edu.jhu.ece.iacl.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.structures.image.ImageData;
import edu.jhu.ece.iacl.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.structures.image.ImageDataFloat;
//import edu.jhu.ece.iacl.structures.matrix.Matrix2;
import edu.jhu.bme.smile.commons.math.StatisticsDouble;
import edu.jhu.bme.smile.commons.math.MatrixMath;
import Jama.Matrix;

/**
 * Segmentation Module, developed for PAMI HW2
 *
 * @author Yu "Charlie" Ouyang
 */
public class HW2_Segmentation extends ProcessingAlgorithm{
	//Input Variables
	private ParamVolume vol;	//input volume
	private ParamBoolean useMask;
	private ParamVolume	mask;	//mask for input
	private ParamOption alg;	//algorithm
	private ParamInteger k;		//cluster number

	//Output Variables
	private ParamVolume	out;	//output volume

	//CVS
	private static final String rcsid = "$Id: HW2_Segmentation.java,v 1.5 2009/04/14 04:44:18 jhuuser Exp $";
	private static final String cvsversion = "$Revision: 1.5 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "");

	protected void createInputParameters(ParamCollection inputParams){
	//Plugin Information
		inputParams.setCategory("PAMI");
		inputParams.setPackage("Spring2009");
		inputParams.setName("HW2-Segmentation");
		inputParams.setLabel("Segmentation");
		AlgorithmInformation info=getAlgorithmInformation();
		info.setWebsite("http://sites.google.com/site/jhupami/");
		info.add(new AlgorithmAuthor("Yu Ouyang","couyang@jhu.edu","http://sites.google.com/a/jhu.edu/pami/"));
		info.setDescription("Allows segmentation with k-means or maximum likelihood with Gaussian mixture models.");
		info.setAffiliation("Johns Hopkins University, Department of Biomedical Engineering");
		info.setVersion(revnum);

	//Inputs
		inputParams.add(vol = new ParamVolume("Input Volume"));
		inputParams.add(mask = new ParamVolume("Mask"));
		//mask.setMandatory(false);
		inputParams.add(useMask = new ParamBoolean ("Use mask", true));
		inputParams.add(alg = new ParamOption("Algorithm", new String[]{"K-Means", "ML"}));
		inputParams.add(k = new ParamInteger("Clusters"));
		k.setValue(new Integer(3));
	}

	protected void createOutputParameters(ParamCollection outputParams) {
		outputParams.add(out = new ParamVolume("Output Volume"));
	}

	protected void execute(CalculationMonitor monitor){
		/**
		 * Initialization
		 */
		//image data properties
		ImageData idVol = vol.getImageData();
		ImageData idMask = mask.getImageData();
		boolean maskUsed = useMask.getValue();
		int rows = idVol.getRows();
		int cols = idVol.getCols();
		int slices = idVol.getSlices();
		int components = idVol.getComponents();
		float maxVal = 0;
		//user parameter
		int clusters = k.getInt();

		//read through idVol for its maximum value
		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					for (int l = 0; l < components; l++) {
						if (maskUsed == true && idMask.getInt(i, j, k ,l) == 1)
							maxVal = Math.max(maxVal, idVol.getFloat(i,j,k,l));
					}
				}
			}
		}

		//initialize array of k means and generate random means
		float [] meansArray = new float[clusters]; //array of random means
		java.util.Random rand = new java.util.Random();
		for (int n = 0; n < clusters; n++){	//fill array
			meansArray[n] = maxVal*rand.nextFloat();
		}

		//initialize Assignment Volume
		int [][][][]volAssign = new int[rows][cols][slices][components];

		/**
		 * Algorithms
		 */

		switch(alg.getIndex()){

//K-Means Algorithm
			case 0:

				System.out.println("K-Means Algorithm Running...");

				boolean[] stopKM = new boolean[clusters];
				for (int i = 0; i < clusters; i++)
					stopKM[i] = false;

				while (checkStop(stopKM) != true){

					//Assignment Step
					for (int i = 0; i < rows; i++) {
						for (int j = 0; j < cols; j++) {
							for (int k = 0; k < slices; k++) {
								for (int l = 0; l < components; l++) {
									if (idMask.getInt(i, j, k ,l) == 1)
										volAssign[i][j][k][l] = assignMean(idVol.getFloat(i,j,k,l),meansArray);
								}
							}
						}
					}

					//Update Step
					for (int h = 0; h < clusters; h++) {	//for each mean...
						int counter = 0;
						int sumVal = 0;

						//scan through volume for assigned points, and sum values
						for (int i = 0; i < rows; i++) {
							for (int j = 0; j < cols; j++) {
								for (int k = 0; k < slices; k++) {
									for (int l = 0; l < components; l++) {
										if (idMask.getInt(i, j, k ,l) == 1){
											if (h == volAssign[i][j][k][l])
											{
												counter++;
												sumVal += idVol.getFloat(i,j,k,l);
											}
											else {}
										}
										else {}
									}
								}
							}
						}//end scanning through volume

						//Update mean/stopping criteria (if rounded values are the same)
						float chkVal = (float) (sumVal/counter);
						if (Math.round(meansArray[h]) == Math.round(chkVal)) {
							//update this mean
							meansArray[h] = chkVal;
						}
						else {stopKM[h] = true;}

					} //end for loop for means
				} //end while loop for stopping criterion

			break;


//EM Algorithm

			case 1:

			/**
			 * Initialize with K-Means first
			 */
				System.out.println("EM Algorithm: Initializing with K-Means...");

				//initialize parameters... using K-means
				boolean[] stopEM = new boolean[clusters];
				for (int i = 0; i < clusters; i++)
					stopEM[i] = false;

				while (checkStop(stopEM) != true){

					//Assignment Step
					for (int i = 0; i < rows; i++) {
						for (int j = 0; j < cols; j++) {
							for (int k = 0; k < slices; k++) {
								for (int l = 0; l < components; l++) {
									if (idMask.getInt(i, j, k ,l) == 1)
										volAssign[i][j][k][l] = assignMean(idVol.getFloat(i,j,k,l),meansArray);
								}
							}
						}
					}

					//Update Step
					for (int h = 0; h < clusters; h++) {	//for each mean...
						int counter = 0;
						int sumVal = 0;

						//scan through volume for assigned points, and sum values
						for (int i = 0; i < rows; i++) {
							for (int j = 0; j < cols; j++) {
								for (int k = 0; k < slices; k++) {
									for (int l = 0; l < components; l++) {
										if (idMask.getInt(i, j, k ,l) == 1){
											if (h == volAssign[i][j][k][l])
											{
												counter++;
												sumVal += idVol.getFloat(i,j,k,l);
											}
											else {}
										}
										else {}
									}
								}
							}
						}//end scanning through volume

						//Update mean/stopping criteria (if rounded values are the same)
						float chkVal = (float) (sumVal/counter);
						if (Math.round(meansArray[h]) == Math.round(chkVal)) {
							//update this mean
							meansArray[h] = chkVal;
						}
						else {stopEM[h] = true;}

					} //end for loop for means
				} //end while loop for stopping criterion
			/**
			 * End K-Means for Intialization
			 */
				//calculate weights in weightsArray - 1st dimension is clusters, 2nd is iteration (1=current-1, 2=current)
				float weightsArray[] = new float[clusters];
				weightsArray = initWeights(idMask, maskUsed, volAssign, clusters);
				System.out.println("EM Algorithm: weightsArray initialized.");

				//meansArray already exists, just convert to double for use with Matrix (loss of precision?)
				double[][] tmpMeansArray = new double[clusters][1];
				for (int i = 0; i <= clusters - 1; i++)
					tmpMeansArray[i][0] = (double) meansArray[i];
				Matrix emMeans = new Matrix(tmpMeansArray);
				System.out.println("EM Algorithm: emMeans initialized.");

				//Initialize covariance matrix - start with only sigma*diag([]), update will take care of it
				Matrix[] emCov = new Matrix[clusters];
				for (int i=0; i <= clusters - 1; i++)
					emCov[i] = new Matrix(clusters, clusters);
				emCov = initCov(idVol, idMask, maskUsed, volAssign, clusters, emMeans);
				System.out.println("EM Algorithm: emCov initialized.");

				//volume of posterior probabilities (extra dimension for clusters)
				float [][][][][]volProb = new float[rows][cols][slices][components][clusters];

			/**
			 * End Initialization
			 */

			//E-STEP - compute posterior probabilities for each cluster at each point
				System.out.println("EM Algorithm: E-Step");
				//for every point in the volume
				for (int i = 0; i < rows; i++) {
					for (int j = 0; j < cols; j++) {
						for (int k = 0; k < slices; k++) {
							for (int l = 0; l < components; l++) {

								//for each cluster at this point
								for (int c = 0; c <= clusters - 1; c++) {

									//If we are at unmasked values, and if we are on the correct cluster
									if (idMask.getInt(i, j, k ,l) == 1 && volAssign[i][j][k][l] == c) {
										//calculate numerator for p_i,k
										volProb[i][j][k][l][c] = (weightsArray[c] * (float) (1 / Math.sqrt((2*Math.PI)*emCov[c].det())) * (float) Math.exp((getExpTerms(idVol.getFloat(i, j, k, l), emMeans, emCov[c], c))/2));
										//calculate denominator for p_i,k
										float tmpDenom = 0;
										for (int x = 0; x <= clusters - 1; x++)
											tmpDenom += (weightsArray[x] * (float) (1 / Math.sqrt((2*Math.PI)*emCov[c].det())) * (float) Math.exp((getExpTerms(idVol.getFloat(i, j, k, l), emMeans, emCov[c], x))/2));
										//divide numerator by denominator
										volProb[i][j][k][l][c] = volProb[i][j][k][l][c] / tmpDenom;
									}

								} //end for loop for each cluster

							}
						}
					}
				} //end scanning through volume


			//M-STEP - update cluster weights, means, and covariance
				System.out.println("EM Algorithm: M-Step");
				weightsArray  = upWeights(idMask, volProb, clusters);
				emMeans = upMeans(idVol, idMask, volProb, clusters);
				emCov = upCov(idVol, idMask, volProb, clusters, emMeans);

			break;

		} //end switch

//End Classification, output segmented volume

		//rescale assignment volume (segmentation volume), and output
		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					for (int l = 0; l < components; l++) {
						if (idMask.getInt(i, j, k ,l) == 1)
							volAssign[i][j][k][l] = ((volAssign[i][j][k][l])+1)*(255/(clusters+1));
					}
				}
			}
		}
		ImageDataInt tmp = (new ImageDataInt(volAssign));
		tmp.setName("Segmentation");
		out.setValue(tmp);


	} //end execute method


//Other methods


	/**
	 * Returns the index of the closest mean (in Euclidean distance)
	 * @param	ptVal	the value of the input point
	 * @param	means	the array of means
	 * @return			the index of the Euclidean nearest mean
	 */
	protected int assignMean(float ptVal, float[] means){
		//initialize Euclidean distance with first mean
		float euclidean = (float) (Math.sqrt(Math.pow((ptVal - means[0]), 2)));
		int index = 0;
		float tmp;

		//iterate through means array and keep smallest distance and corresponding index
		for (int i=0; i < means.length; i++){
			tmp = (float) (Math.sqrt(Math.pow((ptVal - means[i]), 2)));
			if (tmp < euclidean){
				euclidean = tmp;
				index = i;
			}
			else {}
		}

		return index;
	} //end assignMean method

	/**
	 * Given an boolean array of stopping criteria, returns true if all are true, and false if at least one is false
	 * @param 	s	array of booleans
	 * @return		true if all true, false if at least one false
	 */
	protected boolean checkStop(boolean[] s){
		boolean out = true;
		for (int i = 0; i < s.length; i++)
		{
			out = out && s[i];
		}
		return out;
	} //end checkStop method

	/**
	 * Calculate the argument (?) of the exponential in the Gaussian pdf
	 * @param val
	 * @param emMeans
	 * @param emCov
	 * @param c
	 * @return float value of evaluated expression
	 */
	protected float getExpTerms(float val, Matrix emMeans, Matrix emCov, int c) {
		Matrix valMatrix = new Matrix(emMeans.getRowDimension(),1);
		Matrix diffMatrix = new Matrix(emMeans.getRowDimension(),1);
		diffMatrix = valMatrix.minus(emMeans);
		return (- (float)((diffMatrix.transpose().times(emCov)).times(diffMatrix)).trace()) / 2;
	}

	/**
	 * Given an volume of assigned clusters, the number of clusters, and a mask, calculate the initial weights
	 * of the clusters
	 * @param inMask Image volume of a mask to be used
	 * @param m whether or not the mask is used
	 * @param assignVol the volume of assigned clusters
	 * @param numClusters the number of clusters
	 * @return
	 */
	protected float[] initWeights(ImageData inMask, boolean m, int[][][][] assignVol, int numClusters) {
		float[] out = new float[numClusters];
		int counter = 0;

		for (int i = 0; i < assignVol.length; i++) {
			for (int j = 0; j < assignVol[1].length; j++) {
				for (int k = 0; k < assignVol[1][1].length; k++) {
					for (int l = 0; l < assignVol[1][1][1].length; l++) {
						if (inMask.getInt(i, j, k ,l) == 1) {
							out[assignVol[i][j][k][l]] += 1;
							counter++;
						}
					}
				}
			}
		}


		for(int i = 0; i <= numClusters - 1; i++) {
			out[i] = out[i] / counter;
		}

		return out;
	} //end initWeights method


	/**
	 * Initialize the covariance matrices with the variances of each cluster along the diagonal
	 * @param idVol
	 * @param idMask
	 * @param maskUsed
	 * @param volAssign
	 * @param clusters
	 * @param emMeans
	 * @return
	 */
	protected Matrix[] initCov(ImageData inVol, ImageData inMask, boolean m, int[][][][] inAssign, int numClusters, Matrix inMeans) {
		//initialize array of covariance matrices
		Matrix[] out = new Matrix[numClusters];
		for (int i = 0; i < numClusters-1; i++)
			out[i] = new Matrix(numClusters, numClusters);
		//initialize k-size array of summed denominator values, do same for counts
		double [] numerator = new double[numClusters];
		int []	counts = new int[numClusters];


		//for every point in the volume
		for (int i = 0; i < inVol.getRows(); i++) {
			for (int j = 0; j < inVol.getCols(); j++) {
				for (int k = 0; k < inVol.getSlices(); k++) {
					for (int l = 0; l < inVol.getComponents(); l++) {
						//for every cluster, if we are at unmasked values and correctly assigned cluster
						for (int c = 0; c <= numClusters - 1; c++) {
							if (inMask.getInt(i, j, k ,l) == 1 && inAssign[i][j][k][l] == c) {
								counts[c]++;
								numerator[c] += Math.pow((inVol.getDouble(i, j, k, l) - inMeans.get(c,0)), 2);
							}
						}
					}
				}
			}
		}

		Matrix ident = new Matrix(numClusters, numClusters);
		for (int i = 0; i <= numClusters - 1; i++)
			ident.set(i, i, 1);

		//along the diagonal of the covariance matrix
		for (int r = 0; r <= numClusters - 1; r++) {
			out[r] = ident.times((double) (numerator[r]/counts[r]));
		}
		return out;
	} //end initCov

	/**
	 * Update the weights
	 * @param idMask
	 * @param volProb
	 * @param clusters
	 * @return
	 */
	protected float[] upWeights(ImageData inMask, float[][][][][] inProb, int numClusters) {
		float[] out = new float[numClusters];

		int count = 0;
		for (int i = 0; i <= inProb.length - 1; i++) {
			for (int j = 0; j <= inProb[0].length - 1; j++) {
				for (int k = 0; k <= inProb[0][0].length - 1; k++) {
					for (int l = 0; l <= inProb[0][0][0].length - 1; l++) {
						//for every k cluster
						for (int c = 0; c <= numClusters - 1; c++) {
							if (inMask.getInt(i, j, k, l) == 1) {
								out[c] += inProb[i][j][k][l][c]; //sum corresponding probabilities
								count++;
							}
						}
					}
				}
			}
		} //end scan through volume

		//divide numerator by denominator
		for (int i = 0; i <= numClusters - 1; i++)
			out[i] = out[i] / count;

		return out;
	}

	/**
	 * Update the means
	 * @param inVol
	 * @param inMask
	 * @param inProb
	 * @param numClusters
	 * @return
	 */
	protected Matrix upMeans(ImageData inVol, ImageData inMask, float[][][][][] inProb, int numClusters) {
		float[] denom = new float[numClusters];
		float[] numer = new float[numClusters];
		Matrix out = new Matrix(numClusters,1);

		for (int i = 0; i <= inProb.length - 1; i++) {
			for (int j = 0; j <= inProb[0].length - 1; j++) {
				for (int k = 0; k <= inProb[0][0].length - 1; k++) {
					for (int l = 0; l <= inProb[0][0][0].length - 1; l++) {
						//for every k cluster
						for (int c = 0; c <= numClusters - 1; c++) {
							if (inMask.getInt(i, j, k, l) == 1) {
								denom[c] += inProb[i][j][k][l][c]; //sum corresponding probabilities
								numer[c] += inProb[i][j][k][l][c] * inVol.getFloat(i, j, k, l);
							}
						}
					}
				}
			}
		} //end scan through volume

		for (int i = 0; i <= numClusters - 1; i++)
			out.set(i, 1 , (double) (numer[i] / denom[i]));

		return out;
	}

	protected Matrix[] upCov(ImageData inVol, ImageData inMask, float[][][][][] inProb, int numClusters, Matrix inMeans) {
		float[] denom = new float[numClusters];
		Matrix[] numer = new Matrix[numClusters];
		for (int i = 0; i <= numClusters - 1; i++)
			numer[i] = new Matrix(numClusters,numClusters);


		for (int i = 0; i <= inProb.length - 1; i++) {
			for (int j = 0; j <= inProb[0].length - 1; j++) {
				for (int k = 0; k <= inProb[0][0].length - 1; k++) {
					for (int l = 0; l <= inProb[0][0][0].length - 1; l++) {
						//for every k cluster
						for (int c = 0; c <= numClusters - 1; c++) {
							if (inMask.getInt(i, j, k, l) == 1) {
								denom[c] += inProb[i][j][k][l][c]; //sum corresponding probabilities
								Matrix tmp = new Matrix(numClusters,1);
								for (int x = 0; x < numClusters; x++)
									tmp.set(x,1,(double) inVol.getFloat(i, j, k, l));
								tmp = tmp.minus(inMeans);
								numer[c] = numer[c].plus((tmp.times(tmp.transpose())).times((double) inProb[i][j][k][l][c]));
							}
						}
					}
				}
			}
		} //end scan through volume

		for (int i = 0; i <= numClusters - 1; i++)
			numer[i] = numer[i].times((double) (1 / denom[i]));

		return null;
	}






} //end HW2_Segmentation
