function [Bhat, varargout] = l1_logreg_wrapper(X,Y,opts)
% Robust logistic regression (rLR) of Y against X
% Bhat = l1_logreg_wrapper(Y,X)
% Bhat = l1_logreg_wrapper(Y,X,opts)
% [Bhat,Yhat] = l1_logreg_wrapper(Y,X,...)
%
% Mandatory Inputs:
%     X: MxN design matrix:
%         - typically includes a ones column
%     Y: Mx1
%         - must have elements between [0 1] (usually binary mask)
%
% Optional Opts Struct Input:
%
%     params.standardize: {0/1} whether to standardize features
%         <default: 1>
%     params.verbosity: {0,1,2,3} that determines level of printed statements
%         <default: 1>
%     params.lambda: [0,inf] regularization parameter
%         <default: 0.01>
%
% Mandatory Output:
%     Bhat: Nx1 vector of design matrix column weights
%       - THE FIRST ARGUMENT WILL BE THE INTERCEPT WHETHER OR NOT YOU
%           INCLUDE A ONES COLUMN!!!
%
% Optional Output:
%     Yhat: Mx1 Vector == 1 ./(1+exp(-X*Bhat));

% check inputs
if nargin<2
    error('Not enough arguments');
end
if nargin<3
    opts={};
end
[X,Y,standardize,verb,lambda] = check_inputs(X,Y,opts);

vprintf(verb,1,'\n === === Starting l1_logreg wrapper === === \n');

% make sure we have needed software
check_software(verb);

% define temp dir and fnames to temp store date
fpx = timestring; % get a fileprefix string
tmpdir = ['/tmp/l1_logrep_tmp_' fpx '/'];
[success,~,~] = mkdir(tmpdir);
if ~success
    error('Couldn''t write to /tmp for temporary directory. Permissions?');
end

% fix dimensionality and other quirks
Y(Y==0)=-1;


% look for any remove any 1's colums

onescols = sum(ones(size(X)) == X)==size(X,1);
if any(onescols) && onescols(1) == 1 && sum(onescols(2:end)) == 0;
    % any ones columes  % and first item is 1   % and not other item is one
    X = X(:,2:end);
    vprintf(verb,2,'Removed one''s column from matrix\n');
else
    wrnstr = ['A one''s column was detected that wasn''t the first column. '...
        'It has been left in there.'];
    warning(wrnstr);
end

if ~any(onescols)
    wrnstr = ['No one''s column detected in X matrix. Be aware that the '...
        'first element of Bhat WILL BE THE INTERCEPT.'];
    warning(wrnstr);
end

% write input variables to MatrixMarket format;
vprintf(verb,1,'Writing matricies to disk\n');
X_file = [tmpdir 'X'];
Y_file = [tmpdir 'Y'];
model_file = [tmpdir 'model'];
result_file = [tmpdir 'result'];
mmwrite([tmpdir 'X'],X);
mmwrite([tmpdir 'Y'],Y);

% TRAIN
vprintf(verb,1,'Training\n');
% built command string
train_cmmd = 'l1_logreg_train ';
if standardize; % standardize data
    train_cmmd = [train_cmmd '-s '];
end
train_cmmd = [train_cmmd X_file ' ' Y_file ' ' lambda ' ' model_file];
vprintf(verb,2,'train_cmmd = \n%s\n',train_cmmd);

if verb>2; % print console output for verb level 3
    [err, msg] = system(train_cmmd,'-echo');
else
    [err, msg] = system(train_cmmd);
end

if err || ~exist(model_file,'file')
    disp(msg);
    disp('Something didn''t work when running l1_logreg_train');
    disp('train_cmmd = ');
    disp(train_cmmd);
    error('Something didn''t work when running l1_logreg_train')
end

% CLASSIFY
vprintf(verb,1,'Classifying\n');
if verb>0
    test_classify_cmmd = ...
        ['l1_logreg_classify -t ' Y_file ' ' model_file ...
        ' ' X_file ' ' result_file];
    vprintf(verb,2,'=== === CHECK CLASSIFICATION RATES! === ===\n');
    vprintf(verb,2,[test_classify_cmmd '\n']);
    [~, msg2] = system(test_classify_cmmd);
    vprintf(verb,3,msg2);
    error_rate = parse_msg(msg2); % get percent correct
else
    error_rate = '';
end

classify_cmmd = ...
    ['l1_logreg_classify -p ' model_file ' ' X_file ' ' result_file];
