package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.apache.commons.math3.stat.StatUtils;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class MSIntensityNormalizer extends AbstractCalculation {

	private int r,c,s,n;
	private boolean[][][] cons;
	private int vecLength;
	private double[] tVec;
	private double[][] iVec;
	private double medianValue, meanValue, stdValue,diffval;

	public MSIntensityNormalizer(int[][][][] labs,float[][][] im){
		super();
		log("Initializing new Intensity Normalizer",JistLogger.INFO);
		setConsensus(labs);
		loadImage(im);
	}

	private void loadImage(float[][][] im){
		tVec = new double[vecLength];
		int count = 0;
		for(int i=0;i<r;i++)
			for(int j=0;j<c;j++)
				for(int k=0;k<s;k++)
					if(cons[i][j][k])
						tVec[count++] = (double) im[i][j][k];
		tVec = centerVector(tVec);
		for(int i=0;i<r;i++)
			for(int j=0;j<c;j++)
				for(int k=0;k<s;k++)
					im[i][j][k] = (im[i][j][k] - (float) diffval)/ (float) stdValue;
	}

	private double[] centerVector(double[] vec){
		medianValue = 0;
		meanValue  = 0;
		stdValue   = 0;
		medianValue = StatUtils.percentile(vec, 50.0);
		meanValue   = StatUtils.mean(vec);
		stdValue    = Math.sqrt(StatUtils.variance(vec, meanValue));
		if (stdValue == 0 || stdValue == Double.NaN)
			stdValue = 0.0001f;
		diffval = 0;
		if ((Math.abs(medianValue - meanValue)/stdValue) > 0.1)
			diffval = medianValue;
		else
			diffval = meanValue;
		for(int i=0;i<vec.length;i++)
			vec[i] = (vec[i] -  diffval) / (float)stdValue;
		log(String.format("[IntensityNormalizer] Diff = %.2f, Median = %.2f, Mean = %.2f, Stdev = %.2f", diffval, medianValue, meanValue, stdValue));
		return vec;
	}

	public float[][][] normalizeImage(float[][][] im){
		// Load data from matrix at consensus voxels
		iVec = new double[vecLength][1];
		int count = 0;
		for(int i=0;i<r;i++)
			for(int j=0;j<c;j++)
				for(int k=0;k<s;k++)
					if(cons[i][j][k])
						iVec[count++][0] = (double) im[i][j][k];

		// Perform OLS to determine normalization stuff
		double[] iVecT = new double[vecLength];
		double[] tVecT = new double[vecLength];
		for(int i=0;i<vecLength;i++){
			iVecT[i] = iVec[i][0];
			tVecT[i] = tVec[i];
		}
		Arrays.sort(iVecT);
		Arrays.sort(tVecT);
		for(int i=0;i<vecLength;i++)
			iVec[i][0] = iVecT[i];
		OLSMultipleLinearRegression regressor = new OLSMultipleLinearRegression();
		regressor.setNoIntercept(false);
		regressor.newSampleData(tVecT,iVec);
		double[] param_ests = regressor.estimateRegressionParameters();
		log(String.format("The regression model found on the data is y = %f +%fx", param_ests[0],param_ests[1]));
		for(int i=0;i<r;i++)
			for(int j=0;j<c;j++)
				for(int k=0;k<s;k++)
					im[i][j][k] = (float) ((param_ests[0] + ((double) im[i][j][k])*param_ests[1]));
		return im;
	}

	private void setConsensus(int[][][][] labs){
		r = labs.length;
		c = labs[0].length;
		s = labs[0][0].length;
		n = labs[0][0][0].length;
		cons = new boolean[r][c][s];
		vecLength = 0;
		// Set anything where anyone observes anything to true
		for(int i=0;i<r;i++)
			for(int j=0;j<c;j++)
				for(int k=0;k<s;k++){
					cons[i][j][k] = false;
					for(int l=0;l<n;l++)
						if(labs[i][j][k][l] != 0){
							cons[i][j][k] = true;
							break;
						}
					if(cons[i][j][k])
						vecLength++;
				}

		// Consider what is consensus
		int numVoxels = r*c*s;
		log(String.format("%f Normalizing fraction consensus", 
				(float) vecLength / (float) numVoxels));

	}

	private void log(String msg){ log(msg,JistLogger.INFO);}

	private void log(String msg,int level){
		msg = "[Intensity Normalizer] " + msg;
		JistLogger.logOutput(level, msg);
		JistLogger.logFlush();
	}

}
