package edu.jhu.bme.smile.commons.math;

import no.uib.cipr.matrix.*;
import no.uib.cipr.matrix.distributed.BlockDiagonalPreconditioner;
import no.uib.cipr.matrix.sparse.*;
import Jama.Matrix;

/* Derived from the ls_l1 Matlab package.
 * This port is made available under GPL (as opposed to LGPL (for the rest of the package).
 *   
 * Adapted by Bennett Landman <landman@jhu.edu>
 * 2008.12.10
 * 
	% AUTHOR    Kwangmoo Koh <deneb1@stanford.edu>
	% UPDATE    Apr 8 2007
	%
	% COPYRIGHT 2008 Kwangmoo Koh, Seung-Jean Kim, and Stephen Boyd
	%
	% l1-Regularized Least Squares Problem Solver
	%
	%   l1_ls solves problems of the following form:
	%
	%       minimize ||A*x-y||^2 + lambda*sum|x_i|,
	%
	%   where A and y are problem data and x is variable (described below).


 * % AUTHOR    Kwangmoo Koh <deneb1@stanford.edu>
	% UPDATE    Apr 10 2008
	%
	% COPYRIGHT 2008 Kwangmoo Koh, Seung-Jean Kim, and Stephen Boyd
	%
	% l1-Regularized Least Squares Problem Solver
	%
	%   l1_ls solves problems of the following form:
	%
	%       minimize   ||A*x-y||^2 + lambda*sum(x_i),
	%       subject to x_i >= 0, i=1,...,n
	%
	%   where A and y are problem data and x is variable (described below).
	%   x       : n vector; classifier
 */

public class L1LSCompressedSensing {

	//	% IPM PARAMETERS
	private double MU              = 2;        // updating parameter of t
	private double MAX_NT_ITER     = 400;      // maximum IPM (Newton) iteration

	//	% LINE SEARCH PARAMETERS
	private double ALPHA           = 0.01;     // minimum fraction of decrease in the objective
	private double BETA            = 0.5;      // stepsize decrease factor
	private double MAX_LS_ITER     = 100;      // maximum backtracking line search iteration

	private Matrix A;
	private Matrix At;
	private Matrix AtA2;
	private Matrix x; 				// the result
	private int m; //	%   m       : number of examples (rows) of A
	private int n; //	%   n       : number of features (column)s of A
	private DenseMatrix cgA;
	private CompRowMatrix pre;
	
	private boolean statusConverged = false;
	private boolean verbose = false; 

	public L1LSCompressedSensing(double [][]A, boolean requirePos) {
		this(new Matrix(A),requirePos);
	}

	public L1LSCompressedSensing(Matrix A, boolean requirePos) {
		this.A=A;
		At = this.A.transpose();
		AtA2 = At.times(this.A).times(2);
		n = this.A.getColumnDimension();
		m = this.A.getRowDimension();
		if (requirePos) {
			cgA = new DenseMatrix(AtA2.getArray());
			int[][] nz = new int[n][1];
			for (int i=0; i<n; i++) nz[i][0]=i;
			pre = new CompRowMatrix(n,n,nz);
		}
		else {
			Matrix bigA = new Matrix(2*n,2*n);
			bigA.setMatrix(0,n-1,0,n-1,AtA2);
			cgA = new DenseMatrix(bigA.getArray());
			int[][] nz = new int[2*n][2];
			for (int i=0; i<2*n; i++) {
				nz[i][0]=i;
				nz[i][1]=(i+n)%(2*n);
			}
			pre = new CompRowMatrix(2*n,2*n,nz);
		}
		statusConverged=false;
		verbose = false;
	}

	public boolean isConvereged() {
		return statusConverged;
	}

	public void setVerbose(boolean verbose) {
		this.verbose = verbose; 
	}

	/******************************************************************************
	 * Unconstrained x
	 ******************************************************************************/
	public boolean solve(double []y, double lambda) {
		double [][]y1 = new double[1][];
		y1[0]=y;
		return solve(new Matrix(y1).transpose(), lambda);
	}

