function [V,h,Ph,F,Fa,Fc] = ssfmri_reml(YY,dist_matrix,c0_init,c1_init,phi_init,N,D,t)
% ReML estimation of [improper] covariance components from y*y'
%
% YY  - (m x m) sample covariance matrix Y*Y'  {Y = (m x N) data matrix}
% design matrix is [1 .. 1]'
% Q   - {1 x q} covariance components
% N   - number of samples
% D   - Flag for positive-definite scheme
% t   - regularisation (default 4)
%
% C   - (m x m) estimated errors = 
% h   - (q x 1) ReML hyperparameters h
%
% F   - [-ve] free energy F = log evidence = p(Y|X,Q) = ReML objective
%
% Performs a Fisher-Scoring ascent on F to find ReML variance parameter
% estimates.
%
% see also: spm_reml_sc for the equivalent scheme using log-normal
% hyperpriors
%__________________________________________________________________________
%
% SPM ReML routines:
%
%      spm_reml:    no positivity constraints on covariance parameters
%      spm_reml_sc: positivity constraints on covariance parameters
%      spm_sp_reml: for sparse patterns (c.f., ARD)
%
%__________________________________________________________________________
% Copyright (C) 2008 Wellcome Trust Centre for Neuroimaging

% John Ashburner & Karl Friston
% $Id: spm_reml.m 3791 2010-03-19 17:52:12Z karl $
 
 
% check defaults
%--------------------------------------------------------------------------
try, N; catch, N  = 1;  end       % assume a single sample if not specified
try, K; catch, K  = 100; end       % default number of iterations
try, D; catch, D  = 0;  end       % default checking
try, t; catch, t  = 4;  end       % default regularisation
 
% catch NaNs
%--------------------------------------------------------------------------
q     = find(all(isfinite(YY)));
YY    = YY(q,q);

dist_matrix = dist_matrix(q,q);

 
% dimensions
%--------------------------------------------------------------------------
n     = length(dist_matrix);
m     = 3;
X     = ones(n,1);

% initialise h and specify hyperpriors
%==========================================================================
h = zeros(3,1);
h(1) = c0_init;
h(2) = c1_init;
h(3) = phi_init;
hE  = sparse(m,1);
hP  = speye(m,m)/exp(32);
dF  = Inf;
D   = 8*(D > 0);
 
 
% ReML (EM/VB)
%--------------------------------------------------------------------------
for k = 1:K
    
    % compute current estimate of covariance
    %----------------------------------------------------------------------
    C     = sparse(n,n);
    C = C + h(1)^2*eye(n) + h(2)^2*exp(-dist_matrix*h(3)^2);
%     C = C + h(1)^2*eye(n) + h(2)^2*exp(-dist_matrix*h(3));
%     C = C + h(1)*eye(n) + h(2)*exp(-dist_matrix*h(3));
%     C = C + h(1)*(1-h(2))*eye(n) + h(1)*h(2)*exp(-dist_matrix*h(3));
 
    % positive [semi]-definite check
    %----------------------------------------------------------------------
    for i = 1:D
        if min(real(eig(full(C)))) < 1e-1

            % increase regularisation and re-evaluate C
            %--------------------------------------------------------------
            t     = t - 1;
            h     = h - dh;
            dh    = spm_dx(dFdhh,dFdh,{t});
            h     = h + dh;
            C     = sparse(n,n);
            C     = C + h(1)^2*eye(n) + h(2)^2*exp(-dist_matrix*h(3)^2);
