% NOTE: THIS IS THE MODIFIED VERSION OF THE ORIGINAL ALGORITHM.
% THE REGULARIZATION PART AND THE INITIALIZATION OF SIGMA IS
% MODIFIED. THE DIAGONAL COVARIANCE MATRIX OPTIONS IS ADDED. ALSO THE
% REPLACEMENT OF INV-FUNCTION WAS CONSIDERED BUT IT APPEARED TO BE THE
% BEST CHOICE IN TERMS OF COMPUTATIONAL SPEED. JPK 13.4.2017.
%
% The EM-algorithm for clustering based on Gaussian mixture models.
% Copyright (c) 2002 - 2003 Jussi Tohka
% Institute of Signal Processing
% Tampere University of Technology
% Finland
% P.O. Box 553 FIN-33101
% jussi.tohka@tut.fi
% ----------------------------------------------------------------
% Permission to use, copy, modify, and distribute this software
% for any purpose and without fee is hereby
% granted, provided that the above copyright notice appear in all
% copies.  The author and Tampere University of Technology make no representations
% about the suitability of this software for any purpose.  It is
% provided "as is" without express or implied warranty.

% Version 1.1. A hack that prevents singular covariance matrices
% from appearing has been added. It is simply based on detecting
% singular covariance matrices and adding a constant matrix(i.e a
% diagonal matrix which has a constant-valued diagonal) to singular
% covariance matrices. This way covariance matrices can be made
% non-singular without changing them too much and without
% destroying positive definiteness or symmetry.
%
%

% [mu,sigma,prob,labels] = emclustering3(data,k,maxiter,mu,sigma,prob)
%
% INPUT
% data: data points as n x m matrix where n is the number of
% datapoints and m is their dimension, an error will be declared if
% n < m, i.e the matrix data is "high".
%
% k : the number of clusters
%
% maxiter : maximum number of iterations (default 50)
%
% mu: initial cluster centres, as m x k matrix (optional, you can also
% give an empty matrix if you want to specify sigma or prob)
%
% sigma: initial covariance matrices, as m x m x k matrix, or as
% one m x m matrix when the same initial covariance matrix is used
% for each class (optional).
%
% prob: initial a-priori probabilities for each class (optional)
%
% OUTPUT:
% mu : final cluster centres
% sigma: final covariances
% prob: final probabilities
% labels: labeling of the datapoints based on Bayesian classifier
% with parameters from clustering.

function [mu,sigma,prob,labels] = emclustering3(data,k,varargin)

thr = 10000;
plotSingularSolutions = 0;

datasize = size(data);
n = datasize(1);
m = datasize(2);
if n < m
    fprintf(1,'Error: The number of datapoints must be greater than \n');
    fprintf(1,'their dimension. \n');
    return;
end

if length(varargin) > 0
    max_iter = varargin{1};
else
    max_iter = 50;
end

if length(varargin) > 1
    if isempty(varargin{2})
        mu = zeros(m,k);
        for i = 0:k - 1
            mu(:,i + 1) = min(data)' + ((i + 1)/(k + 1))*(max(data) - min(data))';
        end
    else
        mu = varargin{2};
    end
else
    mu = zeros(m,k);
    for i = 0:k - 1
        mu(:,i + 1) =  min(data)' + ((i + 1)/(k + 1))*(max(data) - min(data))';
    end
end

if length(varargin) > 4
    forceDiag = varargin{5};
else
    forceDiag = 0;
end

if length(varargin) > 2
    if ndims(varargin{3}) == 3
        sigma = varargin{3};
    elseif ndims(varargin{3}) == 2 & ~isempty(varargin{3})
        sigma = repmat(varargin{3},[1 1 k]);
    else
        sigma = repmat(diag(var(data)),[1 1 k]);
    end
else
    if forceDiag
        sigma = repmat(var(data),[1 1 k]);
    else
        sigma = repmat(diag(var(data)),[1 1 k]);
    end
end

if forceDiag
    Sigma = NaN*zeros(size(sigma,2),size(sigma,2),size(sigma,3));
    for i = 1:size(sigma,3)
        Sigma(:,:,i) = diag(sigma(1,:,i));
    end
    sigma = Sigma;
end

rem_comp = zeros(1,size(sigma,3));
for i = 1:size(sigma,3)
   if sum(diag(sigma(:,:,i))) == 0
       rem_comp(i) = 1;
   end
end

if length(varargin) > 3
    prob = varargin{4};
else
    prob = repmat((1/k),k,1);
end

sigma(:,:,logical(rem_comp)) = [];
mu(:,logical(rem_comp)) = [];
prob(logical(rem_comp)) = [];
k = k - sum(rem_comp);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% The following functionality is not used:
if length(varargin) > 5
    regul_option = varargin{5};
else
    regul_option = 0;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
