#!/usr/bin/env python

import os, sys, getopt
import CHIMERA

def usage():
    print """
    CHIMERA--
    Clustering heterogenous disease effects via distribution matching of imaging patterns.
    
    *** For best performance, pick optimal -m, -n values based on cross-validation ***    
    
    Usage: chimera [OPTIONS] 

    Required:
    [-i <string>  ]    File name of data in .csv format with header line, one subject per row 
                       Header line can have types of:
                           ID    : subject ID (optional)
                           Group : subject diagnosis(1: patient; 0: normal control) (required)
                           Set   : subject dataset ID (optional)
                           COVAR : covariates (optional)
                           IMG   : imaging features (required)
    [-r <string>  ]    File name of clustering outputs
    [-k <int>     ]    Number of sub-groups to find
    
    Options: (See reference for more details)
    [-m <float>   ]    The value of lambda1, for b (non-negative, default 10)
    [-n <float>   ]    The value of lambda2, for A (non-negative, default 100)
    [-T <string>  ]    Transformation to be used. (affine/duo/trans, default affine)
    [-e <float>   ]    Stopping critarion tolerance (default 0.001)
    [-M <int>     ]    Maximum iteration allowed (default 1000)
    [-N <int>     ]    Number of Runs of optimization with different initialization (default 50)
    [-o <string>  ]    File name to save trained model (default not saving)
    [-s <float>   ]    Weight of covariates distance (default auto)
    [-t <float>   ]    Weight of set distance (default 10)
    [-w <int>     ]    Data normalization. 0: no action; 1: 0-1 normalization; 2: z-score. (default 1)
    [-c <int>     ]    Save most reproducible (c=1)/ minimal energy (c=2) result among runs (default 2)
    [-v --verbose ]    Verbose output

    [-u --usage   ]    Display this message
    [-h --help    ]    Display this message
    [-V --Version ]    Display version information

    Example:
    chimera -i data.csv -r output.txt -k 3 -o model.cpkl -N 10 -m 20 -v

    csv file data.csv should look like below (must contain fields IMG and Group):
        
        ID,        COVAR,COVAR, IMG,   IMG, ..., Group, Set
        ADNI_0001, 15.1, 0.454, 0.212, 0.13,....,0,     1
        ADNI_0002, 20.9, 0.121, 0.343, 1.32,..., 0,     2  
        ADNI_0003, 21.2, 0.141, 0.143, 0.21,..., 1,     2
        ...
        ...
    
    Reference:
    Dong, et al. "CHIMERA: Clustering heterogenous disease effects via distribution matching
    of imaging patterns" IEEE Transactions on Medical Imaging, 2016.

    Copyright (c) 2017 University of Pennsylvania. All rights reserved.
    Contact: sbia-software@uphs.upenn.edu
    """
   
def version():
    print(\
    "CHIMERA--\n" + 
    "Clustering heterogenous disease effects via distribution matching of imaging patterns.\n\n"+
    "version " + CHIMERA.__version__ + "\n" +
    "Copyright (c) 2017 University of Pennsylvania. All rights reserved.\n" +
    "See https://www.med.upenn.edu/sbia/software-agreement.html or COPYING file\n")


def main():
    rOpts = 0
    config = {'verbose' : False,
              'lambda1' : 10.0,  #lambda1 for b
              'lambda2' : 100.0, #lambda2 for A
              'r'       : -1.0,  #weight of covariate distance
              'rs'      : 10,    #weight of set distance
              'eps'     : 0.001, #precision tolerance
              'max_iter': 1000,
              'numRun'  : 50,   #number of different initializations
              'modelFile' : "",
              'transform' : "affine",
              'norm'      : 1,
              'mode'      : 2,   #save result based on reproducibility or energy function
              'quiet'     : True
              }
    
    sys.stdout.write("Parsing arguments...\n")
    try:
        opts, files = getopt.getopt(sys.argv[1:], "i:r:k:m:n:T:e:M:N:o:s:t:w:c:vuhV",\
                                        ["help", "usage","Version"])

    except getopt.GetoptError, err:
        usage()
        print str(err) # will print something like "option -a not recognized"
        sys.exit(2)

    for o, a in opts:
        if o in ("-v", "--verbose"):
            config['verbose'] = True
        elif o in ("-h", "--help","-u","--usage"):
            usage()
            sys.exit(0)
        elif o in ("-V", "--Version"):
            version()
            sys.exit(0)
        elif o in ("-i"):
            dataFile = a
            rOpts+=1 # for required options
        elif o in ("-r"):
            outFile = a
            rOpts+=1 # for required options
        elif o in ("-k"):
            numClusters = int(a)
            rOpts+=1 # for required options
        elif o in ("-m"):
            config['lambda1'] = float(a)
        elif o in ("-n"):
            config['lambda2'] = float(a)
        elif o in ("-T"):
            config['transform'] = a
        elif o in ("-e"):
            config['eps'] = float(a)
        elif o in ("-M"):
            config['max_iter'] = int(a)
        elif o in ("-N"):
            config['numRun'] = int(a)
        elif o in ("-o"):
            config['modelFile'] = a
        elif o in ("-s"):
            config['r'] = float(a)
        elif o in ("-t"):
            config['rs'] = float(a)
        elif o in ("-w"):
            config['norm'] = int(a)
        elif o in ("-c"):
            config['mode'] = int(a)
        else:
            assert False, "unhandled option"

    if rOpts != 3:
        usage()
        sys.stdout.write("Error: please specify all required arguments.\n")
        sys.exit(1)
     
    # check input
    if not os.path.exists(dataFile):
        sys.stdout.write("File " + dataFile + " not found\n")
        sys.exit(1)
    if dataFile == outFile:
        sys.stdout.write("Input file and output file name should not be the same.\n")
        sys.exit(1)
    if numClusters <= 1:
        sys.stdout.write("Error: number of clusters should be larger than 1, got " + str(numClusters) +'\n')
        sys.exit(1)

    # check regularizations
    if (config['lambda1']<0) | (config['lambda2']<0) :
        sys.stdout.write("Error: lambda must be non-negative value.\n")
        sys.exit(1)

    # check transformation model
    if config['transform'] not in ["affine","duo","trans"] :
        sys.stdout.write("Error: transform should be chosen from affine/duo/trans.\n")
        sys.exit(1)

    # check norm and mode
    if config['norm'] not in [0,1,2]:
        sys.stdout.write("Error: normalization type should be chosen from 0/1/2.\n")
        sys.exit(1)
    if config['mode'] not in [1,2]:
        sys.stdout.write("Error: result mode should be either 1 or 2.\n")
        sys.exit(1)

    # run optimzation
    sys.stdout.write("Starting CHIMERA clustering...\n")
    config.update({'K':numClusters})
    CHIMERA.algorithm.clustering(dataFile,outFile,config)
    
    # check output
    if not os.path.exists(outFile):
        sys.stdout.write("Error in generating clusters, " + outFile + " not generated\n")
        sys.exit(1)
    
    sys.stdout.write("Finished.\n")
    return 0
            
if __name__ == '__main__': main()

