function [Bhat,varargout] = rlogistic(X,Y,params)
% Robust logistic regression (rLR) of Y against X
% Bhat = RLOGISTIC(Y,X)
% Bhat = RLOGISTIC(Y,X,params)
% [Bhat,varargout] = RLOGISTIC(Y,X,...)
% 
% Mandatory Inputs: 
%     X: MxN design matrix: 
%         - typically includes a ones column
%     Y: Mx1 
%         - must have elements between [0 1] (usually binary mask)
% 
% Optional Params Struct Input:
%     params.Ginit: [2x1] array of [g00, g11] for initialization 
%         <default: [.8 .8]>
%         g00: prob mask shows background where truly background
%         g11: prob mask shows foreground where truly foreground
%     params.Binit: [Nx1] array to initialize Beta weights
%         <default: zeros(size(X,2),1)>
%     params.max_its: maximum iterations of the rLR process
%         <default: 25> - usually takes < 5
%     params.verbose: 1/0 that determines level of printed statements
% 
% Mandatory Output:
%     Bhat: Nx1 vector of design matrix column weights
% 
% Optional Output:
%     Yhat: Mx1 Vector == 1 ./(1+exp(-X*Bhat));
 
% check out inputs.
[X,Y,params] = inputcheck(X,Y,params);

if params.verbose; 
    fprintf('\nROBUST ITERATIVE METHOD\n'); 
end

% initialize
% g00: P(obs labe =0 | true label = 0) aka specificity (about .7)
% g11: P(obs label =1 | true_label =1) aka sensitivity (about 1)
g00_init = params.Ginit(1); g00 = g00_init; % specificity
g11_init = params.Ginit(2); g11 = g11_init; % sensitivity  (approx 1)
old_g00 = 0;
old_g11 = 0;
iter = 0;
Binit = params.Binit;
l2lamb = params.l2_lamb;
l1lamb = params.l1_lamb;

if params.wentnuts==1 && size(X,2)>9
    
    % iterate once on just the first 10 features, then pass that initialization
    % for the big regression. should be more stable.
    if params.verbose;
        fprintf('Feature matrix very large. Seeking good initialization.\n');
    end
    
    X2 = X(:,1:min(size(X,2),10));
    Binit = zeros(size(X2,2),1);
    [Bhat, ~] = fminsearch(@loglikelihood, Binit, ...
        optimset('TolFun', 0.0001), Y, X2, 1, g00_init, g11_init, l2lamb, l1lamb);
   
    % first 10 values of Bhat are good - now need to append the rest (as zeroes)
    Bhat = [Bhat;zeros(size(X,2)-size(Bhat,1),1)];
    fprintf('Bhat initialization found. Beginning full regression.\n');
    
else
    Bhat = Binit;
end

% once and then iterate
[Bhat, ~] = fminsearch(@loglikelihood, Bhat, ...
    optimset('TolFun', 0.0001), Y, X, 1, g00_init, g11_init, l2lamb, l1lamb);


Bhatsaved(:,1)= Bhat;

while (abs(old_g00 - g00) > .0005 && abs(old_g11 - g11) > .0005 && ...
        iter < params.max_its)
    
    % increment counter for printing and saving Bhat
    iter = iter+1;
    % print
    if params.verbose;
        fprintf('\n %u. g00 = %.3f \t g11 = %.3f ', iter, g00, g11);
    end
    
    % estimate P(Y|X,B)
    Yhat = 1./(1+exp(-X*Bhat));
    
    % estimate coeff matrix (if statement allow the 0 case)
    old_g00 = g00;
    old_g11 = g11;
    g00 = sum(Yhat <  .5 & Y ==0);
    if sum(Yhat<.5) > 0; 
        g00 = g00/sum(Yhat<.5);
    end 
    g11 = sum(Yhat >= .5 & Y ==1);
    if sum(Yhat>=0.5) > 0;
        g11 = g11/sum(Yhat>=.5); 
    end
    
    if isnan(g00) || isnan(g11); 
        msg = sprintf(['g00 or g11 went NaN. Check code in rlogistic.\n'...
            'Time = %\nSize(X) = [%f %f %f], Size(Y) = [%s %s]\n'],...
        datestr(now),size(X,1), size(X,2), size(X,3), size(Y,1), size(Y,2));
        
        error(msg);
    end
    
    % don't let g00/g11 move by more than .2 per iteration for robustness
    [g00, g11] = constrain_g(g00,g11,old_g00,old_g11);
    
    [Bhat, ~] = fminsearch(@loglikelihood, Bhat, ...
        optimset('TolFun', 0.0001), Y, X, 1, g00, g11, l2lamb, l1lamb); 
    
    Bhatsaved(:,iter+1)=Bhat;
