package edu.vanderbilt.masi.algorithms.regression;

import org.w3c.dom.*;

public class StatisticsAggregator {
	private int sampleCount;
	private int boxCount;
	private double XT_X_11_, XT_X_22_;
	private double[] XT_X_12_, XT_X_21_;
	private double[][] XT_Y_1_;
    private double[] XT_Y_2_;

    private double[] Y2_;
	
	public StatisticsAggregator(int c) {
		boxCount = c;
		sampleCount = 0;
		XT_X_11_ = 0.0;
		XT_X_22_ = 0.0;
		XT_X_12_ = new double[3];
		XT_X_21_ = new double[3];
		XT_Y_1_ = new double[boxCount * 6][3];
		XT_Y_2_ = new double[boxCount * 6];
		Y2_ = new double[boxCount * 6];
	}

	public double Entropy() {
		if (sampleCount < 3)
	        return Double.POSITIVE_INFINITY;

	      double determinant = XT_X_11_ * XT_X_22_ - 
	    		  MatrixComputor.Multiply(XT_X_12_, XT_X_21_);

	      if (determinant == 0.0)
	        return Double.POSITIVE_INFINITY;

	      return 0.5 * Math.log(Math.pow(2.0 * Math.PI * Math.E, 2.0) * determinant);
	}
	
	public double GetProbability(double[] x, float y, int boxIndex, int vecIndex) {
		// coefficient
		int index = boxIndex * 6 + vecIndex;
		double[] mean_x = MatrixComputor.Scale(XT_X_12_, (1.0 / sampleCount));
		double ss_x = XT_X_11_ - 
				sampleCount * MatrixComputor.Multiply(mean_x, mean_x);
		double mean_y = XT_Y_2_[index] / (double)sampleCount;
		double ss_y = Y2_[index] - sampleCount * mean_y * mean_y;
		
		double[] ss_xy = MatrixComputor.Add(
				XT_Y_1_[index], 
				MatrixComputor.Scale(mean_x, -mean_y * (double)sampleCount));
		
		double r2 = MatrixComputor.Multiply(ss_xy, ss_xy) / (ss_x * ss_y);
		//if (ss_x * ss_y == 0) r2 = 0;
		double sigma_2 = ss_y * (1.0 - r2) / (double)sampleCount;
		
		// Bayes
		double determinant = XT_X_11_ * XT_X_22_ - 
	    		  MatrixComputor.Multiply(XT_X_12_, XT_X_12_);
		
		double A_11 = sigma_2 * XT_X_22_ / determinant, 
				A_22 = sigma_2 * XT_X_11_ / determinant;
		double[] A_12 = MatrixComputor.Scale(XT_X_12_, -sigma_2 / determinant),
				A_21 = MatrixComputor.Scale(XT_X_21_, -sigma_2 / determinant);
		
		double mean1 = MatrixComputor.Multiply(x, 
						MatrixComputor.Add(
								MatrixComputor.Scale(XT_Y_1_[index], A_11), 
								MatrixComputor.Scale(A_12, XT_Y_2_[index]))) / 
								sigma_2;
		double mean2 = (MatrixComputor.Multiply(A_21, XT_Y_1_[index]) + 
							(A_22 * XT_Y_2_[index])) / sigma_2;
		double mean = mean1 + mean2;
		double variance = MatrixComputor.Multiply(x, 
								MatrixComputor.Add(
									MatrixComputor.Scale(x, A_11), 
									A_12)) + 
								(MatrixComputor.Multiply(A_21, x) + A_22) + sigma_2;
		return Math.pow(2.0 * Math.PI, -0.5) * Math.pow(variance, -0.5) * Math.exp(-0.5 * (y - mean) * (y - mean) / (variance));
	}
	
	public int GetSampleCount() {
		return sampleCount;
	}

	public void Clear() {
		sampleCount = 0;
		XT_X_11_ = 0.0;
		XT_X_22_ = 0.0;
		XT_X_12_ = new double[3];
		XT_X_21_ = new double[3];
		XT_Y_1_ = new double[boxCount * 6][3];
		XT_Y_2_ = new double[boxCount * 6];
		Y2_ = new double[boxCount * 6];
	}
	
