#!/usr/bin/env python
###########################################################################
# @package MOE
# @brief This script runs MOE with ROI data 
# It takes as input 
# (1) ROI spreadsheet with row and column indices
# It outputs 
# (1) spreadsheet with soft membership values for each subject
# (2) spreadsheet with hyperplane weights for each ROI
#
# @Dependencies
# sklearn version 0.15
#
# @author Harini Eavani
#
# @link: https://www.cbica.upenn.edu/sbia/software/
#
# @author: sbia-software@uphs.upenn.edu
##########################################################################

import os, sys, getopt, tempfile, shutil
sys.path.append('PATH_TO_INSTALL')
from MOEUtils import *

## usage function
def usage():
    print """
    moe --
    Runs MOE-based heterogeneity classification on spreadsheet data
    Works on any set of features -  Average regional volume/density/diffusion/connectivity, 
                                    clinical scores, cognitive data, etc
    *** For best performance, pick optimal -n, -c and -t values based on cross-validation ***
    
    Usage: moe [OPTIONS] 

    Required Options:
    [-d --data]                 Specify the spreadsheet with list of ROIs (required)
    [-p --prefix]               Specify the prefix of the output file  (required)
    [-H --header]               Specify the header column that has the discrete labels  (required) 
    [-l --label]                Specify the label for the heterogenous group (required)
    
    Options:
    [-n --nSVMs]                Specify the number of SVMs as a number. Default = 2.    
    [-c --cost]                 Specify the SVM cost value. Default = 1.    
    [-t --tradeoff]             Specify the classification vs. clustering tradeoff. Default = 0.1.  
    [-f --folds]                Specify the number of cross-validation folds to run. Default = 0. (not run)
                
    [-o --outputDir]            The output directory to write the results. Defaults to the location of the input file
    [-w --workingDir]           Specify a working directory. By default a tmp dir is created and used
    [-u --usage | -h --help]    Display this message
    [-v --verbose]              Verbose output
    [-V --Version]              Display version information

    Example:

    For a csv file ROI_values.csv that looks like below (patient group has label 'AD'):
    
    Subject_id, ROI_1, ROI_2, ROI_3, ..., Label, ...
    BLSA_0001, 0.234, 0.4545, 0.212,...., AD , ...
    BLSA_0002, 0.122, 0.1213, 0.3434,..., CN, ...  
    BLSA_0003, 0.452, 0.1413, 0.1434,..., AD, ... 
    ...
    ...
    
    Run the following command:
    moe -d ROI_values.csv -p ROI_results -H Label -l AD -n 3 -o ./ -v
    To get ten-fold cross-validation accuracy, run:
    moe -d ROI_values.csv -p ROI_results -H Label -l AD -n 3 -o ./ -v -f 10
    
    For more details see:
    Eavani, Harini, et al. "Capturing heterogeneous group differences using mixture-of-experts: 
    Application to a study of aging." NeuroImage 125 (2016): 498-514. 
    """
