###########################################################################
# @package QP
# @brief functions that call the QP routine
# It defines a qp_loss class which has
# loss and derivative as the class methods
# qp_solve function calls qp_loss and optimization routine
#
# @author Harini Eavani
#
# @link: https://www.cbica.upenn.edu/sbia/software/
#
# @author: sbia-software@uphs.upenn.edu
##########################################################################

import numpy as np
from scipy import optimize
#from matplotlib import pyplot as plt
#from mpl_toolkits.mplot3d.axes3d import Axes3D

## function that solves a quadratic problem
# for example the following:
# minimize
#     F = x[1]^2 + 4x[2]^2 -32x[2] + 64
# subject to:
#      x[1] + x[2] <= 7
#     -x[1] + 2x[2] <= 4
#      x[1] >= 0
#      x[2] >= 0
#      x[2] <= 4
# in matrix notation:
#     F = (1/2)*x.T*H*x + c*x + c0
# subject to:
#     Ax <= b
# where:
#     H = [[2, 0],
#          [0, 8]]
#     c = [0, -32]
#     c0 = 64
#     A = [[ 1, 1],
#          [-1, 2],
#          [-1, 0],
#          [0, -1],
#          [0,  1]]
#     b = [7,4,0,0,4]
#H = np.array([[2., 0.],
#              [0., 8.]])
#c = np.array([0, -32])
#c0 = 64
#A = np.array([[ 1., 1.],
#              [-1., 2.],
#              [-1., 0.],
#              [0., -1.],
#              [0.,  1.]])
#b = np.array([7., 4., 0., 0., 4.])
#x0 = np.random.randn(2)

## class that has variables and method for objective function
class qp_loss:
    ## init function
    # @param self object pointer
    # @param sign 
    # @param H square matrix quadratic coefficient in objective
    # @param c linear coefficient in objective
    # @param c0 constant in objective
    def  __init__(self,sign,H,c,c0):
        self.sign = sign
        self.H = H
        self.c = c
        self.c0 = c0
   
    ## loss function
    # @param self object pointer
    # @param x data point
    def loss(self,x):
        return self.sign * (0.5 * np.dot(x.T, np.dot(self.H, x))+ np.dot(self.c, x) + self.c0)
        
    ## jaccard function
    # @param self object pointer
    # @param x data point            
    def jac(self,x):
        return self.sign * (np.dot(x.T, self.H) + self.c)

## main function for QP
# @param H quadratic coefficient in objective
# @param c linear coefficient in objective
# @param c0 constant in objective
# @param A linear coefficient in equality
# @param b constant in equality
# @param A_in linear coefficient in inequality
# @param b_in constant in inequality
# @param x0 initial data point 
# F = (1/2)*x.T*H*x + c*x + c0
# subject to:
#     Ax <= b
def qp_solve(H,c,c0,A,b,A_in,b_in,x0):
    
    cons = [{'type':'eq','fun':lambda x: b - np.dot(A,x),'jac':lambda x: -A},
             {'type':'ineq','fun':lambda x: b_in - np.dot(A_in,x),'jac':lambda x: -A_in}]

    opt = {'disp':False,'maxiter':200}

    sign = 1.0
    qp_loss_f = qp_loss(sign=sign,H=H,c=c,c0=c0)

    res_cons = optimize.minimize(qp_loss_f.loss, x0, jac=qp_loss_f.jac,constraints=cons,
                                 method='SLSQP', options=opt,tol=10**-20)
    if res_cons['status']==9:
        res_cons['x'] = x0
    return res_cons
                                   
    
#
#    print '\nConstrained:'
#    print res_cons
#
#    print '\nUnconstrained:'
#    print res_uncons
#
#    x1, x2 = res_cons
#    f = res_cons['fun']
#
#    x1_unc, x2_unc = res_uncons['x']
#    f_unc = res_uncons['fun']

    # plotting
#    xgrid = np.mgrid[-2:4:0.1, 1.5:5.5:0.1]
#    xvec = xgrid.reshape(2, -1).T
#    F = np.vstack([qp_loss_f.loss(xi) for xi in xvec]).reshape(xgrid.shape[1:])
#
#    ax = plt.axes(projection='3d')
#    ax.hold(True)
#    ax.plot_surface(xgrid[0], xgrid[1], F, rstride=1, cstride=1,
#                    cmap=plt.cm.jet, shade=True, alpha=0.9, linewidth=0)
#    ax.plot3D([x1], [x2], [f], 'og', mec='w', label='Constrained minimum')
#    ax.plot3D([x1_unc], [x2_unc], [f_unc], 'oy', mec='w',
#              label='Unconstrained minimum')
#    ax.legend(fancybox=True, numpoints=1)
#    ax.set_xlabel('x1')
#    ax.set_ylabel('x2')
##    ax.set_zlabel('F')