	public void Aggregate(DataCollection data, int index) {
		double[] datum = data.GetVector(index);
		XT_X_11_ += MatrixComputor.Multiply(datum, datum);
		XT_X_12_ = MatrixComputor.Add(XT_X_12_, datum);
		XT_X_21_ = MatrixComputor.Add(XT_X_21_, datum);
		XT_X_22_ += 1.0;
		
		for (int i = 0; i < boxCount; ++i) {
			for (int j = 0; j < 6; ++j) {
				AggregateY(data, index, i, j);
			}
		}
		sampleCount++;
	}
	
	public void Aggregate(StatisticsAggregator aggregator) {
		XT_X_11_ += aggregator.XT_X_11_;
		XT_X_12_ = MatrixComputor.Add(XT_X_12_, aggregator.XT_X_12_);
		XT_X_21_ = MatrixComputor.Add(XT_X_21_, aggregator.XT_X_21_);
		XT_X_22_ += aggregator.XT_X_22_;
		XT_Y_1_ = MatrixComputor.Add(XT_Y_1_, aggregator.XT_Y_1_);
		XT_Y_2_ = MatrixComputor.Add(XT_Y_2_, aggregator.XT_Y_2_);
		Y2_ = MatrixComputor.Add(Y2_, aggregator.Y2_);
		sampleCount += aggregator.sampleCount;
	}
	
	public StatisticsAggregator DeepClone() {
		StatisticsAggregator result = new StatisticsAggregator(boxCount);
		result.sampleCount = sampleCount;
		result.XT_X_11_ = XT_X_11_;
		result.XT_X_22_ = XT_X_22_;
		result.XT_X_12_ = MatrixComputor.Copy(XT_X_12_);
		result.XT_X_21_ = MatrixComputor.Copy(XT_X_21_);
		result.XT_Y_1_ = MatrixComputor.Copy(XT_Y_1_);
		result.XT_Y_2_ = MatrixComputor.Copy(XT_Y_2_);
		result.Y2_ = MatrixComputor.Copy(Y2_);
		
		return result;
	}
	
	public void CreateStatisticsElement(Document doc, Element e) {
		e.setAttribute("BoxCount", Integer.toString(boxCount));
		e.setAttribute("SampleCount", Integer.toString(sampleCount));
		e.setAttribute("XT_X_11_", Double.toString(XT_X_11_));
		e.setAttribute("XT_X_22_", Double.toString(XT_X_22_));
		
		Element xNode12 = doc.createElement("XT_X_12_");
		Element xNode21 = doc.createElement("XT_X_21_");
		for (int i = 0; i < 3; ++i) {
			xNode12.setAttribute("v" + Integer.toString(i), 
					Double.toString(XT_X_12_[i]));
			xNode21.setAttribute("v" + Integer.toString(i), 
					Double.toString(XT_X_21_[i]));
		}
		e.appendChild(xNode12);
		e.appendChild(xNode21);
		
		Element xyNode1 = doc.createElement("XT_Y_1_");
		Element xyNode2 = doc.createElement("XT_Y_2_");
		Element yNode = doc.createElement("Y2_");
		for (int i = 0; i < boxCount * 6; ++i) {
			Element xyItem1 = doc.createElement("vector");
			for (int j = 0; j < 3; ++j) {
				xyItem1.setAttribute("v" + Integer.toString(j), 
						Double.toString(XT_Y_1_[i][j]));			
			}
			xyNode1.appendChild(xyItem1);
			
			Element xyItem2 = doc.createElement("v" + Integer.toString(i));
			xyItem2.setAttribute("value", Double.toString(XT_Y_2_[i]));
			xyNode2.appendChild(xyItem2);
			
			Element yItem = doc.createElement("v" + Integer.toString(i));
			yItem.setAttribute("value", Double.toString(Y2_[i]));
			yNode.appendChild(yItem);
		}
		e.appendChild(xyNode1);
		e.appendChild(xyNode2);
		e.appendChild(yNode);
	}
	