## main function     
def main():
    
    ## the defaults
    rOpts = 0
    verbose = 0
    nSVMs = 2
    cost = 1.
    tradeoff = 0.1
    folds = 0
    
    
    outDir = None
    workingDir = None
    
    ## parsing arguments
    print('Parsing arguments')
    try:
        opts, files = getopt.gnu_getopt(sys.argv[1:], "l:H:hd:n:o:p:vVuw:t:c:f:",
        ["label=","header=","help", "data=","nSVMs=","outputDir=","prefix=","verbose","Version","usage","workingDir=","tradeoff=","cost=","folds="]) # parameters with : or = will expect an argument!

    except getopt.GetoptError, err:
        usage()
        print str(err) # will print something like "option -a not recognized"
        sys.exit(2)

    for o, a in opts:
        if o in ("-v", "--verbose"):
            verbose+=1
        elif o in ("-h", "--help","-u","--usage"):
            usage()
            sys.exit(0)
        elif o in ("-V", "--Version"):
            version()
            sys.exit(0)
        elif o in ("-d", "--data"):
            dataFile = a
            rOpts+=1 # fore required options
        elif o in ("-p", "--prefix"):
            prefix = a
            rOpts+=1 # fore required options
        elif o in ("-H", "--header"):
            header = a
            rOpts+=1 # fore required options     
        elif o in ("-l", "--label"):
            label = a
            rOpts+=1 # fore required options
        elif o in ("-n", "--nSVMs"):
            nSVMs = int(a)   
        elif o in ("-c", "--cost"):
            cost = float(a)   
        elif o in ("-o", "--outputDir"):
            outDir = a
        elif o in ("-w", "--workingDir"):
            workingDir = a
        elif o in ("-t", "--tradeoff"):
            tradeoff = float(a)   
        elif o in ("-f", "--folds"):
            folds = int(a)       
        else:
            assert False, "unhandled option"

    if rOpts != 4:
        usage()
        cryandexit("Please specify all required options")
     
    ## expand the files into absolute paths
    dataFile = os.path.realpath(dataFile)

    ## check input
    if not fileExists(dataFile):
        cryandexit("File not found", dataFile)

    ## check prefix
    idStr = getFileBase(dataFile)
    if idStr == prefix:
        cryandexit("To avoid confusion the prefix must be different from the base of the input data file", prefix)

    ## check output dir
    if not outDir:
        outDir = getFilePath(dataFile)

    ## make working directory
    if verbose>0:
        print('Making working directory\n')
    if workingDir == None:
        if ( os.environ.has_key( 'SBIA_TMPDIR' ) ):
            cwDir = tempfile.mkdtemp(prefix='SVRHet',dir=os.environ['SBIA_TMPDIR'])
        else:
            cwDir = tempfile.mkdtemp(prefix='SVRHet')
            keepDir = False
    else:
        cwDir = workingDir
        cwDir = os.path.realpath(cwDir)
    if not os.path.exists(cwDir):
        os.makedirs(cwDir)
    elif not os.path.isdir(cwDir):
        cryandexit("Working dir is not a directory", cwDir)
        os.chdir(cwDir) # change to working dir!

    ## some verbose messages
    if verbose > 0:
        print "dataFile   : "+dataFile
        print "nSVMs      : "+str(nSVMs)
        print "cost       : "+str(cost)
        print "tradeoff   : "+str(tradeoff)
        print "prefix     : "+prefix
        print "cwDir      : "+cwDir
        print "Folds      : "+str(folds)
        print "output dir : "+outDir

    print('Starting...')
    
    import pandas as pd
    from sklearn.metrics import accuracy_score, mean_squared_error
    from sklearn.cross_validation import KFold,StratifiedKFold
    from sklearn.preprocessing import scale
    from SV_fuzzy import SV_fuzzy
    import numpy as np
    
    ## load data from csv using pandas
    mydata = pd.read_csv(dataFile,header=0,quotechar='"',dtype={header:str},
                         sep=',',na_values = ['NA', '-', '.', ''],index_col=0)
    if not header in mydata.columns:
        cryandexit("Unable to find column in spreadsheet: "+header)
    else:
        #y_old = mydata[header].as_matrix()
        y_old = list(mydata[header])
        
        #figure out if regression or classification
        if len(list(set(y_old)))==2:
            regress=0
        else:
            regress=1
            
        if regress==0:            
            y = -1*np.ones(len(y_old))
            if label in list(set(y_old)):
                ind = [i for i,e in enumerate(y_old) if e==label]
                y[ind] = 1
            else: 
                cryandexit("Unable to find "+label+" in column "+header)
        else:
            y = y_old.as_matrix()
        
        orig_labels = mydata[header]
        mydata = mydata.drop(header, 1)
        X = mydata.as_matrix().astype('float64')  
        X = scale(X, axis=0, with_mean=True, with_std=True, copy=False)
    
    ## if cross validation load cv and sv_f else only sv_f
    acc_all = []
    corrWAll = []
    bpcAll = []    
    if folds == 0:
        # load sv_f
        sv_f=SV_fuzzy(n=nSVMs,C=cost,t=tradeoff,max_iter=100,verbose=verbose)
        sv_f.fit(X,y) #fit
        
    else:
        # load cv
        if regress==0:
            ss=StratifiedKFold(y,n_folds=folds,shuffle=True)
        else:
            ss=KFold(y.shape[0],n_folds=folds,shuffle=True)
        sv_f=SV_fuzzy(n=nSVMs,C=cost,t=tradeoff,max_iter=100,verbose=verbose)
        
        # use correct performance metric
        if regress==0:
            scoring_func=accuracy_score
        else:
            scoring_func=mean_squared_error
            
        # cross-validate
        y_pred = np.zeros(y.shape)

        for train_index, test_index in ss:
            sv_f.fit(X[train_index,:],y[train_index])
            y_pred[test_index] = sv_f.predict(X[test_index,:])
            acc_all.append(scoring_func(y[test_index],y_pred[test_index]))
            corrWAll.append(sv_f.corrW[-1])
            bpcAll.append(sv_f.bpc[-1])
        
        sv_f.fit(X,y)
        
        cvBest = np.mean(acc_all)
        cvStd = np.std(acc_all)
        print("%d-fold cross validated accuracy is %0.3f (+/-%0.03f)" % (folds,cvBest, cvStd))
        with open(os.path.join(outDir,prefix+'_opt_cv.txt'),'w') as fp:
            for a in acc_all:
              fp.write(str(a))
              fp.write('\n')
        fp.close()
        
    # write outputs as csv files
    nResult = sv_f.hyperplanes_.shape[1]
    headers=list(mydata.columns)
    
    hyperplanes = pd.DataFrame(sv_f.hyperplanes_.transpose(),
                               index=['Group'+str(i) for i in range(1,nResult+1)], columns=headers)
    hyperplanes['Intercept'] = sv_f.intercepts_
    outputfile = os.path.join(outDir,prefix+'_hyperplanes.csv')
    hyperplanes.to_csv(outputfile)

    # make the memberships related to controls '-1' empty
    out_memberships = sv_f.memberships_.copy()
    for e,v in enumerate(y):
        if v==-1:
            out_memberships[e,:]=np.nan     
    
    cols=['Group'+str(i) for i in range(1,nResult+1)]                
    memberships = pd.DataFrame(out_memberships,index=mydata.index, 
                               columns=cols)              
             
    memberships['label'] = orig_labels   
    memberships['BinaryGroups'] = memberships[cols].idxmax(axis=1, skipna=True)
    memberships = memberships.fillna('')
                                       
    outputfile = os.path.join(outDir,prefix+'_memberships.csv')
    memberships.to_csv(outputfile)
    
    outputfile = os.path.join(outDir,prefix+'_allResults')
    np.savez(outputfile,cvAccuracy=acc_all,hyperplanes=sv_f.hyperplanes_,
             intercepts=sv_f.intercepts_,memberships=sv_f.memberships_,
             cvBPC=bpcAll,cvCorrW=corrWAll,centroids=sv_f.centroids_)
    
    # delete the temp
    if verbose>0:
        print('Cleaning temp directory\n')
    if workingDir == None:
        shutil.rmtree(cwDir)
     
    return 0
            
if __name__ == '__main__': main()


