/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package cart;

import database.DatabaseEntry;
import database.SplittingRule;
import database.SplittingRuleComparator;
import database.VectorSet;
import edu.masi.hyperadvisor.util.Log;
import util.OutputType;
import util.ResultStatistic;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Scanner;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

//import com.google.gson.Gson;

/**
 *
 * @author rbanalagay
 */
public class CART implements Serializable{
    
    private Node rootNode;
    private OutputType type;
    //public static int MIN_NODE_SIZE = 1;
    private static final int NUM_THREADS = 10;
    private String title;
    //private double maxSeen = 0.0;
    public int NUM_NODES = 0;
    private int MIN_NODE_SIZE = 1;
    private double fudgeFactor = 3;
    
    //private static Gson gson = new Gson();
    //THIS IS SAVED JUST FOR THE PURPOSE OF ALIGNING FUTURE ESTIMATION REQUESTS
    private DatabaseEntry sampleEntry = null;
    private static final long serialVersionUID = 42L; 
    
    
    
    public CART(String title,VectorSet allVectors,OutputType type,int minNodeSize,double fudgeFactor)
    {
        this.title = title;
        this.type = type;
        this.MIN_NODE_SIZE = minNodeSize;
        this.fudgeFactor = fudgeFactor;
        rootNode = trainTree(allVectors);
        this.sampleEntry = allVectors.getDBRows()[0];
        System.out.println(title);
    }
    
    public double getEstimate(String[] inputs,String machineID,String isUseGrid,double fudgeFactor)
    {
    	DatabaseEntry entryToEstimate = new DatabaseEntry(inputs, machineID, isUseGrid, 0, 0);
    	DatabaseEntry[] dummy = new DatabaseEntry[2];
    	dummy[0] = entryToEstimate;
    	dummy[1] = sampleEntry;
    	VectorSet tempSet = new VectorSet(dummy);
    	return rootNode.getEstimate(tempSet.getDBRows()[0].getTableInputs(),fudgeFactor);
    }
 
    public void updateValues(String[] inputVector, String machineID,String isUseGrid, int newValue)
    {
    	
    	DatabaseEntry entryToEstimate = new DatabaseEntry(inputVector, machineID, isUseGrid, (int)newValue, (int)newValue);
    	DatabaseEntry[] dummy = new DatabaseEntry[2];
    	dummy[0] = entryToEstimate;
    	dummy[1] = sampleEntry;
    	VectorSet tempSet = new VectorSet(dummy);
    	rootNode.updateVals(tempSet.getDBRows()[0].getTableInputs(), newValue);
    }
    
   public static CART getCARTFromFile(String algName,OutputType type) throws IOException, ClassNotFoundException 
   {
	   Log.writeToLog("GotHere");
//Log.writeToLog(read(System.getProperty("user.home")+"/trainingFiles/"
	//               		+algName+"_CPUTIME.txt"));
	   Log.writeToLog("NotHere");
	   
       FileInputStream f_in = null;
       if(type.equals(OutputType.CPU_TYPE))
       f_in = new 
               FileInputStream(System.getProperty("user.home")+"/trainingFiles/"
	               		+algName+"_CPUTIME.obj");
       else
           f_in = new FileInputStream(System.getProperty("user.home")+"/trainingFiles/"
              		+algName+"_MEMUSED.obj");
       // Read object using ObjectInputStream
       ObjectInputStream obj_in = 
               new ObjectInputStream (f_in);

       // Read an object
       Object obj = obj_in.readObject();
       
       obj_in.close();
       return (CART)obj;
	   
   }
    
   private static String read(String filename) throws IOException
   {

	   StringBuilder text = new StringBuilder();
	   String NL = System.getProperty("line.separator");
	   FileInputStream fis = new FileInputStream(filename);

	   Scanner scanner = new Scanner(fis);
	   try {
		   while (scanner.hasNextLine()){
			   text.append(scanner.nextLine()).append(NL);
		   }
	   }
	   finally
	   {
		   scanner.close();
	   }
	   return text.toString();

   } 
    private Node trainTree(VectorSet allVectors)
    {

        Node currentNode;
        if(allVectors.getNumVectors()>MIN_NODE_SIZE)
        {
            ExecutorService executor = 
                    Executors.newFixedThreadPool(NUM_THREADS);
            ArrayList<Future<SplittingRule>> results = 
                    new ArrayList<Future<SplittingRule>>();
            
            for(int x=0; x<allVectors.getNumVectorElements(); x++)
            {
                Callable<SplittingRule> worker = 
                        new RuleSearcher(VectorSet.getCopy(allVectors), 
                        x, type);
                Future<SplittingRule> splitResult = executor.submit(worker);
                results.add(splitResult);
            }
            
            ArrayList<SplittingRule> possibleSplits = 
                    new ArrayList<SplittingRule>();
            
            for(Future<SplittingRule> currentFuture : results)
            {
                SplittingRule currentRule = null;
                try
                {
                    currentRule = currentFuture.get();
                }
                catch(ExecutionException e){e.printStackTrace();}
                catch(InterruptedException e){e.printStackTrace();}
                if(currentRule!=null)
                    possibleSplits.add(currentRule);
                
            }
            executor.shutdown();
            while(!executor.isTerminated()){}
            
            Collections.sort(possibleSplits,
                    SplittingRuleComparator.getInstance());
            if(possibleSplits!=null && !possibleSplits.isEmpty())
            {
                
                currentNode = new Node(possibleSplits.get(0),fudgeFactor);
                currentNode.
                        setLeftChild(trainTree(possibleSplits.get(0)
                        .getLeftSet()));
                currentNode.
                        setRightChild(trainTree(possibleSplits.get(0)
                        .getRightSet()));
                return currentNode;
            }
        }
        currentNode = new Node(allVectors.getOutputValues(type),fudgeFactor);
        return currentNode;
    }
    
