package bl.coe.BigSparseMath;

import java.util.Arrays;

/**
 * The Class BigSparseMatrixTools.
 * 
 * Provides static methods for operating on data structures in the BigSparseMath pacakge.
 * 
 * @author Bennett Landman, bennett.landman@vanderbilt.edu
 */
public class BigSparseMatrixTools {

	/**
	 * Prints the status information for the matrix
	 * 
	 * @param mat the mat
	 */
	public static void printStat(BigMatrix mat) {
		long total=(long)mat.getNRows()*(long)mat.getMCols();
		long count=mat.getNonEmptyCount();
		System.out.println("Total elements:    "+total+" ("+mat.getNRows()+"x"+mat.getMCols()+")");
		System.out.println("Non-zero elements: "+count);
		System.out.println("Density:           "+(100.0*(double)count/(double)total)+"%");

	}

	/**
	 * Symmeterize triangular matrix.
	 * 
	 * @param inMat the in mat
	 * 
	 * @return the big sparse matrix
	 */
	public static BigSparseMatrix symmeterizeTriangularMatrix(
			BigSparseMatrix inMat) {
		int N=(int)inMat.getNRows();
		int M=(int)inMat.getMCols();
		BigSparseMatrix outMat=new BigSparseMatrix(N,M);

		System.out.println("Computing row/column sparsity calculations");
		long count=0; // will add 1 extra element for diagonal (if non-zero)  
		int [][]rowcolCnt = inMat.computeRowColumnCounts(); 	
		for(int i=0;i<2;i++)
			for(int j=0;j<rowcolCnt[0].length;j++)
				count+=rowcolCnt[i][j];

		// bytes per element * elements per row * rows
		double estMemSize = (8.*(double)count+M*30)/1024./1024.;
		System.out.println("Allocating Memory: "+estMemSize+" MB ");
		// Initialize a reasonable amount of memory
		for(int r=0;r<N;r++) {
			outMat.setRowDirect(r,new BigSparseVector(rowcolCnt[0][r]+rowcolCnt[1][r]));
		}
		System.out.println("Performing Symmeterization");
		long tic = System.currentTimeMillis();
		// Creating Symmetric Matrix			
		for(int r=0;r<N;r++) {
			//for each row
			BigSparseVector row = inMat.getRowDirect(r);
			if(row==null) {
				// We should fill in this empty row.
				row = new BigSparseVector(1);
				inMat.setRow(r, row);
			}
			for(int i=0;i<row.countNonEmpty();i++) {
				int index = row.getIndexAt(i);
				float value = row.getValueAt(i);
				//				System.out.println("set:"+r+" "+index+" "+value);System.out.flush();
				outMat.set(r, index, value);
				if(index!=r) {
					//					System.out.println("set:"+index+" "+r+" "+value);System.out.flush();
					outMat.set(index, r, value);
				}				
			}
			if(r%1000==5 && BigMathPreferences.verbose) {
				float tdiff =(System.currentTimeMillis()-tic)/1000.f;
				float trem = tdiff/r*(N-r);
				int hh = (int)Math.floor(trem/3600.f);
				int mm = (int)Math.floor(trem/60.f-hh*60);
				float ss = (float)Math.floor(trem-hh*3600-mm*60);				
				System.out.println("Proc. row: "+r+" \tElapsed Time:"+(tdiff)+" s\tEsti. time remaining:"+hh+":"+mm+":"+ss);
				System.out.flush();
			}
		}
		return outMat;
	}

	/**
	 * Run threaded knn labeled big sparse matrix rows.
	 * 
	 * @param nThreads the n threads
	 * @param k the k
	 * @param mat the mat
	 * 
	 * @return the big dense array vector
	 */
	public static BigDenseArrayVector runThreadedKNNLabeledBigSparseMatrixRows(int nThreads, int k, LabeledBigSparseMatrix mat){
		return runThreadedSimilarityWeightedKNNLabeledBigSparseMatrixRows(nThreads, k, mat, false);
	}

