###########################################################################
# @package SV_fuzzy
# @brief Main src that implements the MOE method
# It defines a SV_fuzzy class which has
# class outputvals as the class members
# fit, predict as the methods
#
# @author Harini Eavani
#
# @link: https://www.cbica.upenn.edu/sbia/software/
#
# @author: sbia-software@uphs.upenn.edu
##########################################################################

from sklearn.linear_model import Ridge
from sklearn.svm import SVC
from sklearn.utils import check_random_state
from sklearn.cluster import KMeans
import numpy as np
from QP import qp_solve
#from draw_moe_plot import draw_moe_plot


## notation and sizes
# nS - number of subjects
# n - number of experts
# p - dimensionality
# centroid - p x n
# hyperplanes p x n
# membership - nS x n
# err - objective function value
# corrW - the maximal inner-product between hyperplanes
# bpc - Bezdek participation coefficient

## class OutputVals
# defines the tructure that stores all output variables 
class OutputVals:
    def __init__(self):
        self.svms = []
        self.hyperplanes = None
        self.memberships = None
        self.centroids = None
        self.err = None            
        self.corrW = None            
        self.bpc = None        

## initialization function for all output variables       
def initializeSV_fuzzy(X,y,n,C,max_iter):
    nS = y.shape[0]
    p = X.shape[1]
    
    ## check if input labels are discrete or continuous
    if len(np.unique(y))==2:
        regress=0
    else:
        regress=1
    
    random_state = 0
    random_state = check_random_state(random_state)

    ## initialize a set out output vals
    initialValues = OutputVals()
    
    if regress==0:
        for nn in range(n):
            initialValues.svms.append(SVC(C=C,kernel='linear')) 
    else:
        for nn in range(n):
            initialValues.svms.append(Ridge()) 
        

    initialValues.hyperplanes = np.zeros((p,n))
    initialValues.intercepts = np.zeros((n))
    initialValues.memberships = (1./n)*np.ones((nS,n))
    initialValues.err = np.zeros(max_iter)
    initialValues.corrW = np.zeros(max_iter)
    initialValues.bpc = np.zeros(max_iter)
    initialValues.centroids = np.zeros((p,n))
    
    ## use k means to initialize membership vals
    km = KMeans(n_clusters=n)
    if regress==0:
        km.fit(X[y==1,:])
    else:
        km.fit(X)
    initialValues.centroids = km.cluster_centers_.transpose()
    count = 0
    for e,label in enumerate(y):
        if label==1 or regress==1:
            initialValues.memberships[e,:] = np.zeros(n)
            initialValues.memberships[e,km.labels_[count]] = 1
#            initialValues.memberships[e,:] = np.random.rand(1,n)
#            initialValues.memberships[e,:] = initialValues.memberships[e,:]/initialValues.memberships[e,:].sum()
            count += 1
          
    return initialValues

## function for calculating the objective function value
def get_objective(X,y,out_vals,C,t):
    nS = y.shape[0]
    nS,n = out_vals.memberships.shape
    p = X.shape[1]
    
    if len(np.unique(y))==2:
        regress=0
    else:
        regress=1
        
    ## compute objective
    f_w = 0
    f_slack = 0
    f_fuzzy = 0
    
    for nn in range(n):
        f_w += 0.5*np.sum(out_vals.svms[nn].coef_**2)
        predictions = out_vals.svms[nn].decision_function(X)
        if regress==0:
            slack = ((1 - y*predictions)*((1 - y*predictions)>0))
        else:
            slack = (abs(y - predictions))**2
            
        f_slack += np.sum(out_vals.memberships[:,nn]*slack)
        for ss in range(nS):
            f_fuzzy += out_vals.memberships[ss,nn]**2 * np.sum((X[ss,:] - out_vals.centroids[:,nn])**2)
        f_fuzzy = 1.*f_fuzzy/(nS*n*p)
    f = f_w + C*f_slack + t*f_fuzzy
