package edu.vanderbilt.masi.algorithms.labelfusion.simple;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.utilities.AndrewUtils;

public class NonLocalSIMPLELabelVolume extends AbstractCalculation {

	private int[][][][][] labs;
	private float[][][][][] imgs;
	private float[][][][] target;
	private int[] dims;
	private int[] origDims;
	private int numRaters;
	private int maxLab;
	private int[][] croppingRegion;
	private boolean[][][][] consensus;
	private int[][][][] consensusLabs;
	private int numCons, numTotal;
	private ImageHeader header;
	private int patchRadius;
	private int searchRadius;
	private int maxNumPatches;
	private float searchSTD;
	private float minPatchWeight;
	private Patch[][][][][][] patches;

	public NonLocalSIMPLELabelVolume(List<ImageData> obsLabsList,
			List<ImageData> obsImgsList, ImageData targetImage,
			int patchRadiusIn, int searchRadiusIn, float searchSTDIn,
			float minPatchWeightIn, int maxNumPatchesIn) throws Exception {
		super();
		patchRadius = patchRadiusIn;
		searchRadius = searchRadiusIn;
		maxNumPatches = maxNumPatchesIn;
		minPatchWeight = minPatchWeightIn;
		searchSTD = searchSTDIn;
		setDims(targetImage);
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("The image dimensions are [%d %d %d %d]", 
						dims[0],dims[1],dims[2],dims[3]));
		numRaters = obsLabsList.size();
		JistLogger.logOutput(JistLogger.WARNING, "Loading Images");
		JistLogger.logFlush();
		loadLabelData(obsLabsList);
		loadImageData(obsImgsList);
		loadTargetImage(targetImage);
		JistLogger.logOutput(JistLogger.WARNING, "Calculating cropping region");
		JistLogger.logFlush();
		determineCroppingRegion();
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("The cropping region is [%d:%d %d:%d %d:%d %d:%d]", croppingRegion[0][0], croppingRegion[0][1]
						, croppingRegion[1][0], croppingRegion[1][1], croppingRegion[2][0], croppingRegion[2][1]
								, croppingRegion[3][0], croppingRegion[3][1]));
		JistLogger.logFlush();
		JistLogger.logOutput(JistLogger.WARNING, "Cropping Data");
		JistLogger.logFlush();
		cropData();
		JistLogger.logOutput(JistLogger.WARNING, "Determining consensus voxels");
		JistLogger.logFlush();
		determineConsensusVoxels();
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("Starting non-local correspondence"));
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("\tPatch Radius: %d", patchRadius));
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("\tSearch Radius: %d", searchRadius));
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("\tMin Patch Weight: %.3f", minPatchWeight));
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("\tMax Number of Patches: %d", maxNumPatches));
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("\tSearch Standard Deviation: %.3f", this.searchSTD));
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("The maximum label is %d",maxLab));
		normalizeIntensity();
		runNonLocalCorrespondence();
	}
	
	private void normalizeIntensity(){
		IntensityNormalizer IN = new IntensityNormalizer();
		IN.setTarget(this.target);
		IN.setConsensus(this);
		IN.normalizeTarget();
		for(int j=0;j<this.numRaters;j++){
			float[][][][] im = getRaterImage(j);
			im = IN.normalizeImage(im);
			setRaterImage(im,j);
		}
	}
	
	private void setRaterImage(float[][][][] im, int j){
		for(int x=0;x<dims[0];x++)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++)
						this.imgs[x][y][z][c][j] = im[x][y][z][c];
	}
	
	private float[][][][] getRaterImage(int j){
		float[][][][] im = new float[dims[0]][dims[1]][dims[2]][dims[3]];
		for(int x=0;x<dims[0];x++)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++)
						im[x][y][z][c] = this.imgs[x][y][z][c][j];
		return im;
	}

	public int[] getDims(){ return this.dims; }

	public int[] getOrigDims(){ return this.origDims; }

	public int[][] getCroppingRegion(){ return this.croppingRegion; }

	public boolean isConsensus(int x, int y, int z, int c){ return this.consensus[x][y][z][c]; }

	public int getConsLab(int x, int y, int z, int c){ return this.consensusLabs[x][y][z][c]; }

	public ImageHeader getHeader(){ return this.header; }

	public int getObsLab(int x, int y, int z, int c, int j){
		return this.labs[x][y][z][c][j];
	}

	private void determineConsensusVoxels(){
		numCons = 0;
		numTotal = 0;
		consensus = new boolean[dims[0]][dims[1]][dims[2]][dims[3]];
		consensusLabs = new int[dims[0]][dims[1]][dims[2]][dims[3]];
		maxLab = 0;
		for(int x=0;x<dims[0];x++)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++){
						int[] voxelLabs = labs[x][y][z][c];
						for(int i: voxelLabs)
							if(i > maxLab)
								maxLab = i;
						boolean cons = isConsensus(voxelLabs);
						consensus[x][y][z][c] = cons;
						if(cons){
							numCons++;
							consensusLabs[x][y][z][c] = voxelLabs[0];
						}
						numTotal++;
					}
		float p = (float) numCons / (float) numTotal;
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("There are %d/%d consesnsus voxels (%.3f percent)", numCons,numTotal,p));
	}

	private boolean isConsensus(int[] vec){
		int lab1 = vec[0];
		for(int i: vec)
			if(i != lab1)
				return false;
		return true;
	}

	private void loadTargetImage(ImageData im){
		target = new float[dims[0]][dims[1]][dims[2]][dims[3]];
		for(int x=0;x<dims[0];x++)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++)
						target[x][y][z][c] = im.getFloat(x,y,z,c);
	}

	private void determineCroppingRegion(){
		croppingRegion = new int[4][];
		croppingRegion[0] = determineCroppingRegionX();
		croppingRegion[1] = determineCroppingRegionY();
		croppingRegion[2] = determineCroppingRegionZ();
		croppingRegion[3] = determineCroppingRegionC();
	}

	private int[] determineCroppingRegionX(){
		int[] cr = new int[2];
		cr[0] = determineCroppingRegionXTop();
		cr[1] = determineCroppingRegionXBottom()+1;
		return cr;
	}

	private int determineCroppingRegionXTop(){
		for(int x=0;x<dims[0];x++)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++)
						for(int j=0;j<numRaters;j++)
							if(labs[x][y][z][c][j] > 0)
								return x;
		return dims[0];
	}

	private int determineCroppingRegionXBottom(){
		for(int x=dims[0]-1;x>=0;x--)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++)
						for(int j=0;j<numRaters;j++)
							if(labs[x][y][z][c][j] > 0)
								return x;
		return 0;
	}

	private int[] determineCroppingRegionY(){
		int[] cr = new int[2];
		cr[0] = determineCroppingRegionYTop();
		cr[1] = determineCroppingRegionYBottom()+1;
		return cr;
	}

	private int determineCroppingRegionYTop(){
		for(int y=0;y<dims[1];y++)
			for(int x=0;x<dims[0];x++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++)
						for(int j=0;j<numRaters;j++)
							if(labs[x][y][z][c][j] > 0)
								return y;
		return dims[0];
	}

	private int determineCroppingRegionYBottom(){
		for(int y=dims[1]-1;y>=0;y--)
			for(int x=0;x<dims[0];x++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++)
						for(int j=0;j<numRaters;j++)
							if(labs[x][y][z][c][j] > 0)
								return y;
		return 0;
	}

	private int[] determineCroppingRegionZ(){
		int[] cr = new int[2];
		cr[0] = determineCroppingRegionZTop();
		cr[1] = determineCroppingRegionZBottom()+1;
		return cr;
	}

	private int determineCroppingRegionZTop(){
		for(int z=0;z<dims[2];z++)
			for(int x=0;x<dims[0];x++)
				for(int y=0;y<dims[1];y++)
					for(int c=0;c<dims[3];c++)
						for(int j=0;j<numRaters;j++)
							if(labs[x][y][z][c][j] > 0)
								return z;
		return dims[0];
	}

	private int determineCroppingRegionZBottom(){
		for(int z=dims[2]-1;z>=0;z--)
			for(int x=0;x<dims[0];x++)
				for(int y=0;y<dims[1];y++)
					for(int c=0;c<dims[3];c++)
						for(int j=0;j<numRaters;j++)
							if(labs[x][y][z][c][j] > 0)
								return z;
		return 0;
	}

	private int[] determineCroppingRegionC(){
		int[] cr = new int[2];
		cr[0] = this.determineCroppingRegionCTop();
		cr[1] = this.determineCroppingRegionCBottom()+1;
		return cr;
	}

	private int determineCroppingRegionCTop(){
		for(int c=0;c<dims[3];c++)
			for(int x=0;x<dims[0];x++)
				for(int y=0;y<dims[1];y++)
					for(int z=0;z<dims[2];z++)
						for(int j=0;j<numRaters;j++)
							if(labs[x][y][z][c][j] > 0)
								return c;
		return dims[0];
	}

	private int determineCroppingRegionCBottom(){
		for(int c=dims[3]-1;c>=0;c--)
			for(int x=0;x<dims[0];x++)
				for(int y=0;y<dims[1];y++)
					for(int z=0;z<dims[2];z++)
						for(int j=0;j<numRaters;j++)
							if(labs[x][y][z][c][j] > 0)
								return c;
		return 0;
	}

	private void cropData(){
		int[][][][][] oldLabs = this.labs;
		float[][][][][] oldImgs = this.imgs;
		float[][][][] oldTarget = this.target;
		origDims = dims;
		dims = new int[4];
		dims[0] = croppingRegion[0][1] - croppingRegion[0][0];
		dims[1] = croppingRegion[1][1] - croppingRegion[1][0];
		dims[2] = croppingRegion[2][1] - croppingRegion[2][0];
		dims[3] = croppingRegion[3][1] - croppingRegion[3][0];
		target = new float[dims[0]][dims[1]][dims[2]][dims[3]];
		labs = new int[dims[0]][dims[1]][dims[2]][dims[3]][numRaters];
		imgs = new float[dims[0]][dims[1]][dims[2]][dims[3]][numRaters];
		for(int x=0;x<dims[0];x++)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++){
						int xCoord = x + croppingRegion[0][0];
						int yCoord = y + croppingRegion[1][0];
						int zCoord = z + croppingRegion[2][0];
						int cCoord = c + croppingRegion[3][0];
						target[x][y][z][c] = oldTarget[xCoord][yCoord][zCoord][cCoord];
						for(int j=0;j<numRaters;j++){
							labs[x][y][z][c][j] = oldLabs[xCoord][yCoord][zCoord][cCoord][j];
							imgs[x][y][z][c][j] = oldImgs[xCoord][yCoord][zCoord][cCoord][j];
						}
					}
		oldTarget = null;
		oldImgs = null;
		oldLabs = null;
	}

	private void setDims(ImageData im){
		dims = new int[4];
		dims[0] = im.getRows();
		dims[1] = im.getCols();
		dims[2] = im.getSlices();
		dims[3] = im.getComponents();
		header = im.getHeader();
	}

	private void loadLabelData(List<ImageData> ims){
		labs = new int[dims[0]][dims[1]][dims[2]][dims[3]][numRaters];
		for(int j=0;j<numRaters;j++){
			ImageData im = ims.get(j);
			for(int x=0;x<dims[0];x++)
				for(int y=0;y<dims[1];y++)
					for(int z=0;z<dims[2];z++)
						for(int c=0;c<dims[3];c++){
							int lab = im.getInt(x,y,z,c);
							labs[x][y][z][c][j] = lab;
						}
			im.dispose();
		}
		ims.clear();
	}

	private void loadImageData(List<ImageData> ims){
		imgs = new float[dims[0]][dims[1]][dims[2]][dims[3]][numRaters];
		for(int j=0;j<numRaters;j++){
			ImageData im = ims.get(j);
			for(int x=0;x<dims[0];x++)
				for(int y=0;y<dims[1];y++)
					for(int z=0;z<dims[2];z++)
						for(int c=0;c<dims[3];c++){
							float val = im.getFloat(x,y,z,c);
							imgs[x][y][z][c][j] = val;
						}
			im.dispose();
		}
		ims.clear();
	}

	private void runNonLocalCorrespondence() {
		patches = new Patch[dims[0]][dims[1]][dims[2]][dims[3]][numRaters][];
		for(int j=0;j<numRaters;j++)
			runNonLocalCorrespondence(j);
		this.imgs = null;
	}

	private void runNonLocalCorrespondence(int j){
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("Starting non-local correspondence on %d", j));
		Patch[] p;
		int n=0;
		for(int x=0;x<dims[0];x++)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++)
						if(!isConsensus(x, y, z, c)){
							n++;
							printNLCStatusBar(n);
							p = runNonLocalCorrespondence(x,y,z,c,j);
							patches[x][y][z][c][j] = p;
						}
	}

	private Patch[] runNonLocalCorrespondence(int x,int y,int z,int c,int j){
		Patch[] p = null;

		// Should really write some class or something to do this
		int xl = x - this.searchRadius;
		xl = Math.max(0, xl);
		int xh = x + this.searchRadius;
		xh = Math.min(xh, dims[0]-1);
		int yl = y - this.searchRadius;
		yl = Math.max(0, yl);
		int yh = y + this.searchRadius;
		yh = Math.min(dims[1]-1, yh);
		int zl = z - this.searchRadius;
		zl = Math.max(0, zl);
		int zh = z + this.searchRadius;
		zh = Math.min(dims[2]-1, zh);
		ArrayList<Patch> currPatches = new ArrayList<Patch>(1000);
		Patch currPatch;
		for(int xi=xl;xi<=xh;xi++)
			for(int yi=yl;yi<=yh;yi++)
				for(int zi=zl;zi<=zh;zi++){
					currPatch = calculateNonLocalCorrespondence(x,y,z,c,j,xi,yi,zi);
					if(currPatch.getWeight() > this.minPatchWeight)
						currPatches.add(currPatch);
				}
		Collections.sort(currPatches);
		if(currPatches.size() > this.maxNumPatches)
			p = new Patch[this.maxNumPatches];
		else
			p = new Patch[currPatches.size()];
		for(int i=0;i<p.length;i++){
			int ind = currPatches.size() - (i+1);
			p[i] = currPatches.get(ind);
		}
		return p;
	}

	private Patch calculateNonLocalCorrespondence(int x, int y, int z, int c, int j, int xi, int yi, int zi) {
		Patch p = null;
		// Lets get the dimensions handling edge cases "properly"
		int dxl = getDXLow(x,xi);
		int dxh = getDXHigh(x,xi);
		int dyl = getDYLow(y,yi);
		int dyh = getDYHigh(y,yi);
		int dzl = getDZLow(z,zi);
		int dzh = getDZHigh(z,zi);
		double weight = 0;
		for(int dx=-dxl;dx<=dxh;dx++)
			for(int dy=-dyl;dy<=dyh;dy++)
				for(int dz=-dzl;dz<dzh;dz++){
					double w1 = 0;
					double w2 = 0;
					w1 = this.target[x+dx][y+dy][z+dz][c];
					w2 = this.imgs[xi+dx][yi+dy][zi+dz][c][j];
					double dist = Math.pow((w1 - w2),2);
					weight += dist;
				}
		weight = Math.sqrt(weight) / searchSTD;
		weight = Math.exp(-weight);
		p = new Patch((float) weight, xi, yi, zi, c, labs[xi][yi][zi][c][j]);
		return p;
	}

	public int getDXLow(int x1, int x2){
		int dx = getLowerBound(x1, x2);
		return dx;
	}

	public int getDXHigh(int x1, int x2){
		int dx = getUpperBound(x1, x2 , dims[0]-1);
		return dx;
	}

	public int getDYLow(int y1, int y2){
		int dy = getLowerBound(y1, y2);
		return dy;
	}

	public int getDYHigh(int y1, int y2){
		int dy = getUpperBound(y1, y2, dims[1]-1);
		return dy;
	}

	public int getDZLow(int z1, int z2){
		int dz = getLowerBound(z1, z2);
		return dz;
	}

	public int getDZHigh(int z1, int z2){
		int dz = getUpperBound(z1, z2, dims[2]-1);
		return dz;
	}

	private int getLowerBound(int a1, int a2){
		int a = Math.min(a1, a2);
		int d = this.patchRadius;
		int l = Math.max(a - d, 0);
		d = Math.abs(l - a);
		return d;
	}
	
	public int getLowerBoundRadius(int a1, int a2, int r){
		int a = Math.min(a1,a2);
		int l = Math.max(a - r, 0);
		int d = Math.abs(l - a);
		return d;
	}
	
	public int getDXHighRadius(int x1, int x2, int r){
		int dx = getUpperBoundRadius(x1, x2 , dims[0]-1, r);
		return dx;
	}
	
	public int getDYHighRadius(int y1, int y2, int r){
		int dx = getUpperBoundRadius(y1, y2 , dims[1]-1, r);
		return dx;
	}
	
	public int getDZHighRadius(int z1, int z2, int r){
		int dx = getUpperBoundRadius(z1, z2 , dims[2]-1, r);
		return dx;
	}
	
	private int getUpperBoundRadius(int a1, int a2, int h, int r){
		int a = Math.max(a1, a2);
		h = Math.min(a + r, h);
		int d = Math.abs(h - a);
		return d;
	}

	private int getUpperBound(int a1, int a2, int h){
		int a = Math.max(a1, a2);
		int d = this.patchRadius;
		h = Math.min(a + d, h);
		d = Math.abs(h - a);
		return d;
	}

	private void printNLCStatusBar(int n){
		int m = (numTotal-numCons)/10;
		int r = n % m;
		if(r==0){
			String output = "[";
			r = n / m;
			for(int i=0;i<r;i++)
				output += "=";
			for(int i=r;i<10;i++)
				output += "+";
			output += "]";
			JistLogger.logOutput(JistLogger.WARNING, output);
			JistLogger.logFlush();
		}
	}

	public Patch[][] getVoxelPatches(int x, int y, int z, int c){
		return this.patches[x][y][z][c];
	}

	public int getMaxLabel(){ return this.maxLab; }
	
	public int getLabel(int x, int y, int z, int c, int j){
		if(x < 0 || y < 0 || z < 0 || c < 0 || j < 0)
			JistLogger.logOutput(JistLogger.WARNING, String.format(
				"About to get rekt [%d %d %d %d %d]",x,y,z,c,j)); 
		return this.labs[x][y][z][c][j];
	}

	class IntensityNormalizer {
		private float[][][][] target;
		private boolean[][][][] cons;
		private int ncons;
		private OLSMultipleLinearRegression regressor;

		public IntensityNormalizer() {
		}

		public void setTarget(float[][][][] t) {
			this.target = t;
		}

		public void setConsensus(NonLocalSIMPLELabelVolume obs) {
			ncons = 0;
			cons = new boolean[target.length][target[0].length][target[0][0].length][target[0][0][0].length];
			int[] workingDims = obs.getDims();
			for(int x=0;x<workingDims[0];x++)
				for(int y=0;y<workingDims[1];y++)
					for(int z=0;z<workingDims[2];z++)
						for(int c=0;c<workingDims[3];c++)
							if(obs.isConsensus(x, y, z, c) && obs.getConsLab(x, y, z, c)>0){
								ncons++;
								cons[x][y][z][c] = true;
							}
			JistLogger.logOutput(JistLogger.WARNING,
					String.format("[Intensity Normalizer] There are %d voxels for intensity normalization",ncons));
		}

		public float[][][][] normalizeTarget() {
			int n = 0;
			float[] vec = new float[ncons];
			for (int i = 0; i < target.length; i++)
				for (int j = 0; j < target[i].length; j++)
					for (int k = 0; k < target[i][j].length; k++)
						for(int l=0;l <target[i][j][k].length;l++)
							if (cons[i][j][k][l]) {
								vec[n] = target[i][j][k][l];
								n++;
							}
			float mean = AndrewUtils.calculateMean(vec);
			float std = AndrewUtils.calculateSTD(vec, mean);
			JistLogger.logOutput(JistLogger.WARNING,
					"[Intensity Normalizer] Normalizing Target Image");
			JistLogger
			.logOutput(
					JistLogger.WARNING,
					String.format(
							"[Intensity Normalizer] Target Mean: %.4f Target STD: %.4f",
							mean, std));
			for (int i = 0; i < target.length; i++)
				for (int j = 0; j < target[i].length; j++)
					for (int k = 0; k < target[i][j].length; k++)
						for(int l=0; l<target[i][j][k].length;l++)
							target[i][j][k][l] = (target[i][j][k][l] - mean) / std;
			return target;

		}

		public float[][][][] normalizeImage(float[][][][] im) {

			double[] tvec = new double[ncons];
			double[][] avec = new double[ncons][1];
			int n = 0;
			for (int i = 0; i < cons.length; i++)
				for (int j = 0; j < cons[i].length; j++)
					for (int k = 0; k < cons[i][j].length; k++)
						for(int l=0;l<cons[i][j][k].length;l++)
							if (cons[i][j][k][l]) {
								tvec[n] = (double) target[i][j][k][l];
								avec[n][0] = (double) im[i][j][k][l];
								n++;
							}
			regressor = new OLSMultipleLinearRegression();
			regressor.setNoIntercept(false);
			regressor.newSampleData(tvec, avec);
			double[] param_ests = regressor.estimateRegressionParameters();
			JistLogger.logOutput(JistLogger.WARNING, String.format(
					"[IntensityNormalizer] Found Regression: y = %f*x + %f.",
					param_ests[1], param_ests[0]));
			for (int i = 0; i < im.length; i++)
				for (int j = 0; j < im[i].length; j++)
					for (int k = 0; k < im[i][j].length; k++)
						for(int l=0;l<im[i][j][k].length;l++)
							im[i][j][k][l] = (float) (param_ests[0] + param_ests[1]
									* im[i][j][k][l]);
			return im;
		}
	}
	
	public int getMajorityVoteEstimate(int x, int y, int z, int c){
		float[] weights = new float[this.maxLab+1];
		for(int j=0;j<this.numRaters;j++){
			int l = this.getObsLab(x, y, z, c, j);
			weights[l] += 1;
		}
		int l = 0;
		for(int i=0;i<weights.length;i++)
			if(weights[i] > weights[l])
				l = i;
		return l;
	};
	
	public void setPatches(int x,int y,int z,int c,Patch[][] np){ this.patches[x][y][z][c] = np; }
}
