package edu.vanderbilt.masi.algorithms.regression;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Regression {
	private double[][] boxes;
	private boolean[] flags;
	
	public TrainingParameters TrainingParameters = new TrainingParameters();
    
    public Forest RunTraining(DataCollection trainingData,int numberOfTree) {
    	// Train the forest
    	System.out.println("Training the forest...");
    	
    	TrainingContext regressionTrainingContext = new TrainingContext(
    		  trainingData.GetBoxCount());
    	TrainingParameters.NumberOfTrees = numberOfTree;

    	try {
    		long tBegin = System.currentTimeMillis();
    		
    		Forest forest = ForestTrainer.TrainForest(
    			TrainingParameters,
    			regressionTrainingContext,
    			trainingData); 
    		
    		System.out.println("Training cost: " + Long.toString(System.currentTimeMillis() -  tBegin));
    		
    		return forest;

    	} catch (Exception e) {
    		e.printStackTrace();
    		return null;
    	}
    }
    
    public Tree RunTrainingForTree(DataCollection data) {
    	// Train the forest
    	System.out.println("Training the tree...");
    	
    	TrainingContext regressionTrainingContext = new TrainingContext(
    		  data.GetBoxCount());
    	try {
    		long tBegin = System.currentTimeMillis();
    		
    		Tree tree = TreeTrainer.TrainTree(regressionTrainingContext, TrainingParameters, data);
    		
    		System.out.println("Training cost: " + Long.toString(System.currentTimeMillis() -  tBegin));
    		
    		return tree;

    	} catch (Exception e) {
    		e.printStackTrace();
    		return null;
    	}
    }
    
    public double[] RunTesting(DataCollection testData, Forest forest, int boxCount) throws InterruptedException {
    	// Apply the trained forest to the test data
		System.out.println("\nApplying the forest to test data...");
		
		if (forest == null) {
			System.out.print("Error: Forest construction error!");
			return null;
		}
		
		long tBegin = System.currentTimeMillis();
    	int[][] leafNodeIndeices = forest.Apply(testData);
    	System.out.println("Got leaf node cost: " + 
    		Long.toString(System.currentTimeMillis() -  tBegin));
    	
    	tBegin = System.currentTimeMillis();
    	int[][][] ids = forest.GetNodeIdForBoxes(boxCount);
    	System.out.println("Got confident leaf node cost: " + 
    		Long.toString(System.currentTimeMillis() -  tBegin));
    	
    	tBegin = System.currentTimeMillis();
    	double[] boxBoundary = new double[boxCount * 6];
    	boxes = new double[boxCount][6];
    	flags = new boolean[boxCount];
    	int nCpus = Runtime.getRuntime().availableProcessors();
    	ExecutorService pool = Executors.newFixedThreadPool(nCpus);
    	for (int i = 0; i < boxCount; ++i) {
    		TestRunnable r = new TestRunnable(i, leafNodeIndeices,
    				ids, testData, forest);
    		pool.submit(r);
    	}
    	
    	boolean flag = true;
    	while (flag) {
    		Thread.sleep(30);
    		int i = 0;
    		for (i = 0; i < boxCount; ++i) {
    			if (!flags[i]) break;
    		}
    		if (i == boxCount) flag = false;
    	}
    	
    	for (int i = 0; i < boxCount; ++i) {
    		for (int j = 0; j < 6; ++j) {
    			boxBoundary[i * 6 + j] = boxes[i][j];
    		}
    	}
    	System.out.println("Got boxBoundary cost: " + 
        		Long.toString(System.currentTimeMillis() -  tBegin));
    	
    	pool.shutdown();
    	return boxBoundary;
    }
    
    class TestRunnable implements Runnable {
    	private int _index;
    	private int[][] _leafNodeIndeices;
    	private int[][][] _ids; 
    	private DataCollection _data;
    	private Forest _forest;
    	
    	public TestRunnable(int index, int[][] leafNodeIndeices, int[][][] ids,
    			DataCollection data, Forest forest) {
    		_index = index;
    		_leafNodeIndeices = leafNodeIndeices;
    		_ids = ids;
    		_data = data;
    		_forest = forest;
    	}
    	
    	public void run() {
    		
    		int[] sampleCounts = new int[_forest.GetTreeCount() * 2];
    		double[][] meanVoxels = new double[_forest.GetTreeCount() * 2][3];
    		for (int i = 0; i < _data.GetDataCount(); ++i) {
    			for (int t = 0; t < _forest.GetTreeCount(); ++t) {
    				if (_leafNodeIndeices[t][i] == _ids[t][_index][0]) {
    					meanVoxels[t * 2] = MatrixComputor.Add(meanVoxels[t * 2],
    							_data.GetVector(i));
    					sampleCounts[t * 2]++;
    				}
    				else if (_leafNodeIndeices[t][i] == _ids[t][_index][1]) {
    					meanVoxels[t * 2 + 1] = MatrixComputor.Add(meanVoxels[t * 2 + 1],
    							_data.GetVector(i));
    					sampleCounts[t * 2 + 1]++;
    				}
    			}
    		}
    		
    		double[][] voxel1 = new double[_forest.GetTreeCount() * 2][6];
    		for (int i = 0; i < _forest.GetTreeCount() * 2; ++i) {
    			if (sampleCounts[i] == 0) sampleCounts[i] = 1;
    			meanVoxels[i] =  MatrixComputor.Scale(meanVoxels[i], 
    					1.0 / (double)(sampleCounts[i]));
    			for (int j = 0; j < 3; ++j) {
    				voxel1[i][j * 2] = meanVoxels[i][j];
    				voxel1[i][j * 2 + 1] = meanVoxels[i][j];
    			}
    		}
    		
    		double[] bc = new double[6];
			double totalProbability = 0.0;
    		for (int t = 0; t < _forest.GetTreeCount(); ++t) {
    			Node node = _forest.GetTree(t).GetNode(_ids[t][_index][0]);
    			double[] dc = node.GetBoxMean(_index);
    			double[] negdc = MatrixComputor.Scale(dc, -1.0);
    			double var = node.GetBoxVariance(_index);
    			double p = 1.0 / (Math.sqrt(2 * Math.PI) * var);
    			totalProbability += p;
    			bc = MatrixComputor.Add(bc, 
    					MatrixComputor.Scale(MatrixComputor.Add(voxel1[t * 2], negdc), p));
    		}

    		bc = MatrixComputor.Scale(bc, 1.0 / totalProbability);
    		boxes[_index] = bc;
    		flags[_index] = true;
    	}
    }
}