#    print(f_w,f_slack,f_fuzzy)

    ## compute maximal inner product
    if n>1:
        corrW = []
        for nn in np.arange(0,n):
            for pp in np.arange(nn+1,n):
                temp = np.inner(out_vals.svms[nn].coef_,out_vals.svms[pp].coef_)
                temp = temp/(np.linalg.norm(out_vals.svms[nn].coef_)*np.linalg.norm(out_vals.svms[pp].coef_))
                corrW.append(temp)
    else:
        corrW = 1
    
    ## compute Bezdeks Participation Coeff
    if regress==0:
        bpc = np.sum((out_vals.memberships**2)[y==1,:])/np.sum(y==1)
    else:
        bpc = np.sum((out_vals.memberships**2))/nS  
    if n>1:
        bpc = 1 - (1.*n/(n-1))*(1 - bpc)
     
    return f, np.max(corrW), bpc

## function that runs the main iterative method
def run_SV_fuzzy(X, y, n,C,t,max_iter,verbose):
    ## initialize
    out_vals = initializeSV_fuzzy(X,y,n,C,max_iter)
    nS = y.shape
    nS, p = X.shape

    ## check if classification or regression
    if len(np.unique(y))==2:
        regress=0
    else:
        regress=1
        
    ## set up while loop        
    delE=1
    iternum=0
    while  iternum < max_iter and delE > 10**-3:

        blockVal = iternum % 2
        if blockVal == 0:
            ## update svm
            slack = np.zeros((nS,n))
            out_vals.hyperplanes = np.zeros((p,n))
            out_vals.intercepts = np.zeros((n))
            for nn in range(n):
                if regress==0:
                    out_vals.svms[nn]=SVC(C=C,kernel='linear')
                else:
                    out_vals.svms[nn]=Ridge() 
                out_vals.svms[nn].fit(X,y,sample_weight=out_vals.memberships[:,nn].ravel())
                out_vals.hyperplanes[:,nn] = out_vals.svms[nn].coef_
                out_vals.intercepts[nn] = out_vals.svms[nn].intercept_
                predictions = out_vals.svms[nn].decision_function(X)
                if regress==0:
                    slack[:,nn] = ((1 - y*predictions)*((1 - y*predictions)>0))
                else:
                    slack[:,nn] = (y - predictions)**2                    
    
        elif blockVal == 1:
            ## update centroids
            m_squared = out_vals.memberships ** 2
            out_vals.centroids = np.zeros((p,n))
            
            for nn in range(n):
                for ss in range(nS):
                    out_vals.centroids[:,nn] += X[ss,:] *  m_squared[ss,nn]
                out_vals.centroids[:,nn] = out_vals.centroids[:,nn]/np.sum(m_squared[:,nn])
            
            dist2centroids = np.zeros((nS,n))
            for nn in range(n):
                for ss in range(nS):
                    dist2centroids[ss,nn] = np.sum((X[ss,:] - out_vals.centroids[:,nn])**2)
                    
                
            ## update memberships
            for ss in range(nS):
                if y[ss]==1 or regress==1:
                    H = 2*(1.*t/(nS*n*p))*np.diag(dist2centroids[ss,:])
                    f = C*slack[ss,:]
                    f0 = np.zeros(1)
                    A = np.ones(n)
                    b = np.array(1)
                    A_in =  np.concatenate((-1*np.diag(np.ones(n)),np.diag(np.ones(n))),axis=0)
                    b_in =  np.concatenate((np.zeros(n),np.ones(n)),axis=0)
                    x0 = out_vals.memberships[ss,:]
                    res_cons = qp_solve(H,f,f0,A,b,A_in,b_in,x0)
                    soln = res_cons['x']
                    soln[soln<10**-3]=0
                    soln = soln/np.sum(soln)
                    out_vals.memberships[ss,:]=soln
            
       ## clean out zero clusters
        if regress==0:            
            clusters_to_remove=np.where(np.sum(out_vals.memberships[y==1,:],axis=0) <1)[0]
        else:
            clusters_to_remove=np.where(np.sum(out_vals.memberships,axis=0) <1)[0]
            
        if not clusters_to_remove.shape[0] == 0:
            for i in sorted(clusters_to_remove, reverse=True):
                del out_vals.svms[i]
            out_vals.memberships = np.delete(out_vals.memberships, clusters_to_remove, 1)
            out_vals.memberships = out_vals.memberships/np.sum(out_vals.memberships,axis=1)[:,np.newaxis]
            out_vals.hyperplanes = np.delete(out_vals.hyperplanes, clusters_to_remove, 1)
            slack = np.delete(slack, clusters_to_remove, 1)
            
        ## check change in objective value
        nS,n = out_vals.memberships.shape
        out_vals.err[iternum],out_vals.corrW[iternum],out_vals.bpc[iternum] = get_objective(X,y,out_vals,C,t)
        if iternum>4:
            delE = abs(out_vals.err[iternum] - out_vals.err[iternum-4])/out_vals.err[iternum] 
        
        ## print to stdout
        if verbose>0:
            print("Iteration "+str(iternum)+" complete, bpc:"+str(out_vals.bpc[iternum])+", corrw:"+str(out_vals.corrW[iternum])+", error:"+str(out_vals.err[iternum])+
                                ", nExperts:"+str(n))
        
        iternum += 1
        
        #filename=str(iternum)+".png"
        #draw_moe_plot(X,y,out_vals,filename)
    out_vals.err = out_vals.err[0:iternum]
    out_vals.corrW = out_vals.corrW[0:iternum]
    out_vals.bpc = out_vals.bpc[0:iternum]
    print("Error is "+str(out_vals.err[-1])+', inn-prod is '+str(out_vals.corrW[-1])+', BPC is '+str(out_vals.bpc[-1]))        
    return out_vals
    
