package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;
import java.util.List;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class LabelSetGenericIdealSTAPLE {

	int max;
	private List<ImageData> raters;
	private ImageData truth;
	private int[] dims;
	private float[][][] theta;
	private int[] croppingRegion;
	private int numRaters;
	private ImageData estimate;
	private int numTarget;

	public LabelSetGenericIdealSTAPLE(List<ImageData> raters, ImageData truth){
		max = 0;
		this.raters = raters;
		this.truth = truth;
		numRaters = this.raters.size();
		dims = new int[4];
		dims[0] = truth.getRows();
		dims[1] = truth.getCols();
		dims[2] = truth.getSlices();
		dims[3] = truth.getComponents();
		determineMax();
		JistLogger.logOutput(JistLogger.INFO, String.format("The maximum value in any image was %d.\nThe number of labels in the target segmentation is %d.", max,numTarget));
		theta = new float[max+1][numTarget+1][raters.size()];
		for(float[][] mat : theta)
			for(float[] row: mat)
				Arrays.fill(row, 0f);
		determineCroppingRegion();
		runEM();
	}

	private void determineMax(){
		max = getMax(truth);
		numTarget = max;
		int val=0;
		for(ImageData im: raters){
			val = getMax(im);
			max = (max > val) ? max : val;
		}


	}

	private int getMax(ImageData im){
		int m = 0;
		for(int i=0;i<dims[0];i++){
			for(int j=0;j<dims[1];j++){
				for(int k=0;k<dims[2];k++){
					for(int l=0;l<dims[3];l++){
						int obs = im.getInt(i,j,k,l);
						m = (obs > m) ? obs : m;
					}
				}
			}
		}
		return m;
	}

	private void runEM(){
		JistLogger.logOutput(JistLogger.INFO, "Starting EM Algorithm");
		JistLogger.logOutput(JistLogger.INFO, "Running M-Step with Truth");
		JistLogger.logFlush();
		for(int i=croppingRegion[0];i<croppingRegion[1];i++){
			for(int j=croppingRegion[2];j<croppingRegion[3];j++){
				for(int k=croppingRegion[4];k<croppingRegion[5];k++){
					for(int l=0;l<dims[3];l++){
						runMStep(i,j,k,l,truth.getInt(i, j,k,l));
					}
				}
			}
		}
		normalizeTheta();
		JistLogger.logOutput(JistLogger.INFO, "Running E-Step");
		JistLogger.logFlush();
		estimate = new ImageDataInt("Label Estimate", dims[0], dims[1],dims[2],dims[3]);
		for(int i=0;i<dims[0];i++){
			for(int j=0;j<dims[1];j++){
				for(int k=0;k<dims[2];k++){
					for(int l=0;l<dims[3];l++){
						int label = getEstimate(i,j,k,l);
						estimate.set(i,j,k,l,label);
					}
				}
			}
		}
	}
	
	private int getEstimate(int i,int j,int k,int l){
		
		int label = 0;
		if(!inCroppingRegion(i,j,k))
			return label;
		double[] prob = new double[numTarget];
		Arrays.fill(prob, Math.log(1f/numTarget));
		for(int x=0;x<raters.size();x++){
			int obs = raters.get(x).getInt(i, j,k,l);
			for(int y=0;y<prob.length;y++){
				prob[y] += Math.log(theta[obs][y][x]);
			}
		}
		double maxFact = Double.NEGATIVE_INFINITY;
		for(double p:prob)
			if(p > maxFact)
				maxFact = p;
		for(int x=0;x<prob.length;x++){
			prob[x] = (float) Math.exp(prob[x] - maxFact);
		}
		float max = -1;
		for(int x=0;x<prob.length;x++)
			if(prob[x]>max){
				label = x;
				max = (float) prob[x];
			}
		return label;
	}

	private boolean inCroppingRegion(int i,int j, int k){
		return (i>=croppingRegion[0]&&i<=croppingRegion[1]&&j>=croppingRegion[2]&&j<=croppingRegion[3]&&k>=croppingRegion[4]&&k<=croppingRegion[5]);
	}
	
	private void normalizeTheta(){
		float sum;
		for(int i=0;i<numRaters;i++){
			for(int j=0;j<theta[0].length;j++){
				sum = 0;
				for(int k=0;k<theta.length;k++){
					sum += theta[k][j][i];
				}
				if(sum > 0){
					for(int k=0;k<theta.length;k++){
						theta[k][j][i] = theta[k][j][i] / sum;
					}
				}
			}
		}
	}

	private void runMStep(int i,int j,int k, int l, int t){
		for(int x=0;x<raters.size();x++){
			ImageData im = raters.get(x);
			int obs = im.getInt(i,j,k,l);
			theta[obs][t][x] += 1f;
		}
	}

	public ImageData getEstimate(){ return estimate; }

	public ImageData getThetaImage(){
		ImageData t = new ImageDataFloat("Theta",theta.length, theta[0].length, theta[0][0].length);
		for(int i=0;i<theta.length;i++)
			for(int j=0;j<theta[i].length;j++)
				for(int k=0;k<theta[i][j].length;k++)
					t.set(i,j,k,theta[i][j][k]);
		return t;
	}

	private void determineCroppingRegion(){
		JistLogger.logOutput(JistLogger.INFO, "Determining cropping region");
		int r = raters.get(0).getRows();
		int c = raters.get(0).getCols();
		int s = raters.get(0).getSlices();
		int k = raters.get(0).getComponents();
		ImageData tmp = new ImageDataInt(r,c,s,k);
		croppingRegion = new int[6];
		Arrays.fill(croppingRegion, 0);
		croppingRegion[1] = r;
		croppingRegion[3] = c;
		croppingRegion[5] = s;
		for(ImageData im:raters){
			for(int x = 0;x < r; x++){
				for(int y=0;y<c;y++){
					for(int z=0;z<s;z++){
						for(int a=0;a<k;a++){
							tmp.set(x,y,z,a,tmp.getInt(x, y,z,a)+im.getInt(x,y,z,a));
						}
					}
				}
			}
		}
		//Set x lower bound
		outerloop1: //magic
			for(int x = 0;x < r; x++){
				for(int y=0;y<c;y++){
					for(int z=0;z<s;z++){
						for(int a=0;a<k;a++){
							int val = tmp.getInt(x, y, z, a);
							if(val > 0){
								croppingRegion[0]=x;
								break outerloop1;
							}
						}
					}
				}
			}

		//Set x upper bound
		outerloop2: //magic
			for(int x = r-1;x >=0; x--){
				for(int y=0;y<c;y++){
					for(int z=0;z<s;z++){
						for(int a=0;a<k;a++){
							int val = tmp.getInt(x, y, z, a);
							if(val > 0){
								croppingRegion[1]=x+1;
								break outerloop2;
							}
						}
					}
				}
			}

			//Set y lower bound
			outerloop3: //magic
				for(int y=0;y<c;y++){
					for(int x=0;x<r;x++){
						for(int z=0;z<s;z++){
							for(int a=0;a<k;a++){
								int val = tmp.getInt(x, y, z, a);
								if(val > 0){
									croppingRegion[2]=y;
									break outerloop3;
								}
							}
						}
					}
				}

			//Set y upper bound
			outerloop4: //magic
				for(int y=c-1;y>=0;y--){
					for(int x=0;x<r;x++){
						for(int z=0;z<s;z++){
							for(int a=0;a<k;a++){
								int val = tmp.getInt(x, y, z, a);
								if(val > 0){
									croppingRegion[3]=y+1;
									break outerloop4;
								}
							}
						}
					}
				}

				//Set z lower
				outerloop5: //magic
					for(int z=0;z<s;z++){
						for(int x=0;x<r;x++){
							for(int y=0;y<c;y++){
								for(int a=0;a<k;a++){
									int val = tmp.getInt(x, y, z, a);
									if(val > 0){
										croppingRegion[4]=z;
										break outerloop5;
									}
								}
							}
						}
					}

				//Set z upper
				outerloop6: //magic
					for(int z=s-1;z>=0;z--){
						for(int x=0;x<r;x++){
							for(int y=0;y<c;y++){
								for(int a=0;a<k;a++){
									int val = tmp.getInt(x, y, z, a);
									if(val > 0){
										croppingRegion[5]=z+1;
										break outerloop6;
									}
								}
							}
						}
					}

					JistLogger.logOutput(JistLogger.INFO, String.format("Cropping region is x: [%d %d] y: [%d %d] z: [%d %d]",croppingRegion[0],croppingRegion[1],croppingRegion[2],croppingRegion[3],croppingRegion[4],croppingRegion[5]));

	}
}
