/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package database;

import util.OutputType;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Set;

/**
 *
 * Represents a class of vectors from a database table
 * @author rbanalagay
 */
public class VectorSet implements Serializable{
    
    private DatabaseEntry[] dbRows;
    private boolean[] typeLabels = new boolean[0];
    static final long serialVersionUID = 42L;
    public static VectorSet getCopy(VectorSet vectorSetToCopy)
    {
        DatabaseEntry[] dbRowsCopy = 
                new DatabaseEntry[vectorSetToCopy.dbRows.length];
        for(int x=0; x<dbRowsCopy.length; x++)
        {
            dbRowsCopy[x] = DatabaseEntry.getCopy(vectorSetToCopy.dbRows[x]);
        }
        return new VectorSet(dbRowsCopy);
    }
    
    public VectorSet(DatabaseEntry[] dbRows)
    {
        this.dbRows = dbRows;
        if(dbRows!=null && dbRows.length>0)
        {
            DatabaseEntry.alignInputs(this.dbRows);
            createTypeLabels();
        }
    }
    
    public VectorSet getVectorSubset(int startIndex, int endIndex)
    {
        DatabaseEntry[] copy = new DatabaseEntry[endIndex-startIndex];
        for(int x=startIndex; x<endIndex; x++)
        {
            copy[x-startIndex] = DatabaseEntry.getCopy(dbRows[x]);
        }
        return new VectorSet(copy);
    }

    public DatabaseEntry[] getDBRows()
    {
        return dbRows;
    }
    private void createTypeLabels()
    {
        typeLabels = new boolean[dbRows[0].getTableInputs().length];
        for(int x=0; x<typeLabels.length; x++)
        {
            boolean isCategorical = false;
            for(DatabaseEntry currentEntry : dbRows)
            {
                try
                {
                    Double.parseDouble(currentEntry.getTableInputs()[x]);
                }
                catch(NumberFormatException e)
                {
                    isCategorical = true;
                    break;
                }
            }
            //Hack to make numeric categories stay as categories
            if(isCategorical)
            {
                for(DatabaseEntry currentEntry : dbRows)
                {
                    if(!currentEntry.getTableInputs()[x].contains("CateTag"))
                        currentEntry.getTableInputs()[x]+="CateTag";
                }                
            }
            typeLabels[x] = isCategorical;
        }
    }
    