it = 1;
iter = 0;
changed = 1;
p = zeros(n,k);
while (iter < max_iter) & changed
    %    fprintf(1,'#');
    iter = iter + 1;
    old_prob = prob;
    old_mu = mu;
    old_sigma = sigma;
    old_p = p;
    % calculate the probability of the data point i belonging the
    % class j for each datapoint and class (or cluster). This is
    % the E-step
    
    for i = 1:k
        if cond(sigma(:,:,i)) < thr
            % proceed normally
            if forceDiag
                %  invsigma = 1./sigma(1,:,i);
                rsigma = diag(diag(sigma(:,:,i)));
            else
                rsigma = sigma(:,:,i);
            end
        else % the rank is not full so a little hack is needed
            %            disp(['Condition number of the covariance matrix high --> add regularization....'])
            if plotSingularSolutions
                disp(['singular solution #' num2str(it)])
            end
            it = it + 1;
            if forceDiag
                sigma(:,:,i) = diag(diag(sigma(:,:,i))) + eye(m)*3*m^2/thr;
            else
                %            fprintf(1,'s');
                sigma(:,:,i) = sigma(:,:,i) + eye(m)*3*m^2/thr;
            end
            rsigma = sigma(:,:,i);
            
        end
        
        dist = bsxfun(@minus,data,mu(:,i)');
        if forceDiag
            %            d = prod(sigma(1,:,i));
            %            p(:,i) = (1/sqrt(d))*exp(-0.5*(sum((dist*diag(invsigma)).*dist,2)))*prob(i);
            sigma(:,:,i) = diag(diag(sigma(:,:,i)));
        end
        %            d = det(sigma(:,:,i));
        %            invsigma = inv(sigma(:,:,i));
        %            p(:,i) = (1/sqrt(d))*exp(-0.5*(sum((dist*invsigma).*dist,2)))*prob(i);
        
        invsigma = inv(sigma(:,:,i));
        d = det(sigma(:,:,i));
        %            p(:,i) = (1/sqrt(d))*exp(-0.5*(sum((dist/rsigma).*dist,2)))*prob(i);
        p(:,i) = (1/sqrt(d))*exp(-0.5*(sum((dist*invsigma).*dist,2)))*prob(i);
    end
    
    % normalization of probalities
    nf = sum(p,2);
    nf2 = nf.*(nf > 0) + (nf == 0);
    p = bsxfun(@rdivide,p,nf2);
    
    p(find(nf == 0),:) = 1/k;
    % Then the M-step, i.e. computing the new values of parameters
    % for each cluster
    for i = 1:k
        prob(i) = mean(p(:,i));
        mu(:,i) = (sum(bsxfun(@times,data,p(:,i)))/(n*prob(i)))';
        %        if forceDiag
        %            dist = bsxfun(@minus,data,mu(:,i)');
        %            sigma(1,:,i) = diag(bsxfun(@times,dist,p(:,i))'*dist/(n*prob(i)));
        %        else
        dist = bsxfun(@minus,data,mu(:,i)');
        sigma(:,:,i) = bsxfun(@times,dist,p(:,i))'*dist/(n*prob(i));
        %        end
        
    end
    changes = sum(sum(abs(p - old_p)))/(n*k);
    %changed = changes > 0.001;
    changed = changes > 0.0001;
    if mod(iter,100) == 0
        disp(['Iteration: ' num2str(iter)])
    end
end
for i = 1:k
    if ( cond(sigma(:,:,i)) < thr )
        % proceed normally
        if forceDiag
            rsigma = diag(diag(sigma(:,:,i)));
        else
            rsigma = sigma(:,:,i);
        end
    else  % the rank is not
        % full so a little
        % hack is needed
        %fprintf(1,'s');
        %        disp(['Condition number of the covariance matrix high --> add regularization....'])
        if plotSingularSolutions
            disp(['singular solution #' num2str(it)])
        end
        it = it + 1;
        if forceDiag
            sigma(:,:,i) = diag(diag(sigma(:,:,i))) + eye(m)*3*m^2/thr;
        else
            %
            sigma(:,:,i) = sigma(:,:,i) + eye(m)*3*m^2/thr;
        end
        rsigma = sigma(:,:,i);
    end
    
    dist = bsxfun(@minus,data,mu(:,i)');
    %    if forceDiag
    %        d = prod(sigma(1,:,i));
    %        p(:,i) = (1/sqrt(d))*exp(-0.5*(sum((dist*diag(invsigma)).*dist,2)))*prob(i);
    %    else
    if forceDiag
        %            d = prod(sigma(1,:,i));
        %            p(:,i) = (1/sqrt(d))*exp(-0.5*(sum((dist*diag(invsigma)).*dist,2)))*prob(i);
        sigma(:,:,i) = diag(diag(sigma(:,:,i)));
    end
    
    d = det(sigma(:,:,i));
    invsigma = inv(sigma(:,:,i));
    %       p(:,i) = (1/sqrt(d))*exp(-0.5*(sum((dist/rsigma).*dist,2)))*prob(i);
    p(:,i) = (1/sqrt(d))*exp(-0.5*(sum((dist*invsigma).*dist,2)))*prob(i);
end
% normalization of probabilities
nf = sum(p,2);
p = bsxfun(@rdivide,p,nf);

[tmp labels] = max(p,[],2);
%fprintf(1,'\n');
disp(['Final number of iterations: ' num2str(iter)])
if iter == max_iter
    warning(['Maximum number of iterations (' num2str(max_iter) ') reached, GMM not converged!'])
end

it = it - 1;
disp(['Total number of singular solutions ' num2str(it)])