%             C     = C + h(1)^2*eye(n) + h(2)^2*exp(-dist_matrix*h(3));
%             C     = C + h(1)*eye(n) + h(2)*exp(-dist_matrix*h(3));
%             C = C + h(1)*(1-h(2))*eye(n) + h(1)*h(2)*exp(-dist_matrix*h(3));
        else
            break
        end
    end

    % final estimate of covariance (with missing data points)
    %----------------------------------------------------------------------
    if abs(dF) < 1e-1, break, end

    % E-step: conditional covariance cov(B|y) {Cq}
    %======================================================================
    iC     = spm_inv(C);
    iCX    = iC*X;
    Cq = spm_inv(X'*iCX);   

    % M-step: ReML estimate of hyperparameters
    %======================================================================

    % Gradient dF/dh (first derivatives)
    %----------------------------------------------------------------------
    P     = iC - iCX*Cq*iCX';
    U     = speye(n) - P*YY/N;
    PQ = cell(3,1);

    PQ{1} = P*2*h(1);
    PQ{2} = P*2*h(2)*exp(-dist_matrix*h(3)^2);
    PQ{3} = P*(h(2)^2*(exp(-dist_matrix*h(3)^2).*(-dist_matrix)*2*h(3)));
    
%     PQ{1} = P*2*h(1);
%     PQ{2} = P*2*h(2)*exp(-dist_matrix*h(3));
%     PQ{3} = P*(h(2)^2*(exp(-dist_matrix*h(3)).*(-dist_matrix)));
        
%     PQ{1} = P*h(1);
%     PQ{2} = P*exp(-dist_matrix*h(3));
%     PQ{3} = P*h(2)*(exp(-dist_matrix*h(3)).*(-dist_matrix));
    
%     PQ{1} = P*((1-h(2))*eye(n) + h(2)*exp(-dist_matrix*h(3)));
%     PQ{2} = P*(-h(1)*eye(n) + h(1)*exp(-dist_matrix*h(3)));
%     PQ{3} = P*(h(1)*h(2)*(exp(-dist_matrix*h(3)).*(-dist_matrix)));
    
    for i = 1:3

        % dF/dh = -trace(dF/diC*iC*Q{i}*iC)
        %------------------------------------------------------------------
        dFdh(i,1) = -sum(sum(PQ{i}'.*U))*N/2;

    end

    % Expected curvature E{dF/dhh} (second derivatives)
    %----------------------------------------------------------------------
    for i = 1:m
        for j = i:m

            % dF/dhh = -trace{P*Q{i}*P*Q{j}}
            %--------------------------------------------------------------
            dFdhh(i,j) = -sum(sum(PQ{i}'.*PQ{j}))*N/2;
            dFdhh(j,i) =  dFdhh(i,j);

        end
    end
 
    % add hyperpriors
    %----------------------------------------------------------------------
    e     = h     - hE;
    dFdh  = dFdh  - hP*e;
    dFdhh = dFdhh - hP;

    % Fisher scoring: update dh = -inv(ddF/dhh)*dF/dh
    %----------------------------------------------------------------------
    dh    = spm_dx(dFdhh,dFdh,{t});
    h     = h + dh;

    % predicted change in F - increase regularisation if increasing
    %----------------------------------------------------------------------
    pF    = dFdh'*dh;
    if pF > dF
        t = t - 1;
    else
        t = t + 1/4;
    end
    
    % revert to SPD checking, if near phase-transition
    %----------------------------------------------------------------------
    if abs(pF) > 1e6
        [V,h,Ph,F,Fa,Fc] = ssfmri_reml(YY,dist_matrix,c0_init,c1_init,phi_init,N,1,t-2);
%         [V,h,Ph,F,Fa,Fc] = ssfmri_reml(YY,dist_matrix,c0_init,c1_init,phi_init,N,1);
        return
    else
        dF = pF;
    end
    
    % Convergence (1% change in log-evidence)
    %======================================================================
    fprintf('%s %-23d: %10s%e \n','  ReML Iteration',k,'...',full(pF));

end

 
% re-build predicted covariance
%==========================================================================
V     = 0;
V = V + h(1)^2*eye(n) + h(2)^2*exp(-dist_matrix*h(3)^2);
% V = V + h(1)^2*eye(n) + h(2)^2*exp(-dist_matrix*h(3));
% V = V + h(1)*eye(n) +h(2)*exp(-dist_matrix*h(3));
% V = V + h(1)*(1-h(2))*eye(n) + h(1)*h(2)*exp(-dist_matrix*h(3));
 
% check V is positive definite (if not already checked)
%==========================================================================
if ~D
    if min(eig(V)) < 1e-1
        [V,h,Ph,F,Fa,Fc] = ssfmri_reml(YY,dist_matrix,c0_init,c1_init,phi_init,N,1,2);
%         [V,h,Ph,F,Fa,Fc] = ssfmri_reml(YY,dist_matrix,c0_init,c1_init,phi_init,N,1);
        return
    end
end
 
% log evidence = ln p(y|X,Q) = ReML objective = F = trace(R'*iC*R*YY)/2 ...
%--------------------------------------------------------------------------
Ph    = -dFdhh;
if nargout > 3
 
    % tr(hP*inv(Ph)) - nh + tr(pP*inv(Pp)) - np (pP = 0)
    %----------------------------------------------------------------------
    Ft = trace(hP*inv(Ph)) - length(Ph) - length(Cq);
 
    % complexity - KL(Ph,hP)
    %----------------------------------------------------------------------
    Fc = Ft/2 + e'*hP*e/2 + spm_logdet(Ph*inv(hP))/2 - N*spm_logdet(Cq)/2;
 
    % Accuracy - ln p(Y|h)
    %----------------------------------------------------------------------
    Fa = Ft/2 - trace(C*P*YY*P)/2 - N*n*log(2*pi)/2 - N*spm_logdet(C)/2;
 
    % Free-energy
    %----------------------------------------------------------------------
    F  = Fa - Fc;
end