    /**
     * returns a splitting rule based on the specified element index to 
     * search through
     * @param index
     * @return 
     */
    public SplittingRule getSplitRule(int index,OutputType type)
    {
        if(index>=typeLabels.length)
            return null;
        else
        {
            //if categorical variable
            if(typeLabels[index])
            {
                return findCategoricalSplit(index,type);
            }
            else
            {
                return findOrderedSplit(index,type);
            }
        }
    }
    
    
    private SplittingRule findOrderedSplit(int index,OutputType type)
    {
        final int searchIndex = index;
        Arrays.sort(dbRows,new Comparator<DatabaseEntry>() {

            @Override
            public int compare(DatabaseEntry t, DatabaseEntry t1) {
                return Double.compare(Double.
                        parseDouble(t.getTableInputs()[searchIndex]),
                        Double.parseDouble(t1.getTableInputs()[searchIndex]));
            }
        });
        
        
        String currentVal = dbRows[0].getTableInputs()[searchIndex];
        ArrayList<SplittingRule> rules = new ArrayList<SplittingRule>();
        for(int x=1; x<dbRows.length; x++)
        {
            if(!dbRows[x].getTableInputs()[searchIndex].equals(currentVal))
            {
                ArrayList<DatabaseEntry> leftSet = 
                        new ArrayList<DatabaseEntry>();
                ArrayList<DatabaseEntry> rightSet = 
                        new ArrayList<DatabaseEntry>();
                for(int y=0; y<x; y++)
                {
                    leftSet.add(dbRows[y]);
                }
                for(int y=x; y<dbRows.length; y++)
                {
                    rightSet.add(dbRows[y]);
                }
                currentVal = dbRows[x].getTableInputs()[searchIndex];
                double splitVal = Double.parseDouble
                        (dbRows[x].getTableInputs()[searchIndex]);
                SplittingRule currentRule = 
                        new SplittingRule(searchIndex,splitVal,
                        leftSet,rightSet,type);
                rules.add(currentRule);
            }
        }
        Collections.sort(rules,SplittingRuleComparator.getInstance());
        if(!rules.isEmpty())
            return rules.get(0);
        
        
        /*
        SplittingRule[] rules = new SplittingRule[dbRows.length-1];
        SplittingRule medianRule = null;
        for(int splitIndex=1; splitIndex<dbRows.length; splitIndex++)
        {
            ArrayList<DatabaseEntry> leftSet = new ArrayList<DatabaseEntry>();
            ArrayList<DatabaseEntry> rightSet = new ArrayList<DatabaseEntry>();
            for(int x=0; x<splitIndex; x++)
            {
                leftSet.add(dbRows[x]);
            }
            for(int x=splitIndex; x<dbRows.length; x++)
            {
                rightSet.add(dbRows[x]);
            }
            double splitVal = Double.parseDouble(dbRows[splitIndex]
                    .getTableInputs()[index]);
            SplittingRule currentRule = new SplittingRule(searchIndex,splitVal,
                    leftSet,rightSet,type);
            rules[splitIndex-1] = currentRule;
            //if(rules.length>3 && splitIndex == dbRows.length/2)
            //{
            //    medianRule = currentRule;
            //}
        }
        if(rules.length>0)
        {
            Arrays.sort(rules,SplittingRuleComparator.getInstance());   
            /*if((rules[0].getLeftSet().getNumVectors()==1 
                    || rules[0].getRightSet().getNumVectors() == 1)
                    && medianRule != null)
            {
                return medianRule;
            }
            return rules[0];
        }*/
        return null;
    }
    
    
    private SplittingRule findCategoricalSplit(int index,OutputType type)
    {
        HashMap<String,ArrayList<DatabaseEntry>> categoryMap = 
                new HashMap<String,ArrayList<DatabaseEntry>>();
        for(DatabaseEntry currentEntry : dbRows)
        {
            if(categoryMap.containsKey(
                currentEntry.getTableInputs()[index]))
            {
                categoryMap.get(currentEntry.getTableInputs()[index])
                        .add(currentEntry);
            }
            else
            {
                ArrayList<DatabaseEntry> entries = 
                        new ArrayList<DatabaseEntry>();
                entries.add(currentEntry);
                categoryMap.put(currentEntry.getTableInputs()[index], entries);
            }
        }
        Set<String> keys = categoryMap.keySet();
        CategoryAverage[] averages = new CategoryAverage[keys.size()];
        int averagesIndex = 0;
        for(String currentCategory : keys)
        {
            ArrayList<DatabaseEntry> categoryEntries = 
                    categoryMap.get(currentCategory);
            averages[averagesIndex] = 
                    new CategoryAverage(currentCategory, categoryEntries,type);
            averagesIndex++;
        }
        Arrays.sort(averages, new Comparator<CategoryAverage>() {

            @Override
            public int compare(CategoryAverage t, CategoryAverage t1) {
                return Double.compare(t.averageVal, t1.averageVal);
            }
        });
        
        SplittingRule[] rules = new SplittingRule[averages.length-1];
        for(int splitPoint=1; splitPoint<averages.length; splitPoint++)
        {
            ArrayList<DatabaseEntry> leftSet = new ArrayList<DatabaseEntry>();
            ArrayList<DatabaseEntry> rightSet = new ArrayList<DatabaseEntry>();
            for(int x=0; x<splitPoint; x++)
            {
                leftSet.addAll(categoryMap.get(averages[x].id));
            }
            for(int x=splitPoint; x<averages.length; x++)
            {
                rightSet.addAll(categoryMap.get(averages[x].id));
            }
            rules[splitPoint-1] = new SplittingRule(index,leftSet,
                    rightSet,type);
        }
        if(rules.length>0)
        {
            Arrays.sort(rules,SplittingRuleComparator.getInstance());
            return rules[0];
        }
        return null;
    }
    
    
    private class CategoryAverage
    {
        public String id;
        public double averageVal;
        public CategoryAverage(String id,
                ArrayList<DatabaseEntry> categoryEntries,OutputType type)
        {
            this.id = id;
            averageVal = 0.0;
            for(DatabaseEntry currentEntry : categoryEntries)
            {
                switch(type)
                {
                    case CPU_TYPE:
                        averageVal+=currentEntry.getCPUTime();
                        break;
                    case MEM_TYPE:
                        averageVal+=currentEntry.getMemUsed();
                        break;
                }
            }
            averageVal = averageVal/((double)(categoryEntries.size()));
        }

    }
    public double[] getOutputValues(OutputType type)
    {
        
        double[] outputValues = new double[dbRows.length];
        int count = 0;
        for(DatabaseEntry currentEntry : dbRows)
        {
            switch(type)
            {
                case CPU_TYPE:
                    outputValues[count] = currentEntry.getCPUTime();
                    break;
                case MEM_TYPE:
                    outputValues[count] = currentEntry.getMemUsed();
                    break;
            }
            count++;
        }
        return outputValues;        
    }
    
    public int getNumVectorElements()
    {
        if(dbRows!=null && dbRows.length>0)
            return dbRows[0].getTableInputs().length;
        return 0;
    }
    public int getNumVectors()
    {
        return dbRows.length;
    }
}