    /*public ResultStatistic getValidationResults(VectorSet validationSet)
    {
        
        int passCount = 0;
        int failCount = 0;
        double errorSum = 0.0;
        double mmse = 0.0;
        double totalNormalizedWaste = 0.0;
        double totalAbsoluteWaste = 0.0;
        for(DatabaseEntry currentEntry : validationSet.getDBRows())
        {
            double estimate = 
                    rootNode.getEstimate(currentEntry.getTableInputs());
            double currentWaste = 0.0;
            switch(type)
            {
                case CPU_TYPE:
                    if(estimate>=currentEntry.getCPUTime())
                    {
                        passCount++;
                        errorSum += (estimate/currentEntry.getCPUTime());
                        currentWaste+= estimate-currentEntry.getCPUTime();
                        
                    }
                    else
                    {
                        failCount++;
                        int count = 1;
                        while(count*estimate<currentEntry.getCPUTime())
                        {
                            currentWaste += (double)count*estimate;
                            count++;
                        }
                        currentWaste += ((double)count*estimate)-currentEntry.getCPUTime();
                    } 
                    totalNormalizedWaste += currentWaste/currentEntry.getCPUTime();
                    mmse+=(estimate-currentEntry.getCPUTime())*(estimate-currentEntry.getCPUTime());
                    rootNode.updateVals(currentEntry.getTableInputs(),
                            currentEntry.getCPUTime());
                    //if(maxSeen<currentEntry.getCPUTime())
                    //    maxSeen = currentEntry.getCPUTime();
                    break;
                case MEM_TYPE:
                    if(estimate>=currentEntry.getMemUsed())
                    {
                        passCount++;
                        errorSum += (estimate/currentEntry.getMemUsed());
                        currentWaste += estimate-currentEntry.getMemUsed();
                    }
                    else
                    {
                        failCount++;
                        int count = 1;
                        while(count*estimate<currentEntry.getMemUsed())
                        {
                            currentWaste += (double)count*estimate;
                            count++;
                        } 
                        currentWaste += (double)count*estimate-currentEntry.getMemUsed();
                        
                    }
                    totalNormalizedWaste += currentWaste/currentEntry.getMemUsed();
                    mmse+=(estimate-currentEntry.getMemUsed())*(estimate-currentEntry.getMemUsed());
                    rootNode.updateVals(currentEntry.getTableInputs(), 
                            currentEntry.getMemUsed());
                    //if(maxSeen<currentEntry.getMemUsed())
                    //    maxSeen = currentEntry.getMemUsed();
                    break;
            }
            totalAbsoluteWaste += currentWaste;
           
            
        }
        double percentFailed = ((double)failCount)/
                ((double)validationSet.getDBRows().length);
        double avgOverestimateError = errorSum/((double)passCount);
        mmse = mmse/((double)validationSet.getDBRows().length);
        totalAbsoluteWaste /= ((double)validationSet.getDBRows().length);
        totalNormalizedWaste /= ((double)validationSet.getDBRows().length);
        //if(avgOverestimateError == Double.POSITIVE_INFINITY)
        //System.out.println(type+","+errorSum + ", "+passCount);
        return new ResultStatistic(title, percentFailed, avgOverestimateError,mmse,totalNormalizedWaste,totalAbsoluteWaste);
    }*/
    
    public void printTree()
    {
        rootNode.print("", true);
    }
    
    
    private class RuleSearcher implements Callable<SplittingRule>
    {
        private VectorSet vectors;
        private int searchIndex;
        private OutputType type;
        public RuleSearcher(VectorSet vectors, 
                int searchIndex, OutputType type)
        {
            this.vectors = vectors;
            this.searchIndex = searchIndex;
            this.type = type;
        }

        @Override
        public SplittingRule call() throws Exception {
            return vectors.getSplitRule(searchIndex, type);
        }
        
    }
    
}