	public void XmlParse(Element e) {
		boxCount = Integer.parseInt(e.getAttribute("BoxCount"));
		sampleCount = Integer.parseInt(e.getAttribute("SampleCount"));
		XT_X_11_ = Double.parseDouble(e.getAttribute("XT_X_11_"));
		XT_X_22_ = Double.parseDouble(e.getAttribute("XT_X_22_"));
		
		//NodeList children = e.getChildNodes();
		Element xNode12 = (Element) e.getElementsByTagName("XT_X_12_").item(0);
		Element xNode21 = (Element) e.getElementsByTagName("XT_X_21_").item(0);
		for (int i = 0; i < 3; ++i) {
			XT_X_12_[i] = Double.parseDouble(xNode12.getAttribute(
					"v" + Integer.toString(i)));
			XT_X_21_[i] = Double.parseDouble(xNode21.getAttribute(
					"v" + Integer.toString(i)));
		}
		
		Element xyNode1 = (Element) e.getElementsByTagName("XT_Y_1_").item(0);
		Element xyNode2 = (Element) e.getElementsByTagName("XT_Y_2_").item(0);
		Element yNode = (Element) e.getElementsByTagName("Y2_").item(0);
		NodeList xyChildren1 = xyNode1.getElementsByTagName("vector");
		for (int i = 0; i < boxCount * 6; ++i) {
			Element xyItem1 = (Element) xyChildren1.item(i);
			for (int j = 0; j < 3; ++j) {
				XT_Y_1_[i][j] = Double.parseDouble(
						xyItem1.getAttribute("v" + Integer.toString(j)));
			}
			
			Element xyItem2 = 
					(Element) xyNode2.getElementsByTagName("v" + Integer.toString(i)).item(0);
			XT_Y_2_[i] = Double.parseDouble(xyItem2.getAttribute("value"));
			
			Element yItem = 
					(Element) yNode.getElementsByTagName("v" + Integer.toString(i)).item(0);
			Y2_[i] = Double.parseDouble(yItem.getAttribute("value"));
		}
	}
	
	private void AggregateY(DataCollection data, int index,
			int boxIndex, int vecIndex) {
		double[] datum = data.GetVector(index);
		double target = data.GetTarget(index, boxIndex, vecIndex);
		datum = MatrixComputor.Scale(datum, target);
		XT_Y_1_[boxIndex * 6 + vecIndex] = MatrixComputor.Add(
				XT_Y_1_[boxIndex * 6 + vecIndex], datum);
		XT_Y_2_[boxIndex * 6 + vecIndex] += target;
		Y2_[boxIndex * 6 + vecIndex] += target * target;
	}
	
	//////////////////////////////////////////////////////////
	public int[] GetTwoBoxId() {
		double[] syCandidate = new double[boxCount];
		int[] boxIds = new int[2];
		double syMin = Double.MAX_VALUE;
		for (int i = 0; i < boxCount; ++i) {
			double mean_y = XT_Y_2_[i * 6] / (double)sampleCount;
			syCandidate[i] = Y2_[i * 6] - sampleCount * mean_y * mean_y;
			if (syMin > syCandidate[i]) {
				syMin = syCandidate[i];
				boxIds[0] = i;
			}
		}
		syMin = Double.MAX_VALUE;
		for (int i = 0; i < boxCount; ++i) {
			if (i == boxIds[0]) continue;
			if (syMin > syCandidate[i]) {
				syMin = syCandidate[i];
				boxIds[1] = i;
			}
		}
		
		return boxIds;
	}
	
	public double[] GetTargetMean(int boxIndex) {
		double coef = (double)1 / (double)sampleCount;
		double[] mean_y = MatrixComputor.Scale(XT_Y_2_, coef);
		double[] mean_target = new double[6];
		for (int i = 0; i < 6; ++i) {
			mean_target[i] = mean_y[boxIndex * 6 + i];
		}
		return mean_target;
	}
	
	public double[] GetTargetsVariance() {
		double[] variances = new double[boxCount];
		for (int i = 0; i < boxCount; ++i) {
			variances[i] = GetTargetVar(i);
		}
		return variances;
	}
	
	public double GetTargetVar(int index) {
		double mean_y = XT_Y_2_[index * 6] / (double)sampleCount;
		double ss_y = Y2_[index * 6] / (double) sampleCount - mean_y * mean_y;
		return ss_y;
	}
}