	public boolean solve(double []y, double lambda, double reltol) {
		double [][]y1 = new double[1][];
		y1[0]=y;
		return solve(new Matrix(y1).transpose(), lambda, reltol);
	}

	public boolean solve(Matrix y, double lambda) {
		return solve(y,lambda,1e-3);
	}

	public boolean solve(Matrix y, double lambda, double reltol) {

		x = new Matrix(n,1);
		Matrix u = new Matrix(n,1,1);
		Matrix xu = MatrixMath.concatinateRows(x, u);
		Matrix f = MatrixMath.concatinateRows(x.minus(u),x.times(-1).minus(u)); //f = [x-u;-x-u];
		double t0=Math.min(Math.max(1,1/lambda), n/1e-3);
		double t = t0;

		if(verbose) {
			System.out.println("Solving a problem of size (m="+m+", n="+n+"), with lambda="+lambda);
			System.out.println("-----------------------------------------------------------------------------");
			System.out.println("iter\tgap\tprimobj\tdualobj\tstep\tlen\tpcg\titers");
		}
		double pobj  = Double.POSITIVE_INFINITY; 
		double dobj  = Double.NEGATIVE_INFINITY; 
		double s     = Double.POSITIVE_INFINITY;
		double pitr  = 0 ; 
		int ntiter =0;
		int lsiter  = 0; 		
		Matrix dxu =  new Matrix(2*n,1);
		Matrix dx = (Matrix)x.clone();
		Matrix du = (Matrix)x.clone();

		for(ntiter = 0;ntiter<=MAX_NT_ITER;ntiter++) {
			//		    %------------------------------------------------------------
			//		    %       CALCULATE DUALITY GAP
			//		    %------------------------------------------------------------

			Matrix z = A.times(x).minus(y); //z = A*x-y;
			Matrix nu = z.times(2); //nu = 2*z;

			double maxAnu = At.times(nu).normInf();
			if (maxAnu > lambda)
				nu = nu.times(lambda/(maxAnu));

			pobj  =  MatrixMath.dotProduct(z,z)+lambda*MatrixMath.colSumNorm(x);
			dobj  =  Math.max(MatrixMath.dotProduct(nu,nu)*(-.25)-MatrixMath.dotProduct(nu,y),dobj); //dobj  =  max(-0.25*nu'*nu-nu'*y,dobj);
			double gap   =  pobj - dobj;
			if(verbose) {
				System.out.println(ntiter+"\t"+gap+"\t"+pobj+"\t"+dobj+"\t"+s+"\t"+pitr);
			}

			//		    %------------------------------------------------------------
			//		    %   STOPPING CRITERION
			//		    %------------------------------------------------------------

			if (gap/Math.abs(dobj) < reltol) {
				if(verbose)  
					System.out.println("Converged.");
				statusConverged=true;	
				return statusConverged;
			}

			//		    %------------------------------------------------------------
			//		    %       UPDATE t
			//		    %------------------------------------------------------------
			if (s >= 0.5) 
				t = Math.max(Math.min(2*n*MU/gap, MU*t), t); //        t = max(min(n*MU/gap, MU*t), t);

			//		    %------------------------------------------------------------
			//		    %       CALCULATE NEWTON STEP
			//		    %------------------------------------------------------------

			Matrix q1 = MatrixMath.elementInverse(u.plus(x));
			Matrix q2 = MatrixMath.elementInverse(u.minus(x));
			Matrix d1 = MatrixMath.elementProduct(q1, q1).plus(MatrixMath.elementProduct(q2, q2)).times(1./t); 
			Matrix d2 = MatrixMath.elementProduct(q1, q1).minus(MatrixMath.elementProduct(q2, q2)).times(1./t);

			//		    % calculate gradient
			//		    gradphi = [At*(z*2)-(q1-q2)/t; lambda*ones(n,1)-(q1+q2)/t];
			Matrix gradphi = MatrixMath.concatinateRows(At.times(z.times(2)).minus(q1.minus(q2).times(1./t)), 
					MatrixMath.add(q1.plus(q2).times(-1./t),lambda));

			// [2*A'*A+diag(d1) diag(d2); diag(d2) diag(d1)]
			for (int i=0; i<n; i++) {
				cgA.set(i, i, AtA2.get(i, i)+d1.get(i, 0));
				cgA.set(i, i+n, d2.get(i, 0));
				cgA.set(i+n, i, d2.get(i, 0));
				cgA.set(i+n, i+n, d1.get(i, 0));
				
				pre.set(i, i, d1.get(i, 0));
				pre.set(i, i+n, d2.get(i, 0));
				pre.set(i+n, i, d2.get(i, 0));
				pre.set(i+n, i+n, d1.get(i, 0));
			}

			DenseVector cgB = new DenseVector(gradphi.times(-1).getColumnPackedCopy());
			DenseVector cgX = new DenseVector(xu.getColumnPackedCopy());
	
			IterativeSolver solver = new BiCGstab(cgB);
			Preconditioner M = new ILU(pre);
			M.setMatrix(pre);			
			solver.setPreconditioner(M);

			try {
				solver.solve(cgA, cgB, cgX);
			}
			catch (IterativeSolverNotConvergedException e) {
				System.err.println("Iterative solver failed to converge");
			}

			// Writes the data back into dx and du
			for (int i=0; i<n; i++) {
				dx.set(i,0,cgX.get(i));
				du.set(i,0,cgX.get(i+n));
			}
			
			//		    %------------------------------------------------------------
			//		    %   BACKTRACKING LINE SEARCH
			//		    %------------------------------------------------------------

			double phi = MatrixMath.dotProduct(z,z)+lambda*MatrixMath.sum(u)-MatrixMath.sum(MatrixMath.log(f.times(-1)))/t;
			s = 1.0;
			double gdx = MatrixMath.dotProduct(gradphi,dxu); 
			Matrix newx=null;
			Matrix newf=null;
			Matrix newu = null;
			for(lsiter = 1;lsiter<=MAX_LS_ITER;lsiter++) {
				newx = x.plus(dx.times(s));// newx = x+s*dx;
				newu = u.plus(du.times(s));//newu = u+s*du;
				newf = MatrixMath.concatinateRows(newx.minus(newu),newx.times(-1).minus(newu)); //f = [x-u;-x-u];
				if (MatrixMath.max(newf) < 0) {
					Matrix newz   =  A.times(newx).minus(y);
					double newphi =  MatrixMath.dotProduct(newz,newz)+lambda*MatrixMath.sum(newu)-MatrixMath.sum(MatrixMath.log(newf.times(-1)))/t;		            
					if (newphi-phi <= ALPHA*s*gdx)
						break;
				}
				s = BETA*s;
			}

			if (lsiter > MAX_LS_ITER) 
				break; // end % exit by BLS

			x = newx; f = newf;
			u = newu;
		}

		statusConverged=false;
		if(verbose) {
			if (lsiter > MAX_LS_ITER) {		    
				System.out.println("MAX_LS_ITER exceeded in BLS"); 
			} else if (ntiter > MAX_NT_ITER) {
				System.out.println("MAX_NT_ITER exceeded.");

			} else {
				System.out.println("Unknown failure result.");
			}
		}
		return false;
	}

