/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package database;

import util.OutputType;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;

/**
 *
 * @author rbanalagay
 */
public class SplittingRule implements Serializable{
    
    private VectorSet leftSet;
    private VectorSet rightSet;
    private double weightedVariance;
    private OutputType type;
    private int elemIndex;
    private double splitValue;
    private String[] categories;
    private boolean isCategorical = false;
    static final long serialVersionUID = 42L;
    /**
     * For Ordered Variables
     * @param elemIndex
     * @param value
     * @param leftSet
     * @param rightSet
     * @param type 
     */
    public SplittingRule(int elemIndex, double splitValue,
            ArrayList<DatabaseEntry> leftSet,
            ArrayList<DatabaseEntry> rightSet,OutputType type)
    {
        this.elemIndex = elemIndex;
        this.splitValue = splitValue;
        isCategorical = false;
        this.type = type;
        calculateWeightedVariance(leftSet,rightSet);
        this.leftSet = new VectorSet(leftSet.
                toArray(new DatabaseEntry[leftSet.size()]));
        this.rightSet = new VectorSet(rightSet.
                toArray(new DatabaseEntry[rightSet.size()]));
    }
    /**
     * For Categorical Variables
     * @param elemIndex
     * @param leftSet
     * @param rightSet
     * @param type 
     */
    public SplittingRule(int elemIndex,
            ArrayList<DatabaseEntry> leftSet,
            ArrayList<DatabaseEntry> rightSet,OutputType type)
    {
        this.elemIndex = elemIndex;
        isCategorical = true;
        this.type = type;
        calculateWeightedVariance(leftSet,rightSet);
        
        ArrayList<String> categoryList = new ArrayList<String>();
        for(DatabaseEntry currentEntry : leftSet)
        {
            if(!categoryList.contains(currentEntry.getTableInputs()[elemIndex]))
                categoryList.add(currentEntry.getTableInputs()[elemIndex]);
        }
        categories = categoryList.toArray(new String[categoryList.size()]);
        
        this.leftSet = new VectorSet(leftSet.
                toArray(new DatabaseEntry[leftSet.size()]));
        this.rightSet = new VectorSet(rightSet.
                toArray(new DatabaseEntry[rightSet.size()]));
    } 
    
    public boolean goLeft(String[] inputVector)
    {
        //TODO: THIS WILL FAIL IF THE RULE WAS TRAINED AS A 
        //VALUE BUT A CATEGORICAL VARIABLE COMES IN
        
        if(isCategorical)
        {
            String testString = inputVector[elemIndex];
            boolean decision = false;
            for(String currentCategory : categories)
            {
                decision = currentCategory.equalsIgnoreCase(testString);
                if(decision)
                    break;
            }
            return decision;
        }
        else
        {
            double testVal = Double.parseDouble(inputVector[elemIndex]);
            return testVal<=splitValue;
        }
    }
    
    private double calculateMean(ArrayList<DatabaseEntry>set)
    {
        double mean = 0.0;
        for(DatabaseEntry currentEntry : set)
        {
            switch(type)
            {
                case CPU_TYPE:
                    mean += currentEntry.getCPUTime();
                    break;
                case MEM_TYPE:
                    mean += currentEntry.getMemUsed();
                    break;
            }
        }
        mean /= (double)set.size();
        return mean;
    }
    private double calculateVariance(ArrayList<DatabaseEntry>set)
    {
        
        double mean = calculateMean(set);
        double variance = 0.0;
        for(DatabaseEntry currentEntry : set)
        {
            switch(type)
            {
                case CPU_TYPE:
                    variance += (mean-currentEntry.getCPUTime())*
                            (mean-currentEntry.getCPUTime());
                    break;
                case MEM_TYPE:
                    variance += (mean-currentEntry.getMemUsed())*
                            (mean-currentEntry.getMemUsed());
                    break;
            }
        }
        variance /= (double)set.size();
        return variance;

    }
    private void calculateWeightedVariance(ArrayList<DatabaseEntry>leftSet, 
            ArrayList<DatabaseEntry>rightSet)
    {
        int count = leftSet.size()+rightSet.size();
        double leftWeighting = (double)leftSet.size()/(double)count;
        double rightWeighting = (double)rightSet.size()/(double)count;
        
        weightedVariance =  leftWeighting*calculateVariance(leftSet)+
                rightWeighting*calculateVariance(rightSet);
        /*weightedVariance = leftWeighting*calculateAbsDeviation(leftSet)+
                rightWeighting*calculateAbsDeviation(rightSet);*/
    }
    private double calculateAbsDeviation(ArrayList<DatabaseEntry>set)
    {
        double median = findMedian(set);
        double error = 0.0;
        for(DatabaseEntry currentEntry : set)
        {
            switch(type)
            {
                case CPU_TYPE:
                    error += Math.abs(currentEntry.getCPUTime()-median);
                    break;
                case MEM_TYPE:
                    error += Math.abs(currentEntry.getMemUsed()-median);
                    break;
            }
        }
        return error/((double)set.size());
    }
    private double findMedian(ArrayList<DatabaseEntry>set)
    {
        switch(type)
        {
            case CPU_TYPE:
                Collections.sort(set,new Comparator<DatabaseEntry>() {

                    @Override
                    public int compare(DatabaseEntry t, DatabaseEntry t1) {
                        return Double.compare(t.getCPUTime(), t1.getCPUTime());
                    }
                });
                return set.get(set.size()/2).getCPUTime();
            case MEM_TYPE:  
                Collections.sort(set,new Comparator<DatabaseEntry>() {
                    @Override
                    public int compare(DatabaseEntry t, DatabaseEntry t1) {
                        return Double.compare(t.getMemUsed(), t1.getMemUsed());
                    }
                });
                return set.get(set.size()/2).getMemUsed();
        }
        return 0;
    }
    
    
    public VectorSet getLeftSet()
    {
        return leftSet;
    }
    public VectorSet getRightSet()
    {
        return rightSet;
    }
    
    public double getWeightedVariance()
    {
        return weightedVariance;
    }
    
    @Override
    public String toString()
    {
        String myString = isCategorical ? "Categorical," : "Numbered";
        myString += " Index: " +elemIndex;
        if(!isCategorical)
        {
            myString+= ", Value:"+splitValue;
            return myString;
        }
        else
        {
            myString += " Categories:";
            for(String s : categories)
            {
                myString += s+", ";
            }
            myString = myString.substring(0,myString.length()-1);
            return myString;
        }
    }
}