	/**
	 * Run threaded similarity weighted knn labeled big sparse matrix rows.
	 * 
	 * @param nThreads the n threads
	 * @param k the k
	 * @param mat the mat
	 * @param useSimilarityWeighting the flag to use similarity weighting in deciding k-nn
	 * 
	 * @return the big dense array vector
	 */
	public static BigDenseArrayVector runThreadedSimilarityWeightedKNNLabeledBigSparseMatrixRows(int nThreads, int k, LabeledBigSparseMatrix mat, boolean useSimilarityWeighting){
		if(nThreads<=1)
			return runKNNonLabeledBigSparseMatrixRows(k, mat,useSimilarityWeighting);
		BigDenseArrayVector result = new BigDenseArrayVector(mat.getNRows());
		int threadChunk = (int)Math.ceil((double)mat.getNRows()/(double)nThreads);
		int startRow = 0;
		int endRow = threadChunk;
		WorkerThreadWatcher myMonitor = new WorkerThreadWatcher();
		for(int i=0;i<nThreads;i++){			
			if(i==(nThreads-1)) {
				endRow=mat.getNRows();
			}

			myMonitor.watch(new KNNWorkerThread(k,mat,result,startRow,endRow,useSimilarityWeighting));
			startRow=endRow;
			endRow+=threadChunk;
		}
		while(myMonitor.isNotDone()) {

			try {
				for(int i=0;i<100 && myMonitor.isNotDone();i++) {
					Thread.sleep(BigMathPreferences.updateRateMilliseconds/100);					
				}
			} catch (InterruptedException e) {
				break;
			} // wait 20 s
			myMonitor.reportStatus();
		}
		myMonitor.reportStatus();
		return result;
	}

	/**
	 * Run kn non labeled big sparse matrix rows.
	 * 
	 * @param k the k
	 * @param mat the mat
	 * 
	 * @return the big dense array vector
	 */
	public static BigDenseArrayVector runKNNonLabeledBigSparseMatrixRows(int k, LabeledBigSparseMatrix mat, boolean useSimWeighting){
		BigDenseArrayVector result = new BigDenseArrayVector(mat.getNRows());
		runKNNonLabeledBigSparseMatrixRowsSubsetRows(k, mat,result,0,mat.getNRows(),null,useSimWeighting);
		return result;
	}

	/**
	 * Run knn on labeled big sparse matrix rows subset rows.
	 * 
	 * @param k the k
	 * @param mat the mat
	 * @param result the result
	 * @param firstRow the first row
	 * @param lastRow the last row
	 * @param statusCurrentRow the status current row
	 * 
	 * @return the big dense array vector
	 */
	public static BigDenseArrayVector runKNNonLabeledBigSparseMatrixRowsSubsetRows(int k, LabeledBigSparseMatrix mat,BigDenseArrayVector result,int firstRow, int lastRow, int[] statusCurrentRow){
		return 	runKNNonLabeledBigSparseMatrixRowsSubsetRows(k, mat,result,firstRow, lastRow, statusCurrentRow,false);
	}

