package edu.vanderbilt.masi.algorithms.labelfusion.simple;

import java.util.List;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SimpleLabelVolume extends AbstractCalculation{

	private int[][][][] labVols;
	private int r,c,s,n;
	private boolean[][][] consensus;
	private int[][][] consLabs;
	private int minLab, maxLab;
	private int[] boundingBox; // [xl xh yl yh zl zh]
	private int origR, origC, origS;
	private ImageHeader header;

	public SimpleLabelVolume(List<ImageData> imDatList){
		super();
		JistLogger.logOutput(JistLogger.WARNING, "Initializing SIMPLE Volume");
		setVolumeInfo(imDatList.get(0));
		this.n = imDatList.size();
		JistLogger.logOutput(JistLogger.WARNING, String.format("There are %d volumes", this.n));
		JistLogger.logOutput(JistLogger.WARNING, String.format("The dimensions are [%d %d %d]",r,c,s));
		JistLogger.logOutput(JistLogger.WARNING, "Loading Data");
		loadData(imDatList);
		JistLogger.logOutput(JistLogger.WARNING, "Done Loading Data");
		JistLogger.logOutput(JistLogger.WARNING, "Determining Consensus Voxels");
		determineBoundingBox();
		JistLogger.logOutput(JistLogger.WARNING, 
				String.format("The bounding box is [%d:%d %d:%d %d:%d]", 
						boundingBox[0],boundingBox[1],boundingBox[2],
						boundingBox[3],boundingBox[4],boundingBox[5]));
		cropImages();
		setConsensus();
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("The range of labels is %d to %d",minLab,maxLab));
		JistLogger.logOutput(JistLogger.WARNING, "Determining Bounding Box");
	}

	private void cropImages(){
		minLab = 0;
		maxLab = 0;
		int[][][][] croppedVols = 
				new int[(boundingBox[1] - boundingBox[0])+1][(boundingBox[3]-boundingBox[2])+1][(boundingBox[5]-boundingBox[4])+1][n];
		for(int i=boundingBox[0];i<=boundingBox[1];i++)
			for(int j=boundingBox[2];j<=boundingBox[3];j++)
				for(int k=boundingBox[4];k<=boundingBox[5];k++)
					for(int l=0;l<n;l++){
						croppedVols[i-boundingBox[0]][j-boundingBox[2]][k-boundingBox[4]][l] = labVols[i][j][k][l];
						if(labVols[i][j][k][l] > maxLab)
							maxLab = labVols[i][j][k][l];
						if(labVols[i][j][k][l] < minLab)
							minLab = labVols[i][j][k][l];
					}
		this.origR = this.r;
		this.origC = this.c;
		this.origS = this.s;
		this.r = (boundingBox[1] - boundingBox[0]) + 1;
		this.c = (boundingBox[3] - boundingBox[2]) + 1;
		this.s = (boundingBox[5] - boundingBox[4]) + 1;
		labVols = croppedVols;
	}

	private void determineBoundingBox(){
		boundingBox = new int[6];
		outerloop:
			for(int x=0;x<r;x++)
				for(int y=0;y<c;y++)
					for(int z=0;z<s;z++)
						for(int i=0;i<n;i++)
							if(labVols[x][y][z][i] > 0){
								boundingBox[0] = x;
								break outerloop;
							}
		outerloop:
			for(int x=r-1;x>=0;x--)
				for(int y=0;y<c;y++)
					for(int z=0;z<s;z++)
						for(int i=0;i<n;i++)
							if(labVols[x][y][z][i] > 0){
								boundingBox[1] = x;
								break outerloop;
							}
							outerloop:
								for(int y=0;y<c;y++)
									for(int x=0;x<r;x++)
										for(int z=0;z<s;z++)
											for(int i=0;i<n;i++)
												if(labVols[x][y][z][i]>0){
													boundingBox[2] = y;
													break outerloop;
												}
							outerloop:
								for(int y=c-1;y>=0;y--)
									for(int x=0;x<r;x++)
										for(int z=0;z<s;z++)
											for(int i=0;i<n;i++)
												if(labVols[x][y][z][i]>0){
													boundingBox[3] = y;
													break outerloop;
												}
												outerloop:
													for(int z=0;z<s;z++)
														for(int y=0;y<c;y++)
															for(int x=0;x<r;x++)
																for(int i=0;i<n;i++)
																	if(labVols[x][y][z][i]>0){
																		boundingBox[4] = z;
																		break outerloop;
																	}
												outerloop:
													for(int z=s-1;z>=0;z--)
														for(int y=0;y<c;y++)
															for(int x=0;x<r;x++)
																for(int i=0;i<n;i++)
																	if(labVols[x][y][z][i]>0){
																		boundingBox[5] = z;
																		break outerloop;
																	}
	}

	public ImageData getSegmentation(int[][][] estimate){
		ImageData im = new ImageDataInt("segmentation",origR,origC,origS,1);
		im.setHeader(this.header);
		for(int x=0;x<r;x++)
			for(int y=0;y<c;y++)
				for(int z=0;z<s;z++)
					if(consensus[x][y][z])
						im.set(x+boundingBox[0],y+boundingBox[2],z+boundingBox[4], consLabs[x][y][z]);
					else
						im.set(x+boundingBox[0],y+boundingBox[2],z+boundingBox[4], estimate[x][y][z]);
		return im;
	}
	
	public int getConsLab(int i, int j, int k){ return this.consLabs[i][j][k]; }

	private void loadData(List<ImageData> imDatList){
		labVols = new int[r][c][s][n];
		for(int i=0;i<n;i++){
			ImageData imDat = imDatList.get(i);
			for(int x=0;x<r;x++)
				for(int y=0;y<c;y++)
					for(int z=0;z<s;z++)
						labVols[x][y][z][i] = imDat.getInt(x, y, z);
		}
	}

	private void setConsensus(){
		consensus = new boolean[r][c][s];
		consLabs  = new int[r][c][s];
		int nCons = 0;
		for(int x=0; x<r; x++)
			for(int y=0; y<c; y++)
				for(int z=0; z<s; z++){
					int defLab = labVols[x][y][z][0];
					consensus[x][y][z] = true;
					for(int i=0; i<n; i++){
						int curLab = labVols[x][y][z][i];
						if(curLab>maxLab)
							maxLab = defLab;
						if(curLab<minLab)
							minLab = defLab;
						if(curLab != defLab){
							consensus[x][y][z] = false;
							break;
						}
					}
					if(consensus[x][y][z]){
						nCons++;
						consLabs[x][y][z] = defLab;
					}
				}
		float p = (float) nCons / ((float) (r*c*s));
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("There are %d consensus voxels out of %d (%.3f percent)",
						nCons,r*c*s,p));
	}

	private void setVolumeInfo(ImageData im){
		this.r = im.getRows();
		this.c = im.getCols();
		this.s = im.getSlices();
		this.header = im.getHeader();
	}
	
	protected float[][][] cropToBoundingBox(float[][][] im){
		float[][][] res = new float[this.r][this.c][this.s];
		for(int i=0;i<this.r;i++)
			for(int j=0;j<this.c;j++)
				for(int k=0;k<this.s;k++)
					res[i][j][k] = im[i+boundingBox[0]][j+boundingBox[2]][k+boundingBox[4]];
		return res;
	}

	public boolean isConsensus(int i,int j,int k){ return consensus[i][j][k]; }
	public int getR(){ return this.r; }
	public int getC(){ return this.c; }
	public int getS(){ return this.s; }
	public int getN(){ return this.n; }
	public int getMaxLabel(){ return this.maxLab; }
	public int getObservation(int x, int y, int z, int j){ return this.labVols[x][y][z][j]; }
	public float getLocalWeight(int i,int j,int k,int n){ return 1; }

}