package edu.masi.hyperadvisor.algorithms;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import edu.masi.hyperadvisor.distancemetrics.*;
import edu.masi.hyperadvisor.kernels.GaussianKernel;
import edu.masi.hyperadvisor.structures.Algo;
import edu.masi.hyperadvisor.structures.AlgoCollection;
import edu.masi.hyperadvisor.util.Log;

/**
 * This is the main estimation class for the service.
 * It finds the nearest neighbors of a proposed algorithm based on past executions and a Euclidean distance metric.
 * It then finds an estimation that depends a user-determined probability that the algorithm will succeed.
 * @author covingkj
 *
 */
public class KNN {
	
	/**
	 * For each set of neighbors with the same distance,
	 * creates a gaussian curve for each resource using the set's mean and standard deviation use of that resource.
	 * Then creates a weighted sum of the gaussians where neighbors closer to the proposed algorithm,
	 * and weighted more heavily.
	 * Then if threshold < 1.0, it returns the time and memory usage at the point on the curve where  
	 * threshold percent of the area under the curve is represented, or
	 * if threshold = 0.9, the point where 90% of the samples used that amount of the resource or less.
	 * If threshold >= 1, returns the largest known resource consumption and multiply times threshold.
	 * @param timeReqs	for each sample, the time it required to run
	 * @param memReqs	for each sample, the memory it required to run
	 * @param weights	for each sample, its weight (inverse distance from the proposed algorithm)
	 * @param threshold	the user-defined threshold/success probability
	 * @return
	 */
	private static int [] getKernelDensityEstimate(Float [] timeReqs, Float [] memReqs, Float [] weights, float threshold){
		List<Float> unsortedDistances = Arrays.asList(Arrays.copyOf(weights, weights.length));
		
		//Get the upper and lower bounds for the time and memory requirements
		Float [] tempTime 	= Arrays.copyOf(timeReqs, timeReqs.length);
		Float [] tempMem 	= Arrays.copyOf(memReqs, timeReqs.length);
		
		Arrays.sort(tempTime);
		Arrays.sort(tempMem);
		
		int minTime = Math.round(tempTime[0]);
		int maxTime = Math.round(tempTime[tempTime.length-1]);
		int minMem = Math.round(tempMem[0]);
		int maxMem = Math.round(tempMem[tempMem.length-1]);
		
		//Initialize the curves based on the min and max requirements
		Double [] timeCurve = new Double[maxTime-minTime];
		Double [] memCurve 	= new Double[maxMem-minMem];
		for(int i = 0; i < timeCurve.length; i++){
			timeCurve[i] = 0.0;
		}
		for(int i = 0; i < memCurve.length; i++){
			memCurve[i] = 0.0;
		}
		
		//Get the total distance (weight), so that each gaussian can be scaled accordingly
		int totalWeight = 0;
		for(int i = 0; i < weights.length; i++){
			totalWeight += weights[i];
		}
		
		/*
		 * For each distance, create a Gaussian curve, 
		 * weight it by it's weight/totalWeight
		 * and add it to the aggregate curves
		 */
		for(int i = 0; i < weights.length; i++){
			Float curDist = weights[i];
			
			//Get all examples at this distance and make a gaussian curve
			ArrayList<Integer> thisTime = new ArrayList<Integer>();
			ArrayList<Integer> thisMem 	= new ArrayList<Integer>();
			
			float timeMean 	= 0;
			float memMean 	= 0;
			
			double timeStd 	= 0;
			double memStd 	= 0;
			
			while (unsortedDistances.indexOf(curDist) != -1){
				int pos = unsortedDistances.indexOf(curDist);
				
				thisTime.add(Math.round(timeReqs[pos]));
				timeMean += Math.round(timeReqs[pos]);
				
				thisMem.add(Math.round(memReqs[pos]));
				memMean += Math.round(memReqs[pos]);
				
				unsortedDistances.set(pos, -1f);
			}
			
			if(thisTime.size() > 0 && thisMem.size() > 0){
				//Normalize the means
				timeMean	/= thisTime.size();
				memMean 	/= thisMem.size();
				
				//Calculate the std given all instances and their mean
				for(int j = 0; j < thisTime.size(); j++){
					timeStd += Math.pow(thisTime.get(j) - timeMean	, 2);
					memStd 	+= Math.pow(thisMem.get(j) 	- memMean	, 2);
				}
				
				//Normalize the stds
				timeStd = (1/(double)thisTime.size() 	* Math.sqrt(timeStd));
				memStd 	= (1/(double)thisMem.size() 	* Math.sqrt(memStd));
				
				//Get the gaussian curves and add their weighted value to the total curve
				double [] tempTimeCurve = GaussianKernel.getKernel(timeMean, 	(float)timeStd, 	minTime, 	maxTime);
				double [] tempMemCurve 	= GaussianKernel.getKernel(memMean, 	(float)memStd, 		minMem, 	maxMem);
				
				for(int j = 0; j < tempTimeCurve.length; j++){
					timeCurve[j] 	+= tempTimeCurve[j]	*((thisTime.size()*curDist)/(float)totalWeight);
				}
				for(int j = 0; j < tempMemCurve.length; j++){
					memCurve[j] 	+= tempMemCurve[j]	*((thisMem.size()*curDist)/(float)totalWeight);
				}
			}
			
		}
		
		//Return the values for each curve at the user-selected threshold
		int [] res = new int[2];
		
		if(threshold < 1.0f){
			double timeSum 	= 1.0;
			double memSum	= 1.0;
			for(int i = timeCurve.length -1; i >= 0 ; i--){
				timeSum -= timeCurve[i];
				if(timeSum <= threshold){
					res[0] = i + minTime;
					break;
				}
			}
			
			for(int i = memCurve.length -1; i >= 0 ; i--){
				memSum -= memCurve[i];
				if(memSum <= threshold){
					res[1] = i + minMem;
					break;
				}
			}
		} else {
			res[0] = (int) (maxTime *threshold);
			res[1] = (int) (maxMem *threshold);
		}
		
		if(res[0] <= 0 || res[1] <= 0){
			return null;
		} else {
			return res;
		}
	}
	