end

% loop exited. print last results
if params.verbose; % for last iteration
    fprintf('\n Final: g00 = %.3f \t g11 = %.3f \n', g00, g11); 
    fprintf('Bhat through the iterations (as columns,left -> right):\n'); 
    disp(Bhatsaved);
    fprintf('\n');
end

% if exited because max its:
if iter==params.max_its && abs(old_g00-g00) > 0.0005
    warning('Max iterations reached. Exited');
end

% math done prepare outout
if nargout>1; % output Yhat too
    varargout{1} = 1 ./(1+exp(-X*Bhat));
end



function [g00, g11] = constrain_g(g00,g11,oldg00,oldg11)

if abs(g00-oldg00)<.2;
    % nothing
elseif g00>oldg00
    g00=oldg00+.2; fprintf('\t (contrained g00)');
elseif g00<oldg00
    g00=oldg00-.2; fprintf('\t (contrained g00)');
elseif isnan(g00);
    %do nothing
    fprintf('\t (g00 is NaN'); 
end

if abs(g11-oldg11)<.2;
    % nothing
elseif g11>oldg11
    g11=oldg11+.2; fprintf('\t (contrained g11)');
elseif g11<oldg11
    g11=oldg11-.2; fprintf('\t (contrained g11)');
elseif isnan(g11);
    fprintf('\t (g1 is NaN)\n');
end


function [X,Y,params] = inputcheck(X,Y,params)
 
if nargin<2
    error('Not enough arguments');
end
if nargin<3
    params=[];
end

% handle input sizes
if size(Y,2) ~= 1;
    error('Y must be Mx1');
end
if size(X,1) ~= size(Y,1)
    error('X and Y must have same number of rows (size(*,1))');
end

% check on y
if max(Y(:))>1 || min(Y(:))<0
    error('Elements of Y must be in range [0, 1]');
end

% handle params struct defaults
if ~isfield(params,'Ginit')
    params.Ginit = [.8 .8];
else % is a field
    if numel(params.Ginit)~=2
        error('if given, params.ginit must be 2 element array,');
    end
    if max(params.Ginit(:))>1 || min(params.Ginit(:))<0
        error('elements of init must be between 0 and 1');
    end
end
if ~isfield(params,'Binit');
    params.Binit = zeros(size(X,2),1);
else 
    if ~isequal(size(params.Binit),[size(X,2), 1]);
        warning('Binit should be size: [size(X,2) 1]');
        if numel(params.Binit) == size(X,2)
            params.Binit = params.Binit';
            disp('Transposed params.Binit to make it the correct size');
        else
            error('params.Binit has incorrect number of elements');
        end
    end
end
if ~isfield(params,'verbose');
    params.verbose = 0;
end
if ~isfield(params,'max_its');
    params.max_its = 25;
end
if ~isfield(params,'l1_lamb');
    params.l1_lamb = 0; %none by default
end
if ~isfield(params,'l2_lamb');
    params.l2_lamb = 0; %none by default
end
if ~isfield(params,'wentnuts');
    params.wentnuts = 0;
end