package edu.vanderbilt.masi.algorithms.regression;

import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class TreeTrainingOperation {
	private DataCollection data_;
	private TrainingContext trainingContext_;

	private TrainingParameters parameters_;

	private int[] indices_;

	private float[] responses_;

	private StatisticsAggregator parentStatistics_, leftChildStatistics_,
			rightChildStatistics_;
	
	// for parallel
	private float[][] responsesForCandidate;
	private double[] gainsForCandidate;
	private float[] thresholdsForCandidate;
	private FeatureResponse[] featuresForCandidate;
	private static final int maxCountOfThreads = 10;
	private Long count;


	public TreeTrainingOperation(TrainingContext trainingContext,
			TrainingParameters parameters, DataCollection data) {
		data_ = data;
		trainingContext_ = trainingContext;
		parameters_ = parameters;

		indices_ = new int[data.GetDataCount()];
		for (int i = 0; i < indices_.length; i++)
			indices_[i] = i;

		responses_ = new float[data.GetDataCount()];

		parentStatistics_ = trainingContext_.GetStatisticsAggregator();

		leftChildStatistics_ = trainingContext_.GetStatisticsAggregator();
		rightChildStatistics_ = trainingContext_.GetStatisticsAggregator();

	}

	public void TrainNodesRecurse(Node[] nodes, int nodeIndex, int i0, int i1,
			int recurseDepth) {
		nodes[nodeIndex] = new Node();

		// First aggregate statistics over the samples at the parent node
		parentStatistics_.Clear();
		for (int i = i0; i < i1; i++)
			parentStatistics_.Aggregate(data_, indices_[i]);

		if (nodeIndex >= nodes.length / 2) // this is a leaf node, nothing else
											// to do
		{
			nodes[nodeIndex].InitializeLeaf(parentStatistics_);
			System.out.println(nodeIndex);
			return;
		}
		
		// init parameters for parallel
		responsesForCandidate = new float[maxCountOfThreads][data_.GetDataCount()];
		gainsForCandidate = new double[parameters_.NumberOfCandidateFeatures];
		thresholdsForCandidate = new float[parameters_.NumberOfCandidateFeatures];
		featuresForCandidate = new FeatureResponse[parameters_.NumberOfCandidateFeatures];
		int nCpus = Runtime.getRuntime().availableProcessors();
    	

		// Iterate over candidate features		
		for (int f = 0; f < parameters_.NumberOfCandidateFeatures; f += maxCountOfThreads) {
			// ...
			count = 0L;
			ExecutorService pool = Executors.newFixedThreadPool(nCpus);
			for (int i = 0; i < maxCountOfThreads; i++) {
				TestRunnable r = new TestRunnable(f, i, i0, i1);
				pool.submit(r);
			}
			
			boolean flag = true;
			while (flag) {
				try {
					Thread.sleep(30);
				} catch (InterruptedException e) {
					e.printStackTrace();
				}
				synchronized (count) {
					if (count == maxCountOfThreads) {
						flag = false;
					}
				}
			}
			
			pool.shutdown();
		}
		
		double maxGain = 0.0;
		FeatureResponse bestFeature = new FeatureResponse();
		float bestThreshold = 0.0f;
		for (int f = 0; f < parameters_.NumberOfCandidateFeatures; f++) {
				if (gainsForCandidate[f] >= maxGain) {
				maxGain = gainsForCandidate[f];
				bestFeature = featuresForCandidate[f];
				bestThreshold = thresholdsForCandidate[f];
			}
		}

		if (maxGain == 0.0) {
			nodes[nodeIndex].InitializeLeaf(parentStatistics_);
			return;
		}
		
		// Now reorder the data point indices using the winning feature and
		// thresholds.
		// Also recompute child node statistics so the client can decide whether
		// to terminate training of this branch.
		leftChildStatistics_.Clear();
		rightChildStatistics_.Clear();

		for (int i = i0; i < i1; i++) {
			responses_[i] = bestFeature.GetResponse(data_, indices_[i]);
			if (responses_[i] < bestThreshold)
				leftChildStatistics_.Aggregate(data_, indices_[i]);
			else
				rightChildStatistics_.Aggregate(data_, indices_[i]);
		}

		if (trainingContext_.ShouldTerminate(maxGain)) {
			nodes[nodeIndex].InitializeLeaf(parentStatistics_);
			return;
		}

		nodes[nodeIndex].InitializeSplit(bestFeature, bestThreshold,
				parentStatistics_);

		// Now do partition sort - any sample with response greater goes left,
		// otherwise right
		int ii = Tree.Partition(responses_, indices_, i0, i1, bestThreshold);

		// Otherwise this is a new decision node, recurse for children.
		TrainNodesRecurse(nodes, nodeIndex * 2 + 1, i0, ii, recurseDepth + 1);
		TrainNodesRecurse(nodes, nodeIndex * 2 + 2, ii, i1, recurseDepth + 1);
	}

	

	class TestRunnable implements Runnable {
		private int _index;
		private int _indexOfThread;
		private StatisticsAggregator _leftChildStatistics,
		_rightChildStatistics;
		private StatisticsAggregator[] _partitionStatistics;
		private int i0, i1;
		
		public TestRunnable(int index, int indexOfThread, int i0, int i1) {
			this._index = index;
			this._indexOfThread = indexOfThread;
			this._leftChildStatistics = trainingContext_.GetStatisticsAggregator();
			this._rightChildStatistics = trainingContext_.GetStatisticsAggregator();
			this._partitionStatistics = new StatisticsAggregator[parameters_.NumberOfCandidateThresholdsPerFeature + 1];
			for (int i = 0; i < parameters_.NumberOfCandidateThresholdsPerFeature + 1; i++)
				this._partitionStatistics[i] = trainingContext_
						.GetStatisticsAggregator();
			
			this.i0 = i0;
			this.i1 = i1;
		}

		public void run() {
			
			float[] thresholds = new float[parameters_.NumberOfCandidateThresholdsPerFeature + 1];

			FeatureResponse feature = trainingContext_.GetRandomFeature(
					data_.GetxDim(), data_.GetyDim(), data_.GetzDim());

			for (int b = 0; b < parameters_.NumberOfCandidateThresholdsPerFeature + 1; b++)
				_partitionStatistics[b].Clear(); // reset statistics

			// Compute feature response per samples at this node
			for (int i = i0; i < i1; i++)
				responsesForCandidate[_indexOfThread][i] = feature.GetResponse(data_, indices_[i]);

			int nThresholds;
			if ((nThresholds = ChooseCandidateThresholds(indices_, i0, i1,
					responsesForCandidate[_indexOfThread], thresholds)) == 0) {
				synchronized (count) {
					count++;
				}
				return;
			}

			// Aggregate statistics over sample partitions
			for (int i = i0; i < i1; i++) {
				// Slightly faster than List<float>.BinarySearch() for fewer
				// than 100 thresholds
				int b = 0;
				while (b < nThresholds && responsesForCandidate[_indexOfThread][i] >= thresholds[b])
					b++;

				_partitionStatistics[b].Aggregate(data_, indices_[i]);
			}

			for (int t = 0; t < nThresholds; t++) {
				_leftChildStatistics.Clear();
				_rightChildStatistics.Clear();
				for (int p = 0; p < nThresholds + 1 /* i.e. nBins */; p++) {
					if (p <= t)
						_leftChildStatistics.Aggregate(_partitionStatistics[p]);
					else
						_rightChildStatistics
								.Aggregate(_partitionStatistics[p]);
				}

				// Compute gain over sample partitions
				double gain = trainingContext_.ComputeInformationGain(
						parentStatistics_, _leftChildStatistics,
						_rightChildStatistics);
		
				gainsForCandidate[_index] = gain;
				thresholdsForCandidate[_index] = thresholds[t];
				featuresForCandidate[_index] = feature;
				
				synchronized (count) {
					count++;
				}
			}
		
		}
		
		private int ChooseCandidateThresholds(int[] indices, int i0, int i1,
				float[] responses, float[] thresholds) {
			if (thresholds == null
					|| thresholds.length < parameters_.NumberOfCandidateThresholdsPerFeature + 1)
				thresholds = new float[parameters_.NumberOfCandidateThresholdsPerFeature + 1]; // lazy allocation

			// Form approximate quantiles by sorting a random draw of response
			// values
			int nThresholds;
			float[] quantiles = thresholds; // reuse same block of memory to avoid allocation
			if (i1 - i0 > parameters_.NumberOfCandidateThresholdsPerFeature) {
				nThresholds = parameters_.NumberOfCandidateThresholdsPerFeature;
				for (int i = 0; i < nThresholds + 1; i++)
					quantiles[i] = responses[(int) (i0 + Math.random() * (i1 - i0))]; // sample
																						// randomly
																						// from
																						// all
																						// responses
			} else {
				nThresholds = i1 - i0 - 1;
				quantiles = Arrays.copyOfRange(responses, i0, i1);
			}
			Arrays.sort(quantiles);

			if (quantiles[0] == quantiles[nThresholds])
				return 0; // all sampled response values were the same

			// We from n candidate thresholds by sampling in between n+1 approximate
			// quantiles
			for (int i = 0; i < nThresholds; i++)
				thresholds[i] = quantiles[i]
						+ (float) (Math.random() * (quantiles[i + 1] - quantiles[i]));

			return nThresholds;
		}
	}
}
