package edu.vanderbilt.masi.algorithms.labelfusion.simple;

import java.util.List;

import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.utilities.AndrewUtils;

public class LocalSIMPLELabelVolume extends SimpleLabelVolume {

	private IntensityNormalizer IN;
	private float[][][][] weights; // subject x row x column x slice
	private float std = .5f;

	public LocalSIMPLELabelVolume(List<ImageData> imLabList, List<ImageData> imImgList, ImageData targetImg) {
		super(imLabList);
		float[][][] target = loadImageData(targetImg);
		IN = new IntensityNormalizer();
		IN.setTarget(target);
		IN.setConsensus(this);
		target = IN.normalizeTarget();
		float[][][] atlasIm;
		weights = new float[this.getN()][this.getR()][this.getC()][this.getS()];
		for(int i=0;i<this.getN();i++){
			JistLogger.logOutput(JistLogger.WARNING,String.format("Loading Image %d", i));
			atlasIm = IN.normalizeImage(loadImageData(imImgList.get(i)));
			calculateWeights(target,atlasIm,i);
		}
	}

	private void calculateWeights(float[][][] target, float[][][] im,int r){
		float meanWeight = 0;
		int numWeight = 0;
		for(int i=0;i<im.length;i++)
			for(int j=0;j<im[i].length;j++)
				for(int k=0;k<im[i][j].length;k++){
					weights[r][i][j][k] = (float) Math.exp(-Math.pow(im[i][j][k] - target[i][j][k],2)/std);
					if(!isConsensus(i,j,k)){
						meanWeight += Math.abs(weights[r][i][j][k]);
						numWeight++;
					}
				}
		meanWeight = meanWeight / numWeight;
		JistLogger.logOutput(JistLogger.WARNING, String.format(
				"Mean weight non-consensus: %f",meanWeight));
		JistLogger.logFlush();
	}

	private float[][][] loadImageData(ImageData im){
		float[][][] res = new float[im.getRows()][im.getCols()][im.getSlices()];
		for(int i=0;i<res.length;i++)
			for(int j=0;j<res[i].length;j++)
				for(int k=0;k<res[i][j].length;k++)
					res[i][j][k] = im.getFloat(i, j, k);
		return cropToBoundingBox(res);
	}

	class IntensityNormalizer{
		private float[][][] target;
		private boolean[][][] cons;
		private int ncons;
		private OLSMultipleLinearRegression regressor;
		public IntensityNormalizer(){}
		public void setTarget(float[][][] t){ this.target = t; }
		public void setConsensus(SimpleLabelVolume obs){
			ncons = 0;
			cons = new boolean[target.length][target[0].length][target[0][0].length];
			for(int i=0;i<obs.getR();i++)
				for(int j=0;j<obs.getC();j++)
					for(int k=0;k<obs.getS();k++)
						if(obs.isConsensus(i, j, k) && obs.getConsLab(i,j,k) > 0){
							cons[i][j][k] = true;
							ncons++;
						}
						else
							cons[i][j][k] = false;
			JistLogger.logOutput(JistLogger.WARNING, String.format(
					"[Intensity Normalizer] There are %d voxels for intensity normalization",ncons));
		}
		public float[][][] normalizeTarget(){
			int n=0;
			float[] vec = new float[ncons];
			for(int i=0;i<target.length;i++)
				for(int j=0;j<target[i].length;j++)
					for(int k=0;k<target[i][j].length;k++)
						if(cons[i][j][k]){
							vec[n] = target[i][j][k];
							n++;
						}
			float mean = AndrewUtils.calculateMean(vec);
			float std  = AndrewUtils.calculateSTD(vec, mean);
			JistLogger.logOutput(JistLogger.WARNING, "[Intensity Normalizer] Normalizing Target Image");
			JistLogger.logOutput(JistLogger.WARNING,
					String.format("[Intensity Normalizer] Target Mean: %.4f Target STD: %.4f",mean,std));
			for(int i=0;i<target.length;i++)
				for(int j=0;j<target[i].length;j++)
					for(int k=0;k<target[i][j].length;k++)
						target[i][j][k] = (target[i][j][k] - mean)/std;
			return target;

		}
		public float[][][] normalizeImage(float[][][] im){

			double[] tvec = new double[ncons];
			double[][] avec = new double[ncons][1];
			int n=0;
			for(int i=0;i<cons.length;i++)
				for(int j=0;j<cons[i].length;j++)
					for(int k=0;k<cons[i][j].length;k++)
						if(cons[i][j][k]){
							tvec[n] = (double) target[i][j][k];
							avec[n][0] = (double) im[i][j][k];
							n++;
						}
			regressor = new OLSMultipleLinearRegression();
			regressor.setNoIntercept(false);
			regressor.newSampleData(tvec, avec);
			double[] param_ests = regressor.estimateRegressionParameters();
			JistLogger.logOutput(JistLogger.WARNING,
					String.format("[IntensityNormalizer] Found Regression: y = %f*x + %f.", param_ests[1], param_ests[0]));
			for(int i=0;i<im.length;i++)
				for(int j=0;j<im[i].length;j++)
					for(int k=0;k<im[i][j].length;k++)
						im[i][j][k] = (float) (param_ests[0] + param_ests[1]*im[i][j][k]);
			return im;
		}
	}

	@Override
	public float getLocalWeight(int i,int j,int k,int n){ return this.weights[n][i][j][k]; }
}