	/**
	 * Called by getEstimate
	 * Gets the k-nearest neighbors (from aggregate data) based on a Euclidean distance metric.
	 * Returns their resource requirements and distances from the proposed algorithm.
	 * @param alg	the proposed algorithm
	 * @param algC	the aggregate data of previous runs and their resource requirements
	 * @param k		the number of neighbors to retrieve
	 * @return		float[3][k] that contains the time, memory, and distance float[k] arrays respectively
	 */
	private static Float[][] getNearestNeigborsAndDistances(Algo alg, AlgoCollection algC, int k){
		Algo 	[]	 prevData		= algC.instances.toArray(new Algo[0]);
		
		int		[][] prevOuts		= new int[prevData.length][4];
		for(int i = 0; i < prevData.length; i++){
			prevOuts[i][0] = prevData[i].usedTime;
			prevOuts[i][1] = prevData[i].usedMem;
			prevOuts[i][2] = prevData[i].allocTime;
			prevOuts[i][3] = prevData[i].allocMem;		
		}
				
		//Otherwise calculate the distances from previous runs
		float [] distances = new float[prevData.length];		
		for(int i = 0; i < prevData.length; i++){
			distances[i] = EuclideanDistance.getDistance(prevData[i], alg);
		}
		Log.writeToLog("Algorithm", "distance.txt");
		for(int i=0; i<prevData.length; i++)
		{
			Log.writeToLog(distances[i]+"", "distance.txt");
			//distances[i]=new Float(3.2);
		}
		//Get the k lowest values (shortest distances)
		ArrayList<Float> unsortedDistances = new ArrayList<Float>(distances.length);//Arrays.asList(distances);
		//float[] unsortedDistances = new float[distances.length];
		for(int x=0; x<distances.length; x++)
		{
			unsortedDistances.add(distances[x]);
		}
		Arrays.sort(distances);
		
		int numPrevs = unsortedDistances.size();
		
		Float [][] results = new Float[3][k];
		
		//If there are more neighbors than needed, pick the closest and average their outputs to form the estimate
		if(numPrevs >= k)
		{
			for(int i = 0; i < k; i++){				
				//Take the most recent occurrence to deal with system changes over time
				int whichOut = unsortedDistances.lastIndexOf(distances[i]);
				//int whichOut = Arrays.binarySearch(unsortedDistances, distances[i]);
				float thisWeight = 0;
				if(unsortedDistances.get(whichOut) == 0){
					thisWeight = 1;
				}
				else
				{
				    thisWeight = 1/unsortedDistances.get(whichOut);
				}

				
				//Do not count a neighbor more than once
				unsortedDistances.set(whichOut,-1f);
				
				float temp = prevOuts[whichOut][0];
				results[0][i] = temp;
				temp = prevOuts[whichOut][1];
				results[1][i] = temp;
				results[2][i] = thisWeight;
			}	
			
			return results;
		}
		
		/*
		 * Otherwise something about the number of neighbors versus the input data is wrong,
		 * so just return null
		 */
		return null;
	}
	
	/**
	 * The public method for finding the estimated resource requirements.
	 * Gets the nearest neighbors and their distances, then gets a kernel density estimate
	 * and chooses an estimate based on the user-defined success probability (threshold).
	 * 
	 * @param alg			an Algo object containing input and class information for the proposed execution
	 * @param algc			an AlgoCollection object containing input and output information for previous executions
	 * @param k				the number of neighbors to average
	 * @param threshold		the percent of samples which would run under the returned estimate
	 * 
	 * @return				the resource estimates based on the preferred threshold
	 */
	public static int[] getEstimate(Algo alg, AlgoCollection algC, int k, float threshold){
		
		/* If there are no successful runs return null.
		 * This should never happen, 
		 * but we check in case something awful happens or if the code changes in the future.
		 */
		if(algC == null || algC.instances == null || algC.instances.size() == 0){
			return null;
		}
		
		//Get the k nearest neighbors and their distances from the proposed algorithm alg
		Float[][] results = getNearestNeigborsAndDistances(alg, algC, k);
		
		/*
		 * Use these neighbors to create a kernel density estimator for each resource and
		 * return the value at the preferred threshold on each curve
		 */
		if(results != null){
			return getKernelDensityEstimate(results[0], results[1], results[2], threshold);
		}
		
		return null;
	}	

}