vprintf(verb,2,'classify_cmmd = \n%s\n',classify_cmmd);
if verb>2
    disp('=== === CREATE PROBABLISTIC RESULTS === ===');
    disp(classify_cmmd);
    [err, msg] = system(classify_cmmd,'-echo');
else
    [err, msg] = system(classify_cmmd);
end

% verify everything worked and output exists
if err || ~exist(result_file,'file')
    disp(msg);
    error('Something didn''t work when running l1_logreg_classify')
end

% read results back in (output)
vprintf(verb,1,'Reading results from disk\n');
Yhat = mmread(result_file);
Bhat = mmread(model_file);

% report sparsity and error rate
vprintf(verb,1,'Sparsity: %i of %i features (+1 intercept) are non-zero \n',...
    sum(Bhat~=0)-1,numel(Bhat)-1);
vprintf(verb,2,'Bhat = \n');
vprintf(verb,2,'%f\n',Bhat);
vprintf(verb,1,'Error Rate = %4.3f\n',error_rate);

% handle optional output
if nargout>1
    varargout{1}  = Yhat;
end

% delete temp files
vprintf(verb,1,'\nCleaning up /tmp files\n');
rmdir(tmpdir,'s');

vprintf(verb,1,'\n === === Exiting l1_logreg wrapper === === \n');


function str = timestring

str = num2str(fix(clock));
str = str(~isspace(str));

function [X,Y,standardize,verbosity,lambda] = check_inputs(X,Y,opts)

% handle input sizes
if size(Y,2) ~= 1;
    error('Y must be Mx1');
end
if size(X,1) ~= size(Y,1)
    error('X and Y must have same number of rows (size(*,1))');
end

% check on y
if max(Y(:))>1 || min(Y(:))<0
    error('Elements of Y must be in range [0, 1]');
end


% check on opts struct, fill in missing fields
if ~isfield(opts,'verbosity')
    verbosity = 2;
    vprintf(verbosity,0,'Using default opts.verbosity arg\n');
else
    verbosity = opts.verbosity;
end
if ~isfield(opts,'standardize')
    standardize = 1;
    vprintf(verbosity,0,'Using default opts.standardize arg\n');
else
    standardize = opts.standardize;
end
if ~isfield(opts,'lambda');
    lambda = '-r .25';
    vprintf(verbosity,0,'Using default opts.lambda arg = ''-r .25''\n');
else
    lambda = opts.lambda;
end

set = {'standardize','lambda','verbosity'}; % all possible options
fds = fields(opts);
if any(~ismember(fds,set)); % if any field isnt part of that set
    warning('opts struct includes unused field! Check spelling!')
    vprintf(verbosity,0,'Unused field: %s\n',fds{~ismember(fds,set)});
end


% check on standardize
if ~(standardize == 1 || standardize == 0);
    error('opts.standardize must be 0/1');
end

% check on verbosity
if numel(verbosity)>1 || ~isa(verbosity,'numeric') || ...
        ~(round(verbosity)==verbosity) || verbosity<0 || verbosity >3
    error('opts.verbosity argument must be a single integer: {0,1,2,3}');
end

% check on lambda
if isa(lambda,'char');
    lambda = strtrim(lambda);
elseif isa(lambda,'numeric') && lambda >= 0;
    lambda = num2str(lambda);
else
    error('Lambda must be of the forms: ''-r 0.1'' or a number > 0')
end

function check_software(verbosity)
% look for binaries and necessary mfiles, try to add them
% error if you cant
vprintf(verbosity,2,...
    '\nChecking for binaries and mfiles needed for rest of function.\n');

% check if binares are parth of path, if not add them so that we only have to
% do this once per session
[err,~] = system('l1_logreg_train -h');
if err
    vprintf(verbosity,1,'l1_logreg binaries not on path. Looking for them.\n');
    % go find the binaries
    
    loc = which('l1_logreg_train.');
    if ~isempty(loc);
        disp('Found l1_logreg_train in matlab path. Adding to PATH');
        
        % add to path
        [path, ~,~ ] = fileparts(loc);
        currpath = getenv('PATH');
        setenv('PATH', [path ':' currpath]);
        currpath = getenv('PATH');
        
        % test
        [err, ~] = system('l1_logreg_train -h');
        if ~err
            vprintf(verbosity,2,'Found and added binaries to path\n');
            vprintf(verbosity,2,['$PATH=' currpath '\n']);
            stillneed = 0;
        else
            error(['Successfully found binaries. Could not be added to path'...
                '\nExiting']);
        end
        
    else
        disp('Couldn''t find l1_logreg_train on matlab path');
        stillneed = 1;
    end
