package edu.vanderbilt.masi.algorithms.labelfusion;

import java.util.Arrays;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.apache.commons.math3.stat.StatUtils;

public class IntensityNormalizer extends AbstractCalculation {
	
	private int [] dims;
	private OLSMultipleLinearRegression regressor;
	private double [] param_ests;
	private boolean [][][][] ncon;
	private int vec_length;
	private double [] tvec;
	private double [][] avec;
	
	public IntensityNormalizer(ObservationBase obs) {
		super();
		
		JistLogger.logOutput(JistLogger.INFO, "\n+++ Initializing New Intensity Normalizer +++");
		dims = obs.dims();
		
		// set the consensus region for the intensity normalization
		set_normalizing_consensus(obs);
		
		// allocate some space
		tvec = new double [vec_length];
		avec = new double [vec_length][1];
	}
	
	private void set_normalizing_consensus(ObservationBase obs) {
		
		ncon = new boolean[dims[0]][dims[1]][dims[2]][dims[3]];
		
		// initialize
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++) 
						ncon[x][y][z][v] = true;
		
		// set everything where anybody observes anything other than background to false
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						for (int j = 0; j < obs.num_raters(); j++)
							if (ncon[x][y][z][v] == true) {
								
								short [] obslabs = obs.get_all(x, y, z, v, j);
								float [] obsvals = obs.get_all_vals(x, y, z, v, j);
								
								for (int i = 0; i < obslabs.length; i++)
									if (obslabs[i] != 0 && obsvals[i] > 0.1)
										ncon[x][y][z][v] = false;
							}
		
		vec_length = 0;
		int num_con = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						if (ncon[x][y][z][v])
							num_con++;
						else
							vec_length++;
		
		float frac_con = ((float)num_con) / (dims[0]*dims[1]*dims[2]*dims[3]);
		JistLogger.logOutput(JistLogger.INFO, "[IntensityNormalizer] Normalizing Fraction Consensus: " + frac_con);
	}

	public void set_image_unit_normal (float [][][][] im) {
		
		JistLogger.logOutput(JistLogger.INFO, "[IntensityNormalizer] Normalizing image to unit normal distribution");
		
		double medianval = 0,
			   meanval = 0,
			   stdval = 0;
				
		double [] vec = new double[vec_length];
		int midpoint = vec_length / 2;
		int count = 0;
		
		// convert the image to a vector over the region of interest
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						if (!ncon[x][y][z][v]) {
							vec[count] = im[x][y][z][v];
							count++;
						}
		
		// sort the array
		Arrays.sort(vec);
		
		// calculate the statistics
		medianval = vec[midpoint];
		meanval = StatUtils.mean(vec);
		stdval = Math.sqrt(StatUtils.variance(vec, meanval));
		if (stdval == 0 || stdval == Double.NaN)
			stdval = 0.0001f;
		
		// try and estimate an appropriate measure of center
		float diffval;
		if ((Math.abs(medianval - meanval)/stdval) > 0.1)
			diffval = (float)medianval;
		else
			diffval = (float)meanval;
		
		// normalize the image
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) 
					for (int v = 0; v < dims[3]; v++)
						im[x][y][z][v] = (im[x][y][z][v] - diffval) / (float)stdval;
		
		JistLogger.logOutput(JistLogger.INFO, String.format("[IntensityNormalizer] Diff = %.2f, Median = %.2f, Mean = %.2f, Stdev = %.2f", diffval, medianval, meanval, stdval));
		
	}
	
	public void regress_image(float [][][][] im,
							  float [][][][] target) {
		
		JistLogger.logOutput(JistLogger.INFO, "[IntensityNormalizer] Regressing atlas image to target.");
		
		if (vec_length < 50) {
			JistLogger.logOutput(JistLogger.INFO, String.format("[IntensityNormalizer] Unable to use regression, not enough samples."));
			set_image_unit_normal(im);
		} else {
			
			try {
				
				double [] tmp_avec = new double [vec_length];
				
				// populate the newly allocated space
				int count = 0;
				for (int x = 0; x < dims[0]; x++)
					for (int y = 0; y < dims[1]; y++)
						for (int z = 0; z < dims[2]; z++) 
							for (int v = 0; v < dims[3]; v++)
								if (!ncon[x][y][z][v]) {
									tvec[count] = target[x][y][z][v];
									tmp_avec[count] = im[x][y][z][v];
									count = count + 1;
								}
				
				// sort the values
				Arrays.sort(tvec);
				Arrays.sort(tmp_avec);
				
				for (int i = 0; i < vec_length; i++)
					avec[i][0] = tmp_avec[i];
				
				// perform the regression
				regressor = new OLSMultipleLinearRegression();
				regressor.setNoIntercept(false);
				regressor.newSampleData(tvec, avec);
				param_ests = regressor.estimateRegressionParameters();
				JistLogger.logOutput(JistLogger.INFO, String.format("[IntensityNormalizer] Found Regression: y = %f*x + %f.", param_ests[1], param_ests[0]));
				
				// apply the regression parameters to the image
				double val;
				for (int x = 0; x < dims[0]; x++)
					for (int y = 0; y < dims[1]; y++)
						for (int z = 0; z < dims[2]; z++) 
							for (int v = 0; v < dims[3]; v++) {
								val = param_ests[0] + param_ests[1] * im[x][y][z][v];
								im[x][y][z][v] = (float)val;
							}
				
				regressor = null;
				param_ests = null;
				
			} catch (RuntimeException e) {
				JistLogger.logError(JistLogger.SEVERE, e.getMessage());
				JistLogger.logOutput(JistLogger.SEVERE, String.format("[IntensityNormalizer] Unable to use regression, probably an empty image."));
				set_image_unit_normal(im);
			}
		}
	}
}
