###########################################################################
# @package draw_moe_plot
# @brief Function that plots result from 2D synthetic test cases
# It takes as input 
# @param X input 2D data
# @param y input 2D labels
# @param out_vals output numpy structure output by MOE
# @param filename filename to save plot
# It saves the output as a png file for each test case
#
# @author Harini Eavani
#
# @link: https://www.cbica.upenn.edu/sbia/software/
#
# @author: sbia-software@uphs.upenn.edu
##########################################################################

import matplotlib.pyplot as plt
import colorsys
import numpy as np

def draw_moe_plot(X,y,out_vals,filename):
    ## initialize figure
    fig1 = plt.figure()
    ax1 = fig1.add_subplot(111)
    nS = y.shape[0]
    nS,n = out_vals.memberships.shape
    
    ## generate n colors
    H_val_models = np.linspace(0,2*np.pi,n);
    H_val_models = H_val_models[0:n];
    S_val_models = 1*np.ones(n);
    V_val_models = 0.9*np.ones(n);
    model_colors = np.zeros((n,3))

    for nn in range(n):
        model_colors[nn,:] = colorsys.hsv_to_rgb(H_val_models[nn],S_val_models[nn],V_val_models[nn])
        
    ## calculate colors for each data point        
    data_colors = np.dot(out_vals.memberships,model_colors)

    ## get min and max ylims
    if len(np.unique(y))==2:
        regress=0
        ymin = np.min(X[:,1])-1
        ymax = np.max(X[:,1])+1
    else:
        regress=1
        ymin = np.min(y)-1
        ymax = np.max(y)+1
        
    ## plot data points      
    if regress==0:        
        data_colors[y==-1,:] = np.repeat(0.5*np.ones((1,3)),np.sum(y==-1),axis=0);
        ax1.scatter(X[y==-1,0],X[y==-1,1], c=data_colors[y==-1,:], s=60,alpha=0.5)
        ax1.scatter(X[y==1,0],X[y==1,1], c=data_colors[y==1,:], s=60,alpha=0.5,marker="^")
    else:
        ax1.scatter(X[:,0],y, c=data_colors, s=60,alpha=0.5)        
    
    ## get min and max xlims
    xmin = np.min(X[:,0])-1
    xmax = np.max(X[:,0])+1
    
    ## plot hyperplanes   
    xx = np.linspace(xmin,xmax,100)
    for nn in range(n):
        if regress==0:        
            yy = (-1*out_vals.hyperplanes[0,nn]*xx - out_vals.intercepts[nn])/(out_vals.hyperplanes[1,nn]+0.001)
        else:
            yy = out_vals.hyperplanes[0,nn]*xx + out_vals.intercepts[nn]           
        ax1.plot(xx,yy,color=model_colors[nn,:],linewidth=3)
    
    ## set lims
    ax1.set_xlim([xmin, xmax])
    ax1.set_ylim([ymin, ymax])
    plt.xticks([])
    plt.yticks([])
    
    ## save fig
    plt.savefig(filename, dpi=200, transparent=False,bbox_inches='tight')