	/**
	 * Run knn on labeled big sparse matrix rows subset rows.
	 * 
	 * @param k the k
	 * @param mat the mat
	 * @param result the result
	 * @param firstRow the first row
	 * @param lastRow the last row
	 * @param statusCurrentRow the status current row
	 * @param useSimilarityWeighting the flag to indicate if similarity weighting should be used
	 * 
	 * @return the big dense array vector
	 */
	public static BigDenseArrayVector runKNNonLabeledBigSparseMatrixRowsSubsetRows(int k, LabeledBigSparseMatrix mat,BigDenseArrayVector result,int firstRow, int lastRow, int[] statusCurrentRow, boolean useSimilarityWeighting){
		System.out.println("Running KNN w/Sim?="+useSimilarityWeighting); System.out.flush();
		int M = mat.getMCols();
		// typically these will be limited to k, but in the case of a tie, there could be more.
		int []labels = new int[M];
		float []values = new float[M];
		int currentLabelCount;
		int []labelVal = new int[M];
		int []labelCount = new int[M];
		//		int neighborCount;
		float []labelWeight = new float[M];
		for(int r=firstRow;r<lastRow;r++) {
			if(statusCurrentRow!=null) 
				statusCurrentRow[0]=r;

			BigSparseVector row = mat.getRowDirect(r);
			float []valuesUnsafe = row.getDataDirect();
			// get kth order stat
			int orderStat=valuesUnsafe.length-k+1;
			if(orderStat<1)
				orderStat=1; // Fall back to lower order k-NN if k nearest (non-zero) neighbors are not available
//			int indexKNearestNeighbors = valuesUnsafe.length-orderStat+1;
			float kthOrderStat = kthOrderStat(valuesUnsafe, orderStat); // Not M because not Nec. full			
			if(!Float.isNaN(kthOrderStat) || kthOrderStat<0) { // can store neg sims, but not easily comp with NN
				// find all above kth order stat
				int idx=0;
				//			System.out.println(kthOrderStat+" "+valuesUnsafe.length+" "+values.length);
				for(int i=0;i<valuesUnsafe.length;i++) {
					if(valuesUnsafe[i]>=kthOrderStat) {			
						labels[idx] = mat.getLabelForColumn(row.getIndexAt(i));
						values[idx] = valuesUnsafe[i];
						idx++;
					}
				}
				int countNN = idx;
				if(useSimilarityWeighting) { 

					for(int i=0;i<labelWeight.length;i++)
						labelWeight[i]=0; 

					// count the labels
					currentLabelCount = 0;
					for(int i=0;i<countNN;i++) {
						boolean found = false;
						for(int j=0;j<currentLabelCount;j++) {
							if(labelVal[j]==labels[i]) {
								labelCount[j]++;
								labelWeight[j]+=values[i];
								found = true;
								break;
							}
						}
						if(!found) {
							labelCount[currentLabelCount]=1;
							labelVal[currentLabelCount]=labels[i];
							labelWeight[currentLabelCount]+=values[i];
							currentLabelCount++;
						}
					}
					//			for(int i=0;i<currentLabelCount;i++) {
					//				System.out.println(labelCount[i]+" / "+labelVal[i]);
					//			}
					float []weightSort = labelWeight.clone();
					Arrays.sort(weightSort,0,currentLabelCount);
					float maxWeight=weightSort[currentLabelCount-1];
					//			System.out.println("Max:"+maxCount);
					// report mode (first if tie)
					int countAtMax =0;
					for(int i=0;i<currentLabelCount;i++) {
						if(labelWeight[i]==maxWeight) {
							countAtMax++;
						}
					}
					//			System.out.println(countAtMax);
					countAtMax++;// the 1st element will be "k"
					float []currentKNNLabel=new float[countAtMax];
					//			System.out.println("clc:"+currentLabelCount);
					for(int i=0;i<currentLabelCount;i++) {
						if(labelWeight[i]==maxWeight) {
							//					currentKNNLabel = labelVal[i];
							currentKNNLabel[countAtMax-1] = labelVal[i];
							countAtMax--;
							//					System.out.println("found"+labelVal[i]);
							//					break;
						}
					}	
					currentKNNLabel[0]=countNN;

					result.set(r, currentKNNLabel);


				}else{
					// count the labels
					currentLabelCount = 0;
					for(int i=0;i<countNN;i++) {
						boolean found = false;
						for(int j=0;j<currentLabelCount;j++) {
							if(labelVal[j]==labels[i]) {
								labelCount[j]++;
								found = true;
								break;
							}
						}
						if(!found) {
							labelCount[currentLabelCount]=1;
							labelVal[currentLabelCount]=labels[i];
							currentLabelCount++;
						}
					}
					//			for(int i=0;i<currentLabelCount;i++) {
					//				System.out.println(labelCount[i]+" / "+labelVal[i]);
					//			}
					int []labelSort = labelCount.clone();
					Arrays.sort(labelSort,0,currentLabelCount);
					int maxCount=labelSort[currentLabelCount-1];
					//			System.out.println("Max:"+maxCount);
					// report mode (first if tie)
					int countAtMax =0;
					for(int i=0;i<currentLabelCount;i++) {
						if(labelCount[i]==maxCount) {
							countAtMax++;
						}
					}
					//			System.out.println(countAtMax);
					countAtMax++; // the 1st element will be "k"
					float []currentKNNLabel=new float[countAtMax];
					//			System.out.println("clc:"+currentLabelCount);
					for(int i=0;i<currentLabelCount;i++) {
						if(labelCount[i]==maxCount) {
							//					currentKNNLabel = labelVal[i];
							currentKNNLabel[countAtMax-1] = labelVal[i];
							countAtMax--;
							//					System.out.println("found"+labelVal[i]);
							//					break;
						}
					}			
					currentKNNLabel[0]=countNN;
					result.set(r, currentKNNLabel);
				}
			} else {
				//				System.out.println("KNN:"+valuesUnsafe.length+" "+orderStat+" "+kthOrderStat);
				result.set(r, new float[]{-1 -1});
			}
		}
		if(statusCurrentRow!=null) 
			statusCurrentRow[0]=lastRow;
		return result;
	}

	/**
	 * Compute Kth order stat. (i.e., 1=smallest, length=largest)
	 * 
	 * @param x the data on which to compute (not altered)
	 * @param i the index of the order statistic
	 * 
	 * @return the float
	 */
	public static float kthOrderStat(float []x, int i) {
		float []xcopy = x.clone(); 
		if(x.length<i || i<1){
			//			System.out.println("Warning: order statistic > data length");			
			return Float.NaN;
		}
		java.util.Arrays.sort(xcopy);
		return xcopy[i-1];
	}

	/**
	 * Prints the status information
	 * 
	 * @param lp the lp
	 */
	public static void printStat(LabelPairs lp) {
		System.out.println("Total Label Pairs:    "+lp.getNumberOfLabelPairs());
	}

}