else
    vprintf(verbosity,2,'l1_logreg_train already on PATH. Proceeding\n');
    stillneed = 0;
end

if stillneed == 1;
    [err,msg] = system('locate l1_logreg_train');
    
    if ~err && ~isempty(msg)
        disp('Found l1_logreg_train using "locate"');
        
        nl = strfind(msg, sprintf('\n')); % find newlines in string
        if numel(nl)>1
            pth = strfind(msg,'src_c'); % find the one in the src_c directory
            pth = pth(1); % if more than one
            pth = find(pth<nl,1,'first'); % find the right line
            if pth == 1;
                pth = msg(1:nl(pth)-1);
            else
                pth = msg(nl(pth-1)+1:nl(pth)-1);
            end
        else
            pth = msg(1:end-1); % drop that newline
        end
        
        % add to path
        [path, ~, ~] = fileparts(pth);
        currpath = getenv('PATH');
        setenv('PATH', [path ':' currpath]);
        currpath = getenv('PATH');
        
        % test
        [err, ~] = system('l1_logreg_train -h');
        if ~err
            vprintf(verbosity,2,'Found and added binaries to path\n');
            vprintf(verbosity,2,['$PATH=' currpath '\n']);
            stillneed = 0;
        else
            error(['Successfully found binaries. Could not be added to path'...
                '\nExiting']);
        end
        
    else
        disp('Can''t find binaries using  "locate"')
        stillneed = 1;
    end
end

if stillneed == 1;
    err = ['Could not find l1_logreg_train on MATLAB path or by using '...
        '"locate." Find and add l1_logreg functions to matlab path '...
        'and try again.\n'];
    disp(err);
    error(err);
end

% check if mmwrite and similar utilites are already in matlab path
if exist('mmwrite','file') && exist('mmread','file')
    % nothing, we're set
    vprintf(verbosity,2,...
        'mmwrite and mmread (MatrixMarket read and write) exist. Proceeding\n');
else
    vprintf(verbosity,2,...
        'mmwrite.m and mmread.m not found. Looking for them.\n');
    pth = getenv('PATH');
    tmp = strfind(pth,'src_c');
    if tmp
        % they installed the whole thing including those files
        addpath([pth(1:tmp(1)-1) 'util_m']);
        if exist('mmwrite','file') && exist('mmread','file')
            vprintf(verbosity,2,...
                'Found mmwrite.m and mmread.m in l1_logreg directory.\n');
            vprintf(verbosity,2,'Added to path. Proceeding\n');
            return; % this is ok because this is the second if statemetn
            % this would not be an appropriate technique in the first
            % 'if' b/c it would skip everything else
        end
        
    elseif which('mmwrite.m')
        [path,~,~] = which('mmwrite.m');
        addpath(genpath(path));
        vprintf(verbosity,2,...
            'Found mmwrite.m in:\n%s',path);
        vprintf(verbosity,2,'Added to path. Proceeding\n');
    else
        err = ['Could not find mmwrite and read\n'...
            'Find and add mmwrite.m and mmread.m (MatrixMarket '...
            'read & write to path and then try again.\nExiting'];
        disp(err);
        error(err);
    end
    
end

function vprintf(v,l,varargin)
% VPRINTF(verbosity, verbosity_level_of_statement,sprintf_style_arguments
%
% example:
%
% verbosity = 1;
% vprintf(verbosity,1,'This will print if verbosity is at least 1\n Verbosity = %i\n',verbosity);

% Frederick Bryan, Vanderbilt, July 2013

if nargin<3
    error('not enough arguments');
end
if v>=l % if verbosity level is >= level of this statment
    %     keyboard;
    fprintf(varargin{:});
end

function er = parse_msg(msg)
loc = strfind(msg,'error'); % find the word error
msg = msg(loc(1):end); % get rid of everything before that
loc = strfind(msg,'='); % find =
msg = msg(loc(1)+1:end); % get rid of = and everything before
loc = strfind(msg,sprintf('\n')); % find newline
msg = msg(1:loc(1)-1); % get rid of everything after
er = str2double(strtrim(msg)); % trim whitespace