	/******************************************************************************
	 * Add non-negative constraints 
	 ******************************************************************************/
	public boolean solveNonNeg(double []y, double lambda) {
		double [][]y1 = new double[1][];
		y1[0]=y;
		return solveNonNeg(new Matrix(y1).transpose(), lambda);
	}

	public boolean solveNonNeg(double []y, double lambda, double reltol) {
		return solveNonNeg(y, lambda, reltol, -1);
	}
			
	public boolean solveNonNeg(double []y, double lambda, double reltol, double maxError) {
		double [][]y1 = new double[1][];
		y1[0]=y;
		return solveNonNeg(new Matrix(y1).transpose(), lambda, reltol, maxError);
	}

	public boolean solveNonNeg(Matrix y, double lambda) {
		return solveNonNeg(y,lambda,1e-3, 10);
	}
	
	public boolean solveNonNeg(Matrix y, double lambda, double reltol, double maxError) {

		x = new Matrix(n,1,.01);
		double t0=Math.min(Math.max(1,1/lambda), n/1e-3);
		double t = t0;


		if(verbose) {
			System.out.println("Solving a problem of size (m="+m+", n="+n+"), with lambda="+lambda);
			System.out.println("-----------------------------------------------------------------------------");
			System.out.println("iter\tgap\tprimobj\tdualobj\tstep\tlen\tpcg\titers");
		}
		double pobj  = Double.POSITIVE_INFINITY; 
		double dobj  = Double.NEGATIVE_INFINITY; 
		double s     = Double.POSITIVE_INFINITY;
		double pitr  = 0 ; 
		int ntiter = 0;
		int lsiter = 0; 			
		int errorCount = 0;
		double prevError = -1;
		Matrix dx = new Matrix(n,1);
		
		double lsitercount=0;

		for(ntiter = 0;ntiter<=MAX_NT_ITER;ntiter++) {
			Matrix z = A.times(x).minus(y); //z = A*x-y;
			Matrix nu = z.times(2); //nu = 2*z;
			double minAnu = MatrixMath.min(At.times(nu));
			if (minAnu < -lambda)
				nu = nu.times(lambda/(-minAnu));

			pobj  =  MatrixMath.dotProduct(z,z)+lambda*MatrixMath.sum(x); //pobj  =  z'*z+lambda*sum(x,1);
			dobj  =  Math.max(MatrixMath.dotProduct(nu,nu)*(-.25)-MatrixMath.dotProduct(nu,y),dobj); //dobj  =  max(-0.25*nu'*nu-nu'*y,dobj);
			double gap   =  pobj - dobj;
			if(verbose) {
				System.out.println(ntiter+"\t"+gap+"\t"+pobj+"\t"+dobj+"\t"+s+"\t"+pitr);
			}

			//		    %------------------------------------------------------------
			//		    %   STOPPING CRITERION
			//		    %------------------------------------------------------------
			double relError = gap/Math.abs(dobj);
			if (relError < reltol) {
				if(verbose)  
					System.out.println("Converged.");
				statusConverged=true;
				return statusConverged;
			}
			else if (maxError > 0 && relError > maxError) {
				if (verbose)
					System.out.println("Error too high.");
				if (prevError==-1) {
					errorCount++;
					prevError = relError;
				}
				else {
					if (relError > prevError) {
						errorCount++;
						prevError=relError;
					}
					else {
						errorCount=0;
						prevError=-1;
					}
				}
				if (errorCount==3) {
					statusConverged=false;
					return statusConverged;
				}
			}
			else {
				errorCount=0;
				prevError=-1;
			}

			//		    %------------------------------------------------------------
			//		    %       UPDATE t
			//		    %------------------------------------------------------------
			if (s >= 0.5) 
				t = Math.max(Math.min(n*MU/gap, MU*t), t); //        t = max(min(n*MU/gap, MU*t), t);

			//		    %------------------------------------------------------------
			//		    %       CALCULATE NEWTON STEP
			//		    %------------------------------------------------------------

			Matrix d1 = MatrixMath.elementInverse(MatrixMath.elementProduct(x,x).times(t)); //d1 = (1/t)./(x.^2);

			//		    % calculate gradient
			Matrix gradphi = MatrixMath.add(At.times(z.times(2)),lambda).minus(MatrixMath.elementInverse(x.times(t))); //gradphi = [At*(z*2)+lambda-(1/t)./x]; 

			// 2*A'*A+diag(d1)
			for (int i=0; i<n; i++) {
				cgA.set(i, i, AtA2.get(i, i)+d1.get(i, 0));
				pre.set(i, i, d1.get(i,0));
			}
			
			DenseVector cgB = new DenseVector(gradphi.times(-1).getColumnPackedCopy());
			DenseVector cgX = new DenseVector(x.getColumnPackedCopy());
	
			IterativeSolver solver = new CG(cgB);
			Preconditioner M = new DiagonalPreconditioner(n);
			M.setMatrix(pre);
			solver.setPreconditioner(M);

			try {
				solver.solve(cgA, cgB, cgX);
			}
			catch (IterativeSolverNotConvergedException e) {
				System.err.println("Iterative solver failed to converge");
			}

			// Writes the data back into dx
			for (int i=0; i<n; i++) {
				dx.set(i,0,cgX.get(i));
			}

			//		    %------------------------------------------------------------
			//		    %   BACKTRACKING LINE SEARCH
			//		    %------------------------------------------------------------
			double phi = MatrixMath.dotProduct(z,z)+lambda*MatrixMath.sum(x)-MatrixMath.sum(MatrixMath.log(x))/t;
			s = 1.0;
			double gdx = MatrixMath.dotProduct(gradphi,dx); 
			Matrix newx=null;
			for(lsiter = 1;lsiter<=MAX_LS_ITER;lsiter++,lsitercount++) {
				newx = x.plus(dx.times(s));// newx = x+s*dx;
				if (MatrixMath.min(newx) > 0) {
					Matrix newz   =  A.times(newx).minus(y);
					double newphi =  MatrixMath.dotProduct(newz,newz)+lambda*MatrixMath.sum(newx)-MatrixMath.sum(MatrixMath.log(newx))/t;		            
					if (newphi-phi <= ALPHA*s*gdx)
						break;
				}
				s = BETA*s;
			}
			
			if (lsiter > MAX_LS_ITER) 
				break; // end % exit by BLS

			x = newx;
		}
		

		statusConverged=false;
		if(verbose) {
			if (lsiter > MAX_LS_ITER) {		    
				System.out.println("MAX_LS_ITER exceeded in BLS"); 
			} else if (ntiter > MAX_NT_ITER) {
				System.out.println("MAX_NT_ITER exceeded.");

			} else {
				System.out.println("Unknown failure result.");
			}
		}
		return false;
	}


