package edu.vanderbilt.masi.algorithms.regression;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;


public class ForestTrainer {
	private static Tree[] trees;
	private static Integer count;
	public static Forest TrainForest(
		      TrainingParameters parameters,
		      TrainingContext context,
		      DataCollection data) throws Exception {

		      trees = new Tree[parameters.NumberOfTrees];
		      count = 0;
		      int nCpus = Runtime.getRuntime().availableProcessors();
		      ExecutorService pool =  Executors.newFixedThreadPool(nCpus);
		      for (int t = 0; t < parameters.NumberOfTrees; t++)
		      {
		    	  DataCollection dc = new DataCollection(data.GetBoxCount());
		    	  for (int i = 0; i < 3; ++i) {
		    		  int index = (int)(Math.random() * dc.GetDataItemCount());
		    		  DataItem di = data.GetDataItem(index);
		    		  dc.AddDataItem(di);
		    	  }
		    	  TrainingTreeRunnable r = (new ForestTrainer()).new 
		    			  TrainingTreeRunnable(t, context, parameters, dc);
		    	  pool.submit(r);
		    	  //forest.AddTree(tree);
		      }
		      
		      boolean flag =false;
		      while (!flag) {
		    	  Thread.sleep(30);
		    	  synchronized (count) {
		    		  if (count == parameters.NumberOfTrees) 
		    			  flag = true;
		    	  }
		      }
		      
		      Forest forest = new Forest();
		      for (int i = 0; i < parameters.NumberOfTrees; ++i) {
		    	  forest.AddTree(trees[i]);
		      }
		      pool.shutdown();
		      return forest;
		    }
	
	class TrainingTreeRunnable implements Runnable {
		private int _index;
		private TrainingContext _context;
		private TrainingParameters _parameters;
		private DataCollection _data;
		
		public TrainingTreeRunnable(int index, TrainingContext context,
				TrainingParameters parameters, DataCollection data) {
			_index =  index;
			_context = context;
			_parameters = parameters;
			_data = data;
		}
		
		public void run() {
			System.out.println(_index);
			Tree tree = null;
	    	try {
				tree = TreeTrainer.TrainTree(_context, _parameters, _data);
			} catch (Exception e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
	    	
	    	trees[_index] = tree;
	    	synchronized (count) {
	    		count++;
	    	}
		}
	}
}