## main class that combines SV/Ridge + Fuzzy C-Means
# @param SVC inherits from sklearn's SVC    
class SV_fuzzy(SVC):
    ## init function
    # @param self the object pointer
    # @param n the number of experts
    # @param C the expert cost value
    # @param t the expert-mixture tradeoff value
    # @param max_iter maximum number of iterations
    # @param verbose enable verbose output if >0
    def __init__(self,n=2,C=1,t=1,max_iter=100,verbose=0):
        self.n = n
        self.C = C
        self.t = t    
        self.max_iter = max_iter
        self.verbose=verbose
        
    ## fit function 
    # @param X the data
    # @param y the labels
    def fit(self,X,y):
        if len(np.unique(y))==2:
            self.regress=0
        else:
            self.regress=1
        out_vals = run_SV_fuzzy(X, y, n=self.n,C=self.C,t=self.t,max_iter=self.max_iter,verbose=self.verbose) 
        self.n = len(out_vals.svms)
        self.svms_ = out_vals.svms
        self.hyperplanes_ = out_vals.hyperplanes
        self.intercepts_ = out_vals.intercepts
        self.memberships_ = out_vals.memberships
        self.centroids_ = out_vals.centroids
        self.err = out_vals.err          
        self.corrW = out_vals.corrW       
        self.bpc = out_vals.bpc
        return self
        
    ## predict function
    # @param self the object pointer
    # @param X the test data
    def predict(self,X):
        nS,p = X.shape
        p, n = self.hyperplanes_.shape
        dist2centroids = np.zeros((nS,n))
        for nn in range(n):
            for ss in range(nS):
                dist2centroids[ss,nn] = np.sum((X[ss,:] - self.centroids_[:,nn])**2)
        
        y = np.zeros(nS)
        #update memberships
        for ss in range(nS):
            H = 2*self.t*np.diag(dist2centroids[ss,:])
            f = np.zeros(n)
            f0 = np.zeros(1)
            A = np.ones(n)
            b = np.array(1)
            A_in =  np.concatenate((-1*np.diag(np.ones(n)),np.diag(np.ones(n))),axis=0)
            b_in =  np.concatenate((np.zeros(n),np.ones(n)),axis=0)
            x0 = np.zeros(n)
            res_cons = qp_solve(H,f,f0,A,b,A_in,b_in,x0)
            predict_memberships = res_cons['x']
            decision_fn = np.zeros(n)
            
            for nn in range(n):
                decision_fn[nn] = self.svms_[nn].predict(X[ss,:])
            if self.regress==0:
                y[ss]=np.sign(np.sum(predict_memberships*decision_fn))
            else:
                y[ss]=np.sum(predict_memberships*decision_fn)
        return y        