	public Matrix getMatrixResult() {
		if(x==null)
			return null;
		return (Matrix)x.clone();
	}

	public double[] getResult() {
		double []res = new double[x.getRowDimension()];
		for(int i=0;i<x.getRowDimension();i++)
			res[i]=x.get(i,0);
		return res;
	}

	public static void main(String args[]) {

		double [][]A = {{1, 0, 0, .5}, {0, 1, .2, 0.3}, {0, .1, 1, .2}};
		double []y = {1,.2,1};
		L1LSCompressedSensing l1ls = new L1LSCompressedSensing(A,true);
		l1ls.setVerbose(false);
		long tic = System.currentTimeMillis();
		for(int i=0;i<1000;i++) 
			l1ls.solve(y,0.01);
		System.out.println("Unconstrained Solutions: Milliseconds per solution: "+(System.currentTimeMillis()-tic)/1000.f);
		double []res=(l1ls.getResult());
		for(int i=0;i<res.length;i++)
			System.out.println(res[i]);


		l1ls.setVerbose(false);
		tic = System.currentTimeMillis();
		for(int i=0;i<1000;i++) 
			l1ls.solveNonNeg(y,0.01);
		System.out.println("Nonnegative Solutions: Milliseconds per solution: "+(System.currentTimeMillis()-tic)/1000.f);
		res=(l1ls.getResult());
		for(int i=0;i<res.length;i++)
			System.out.println(res[i]);
	}


}
