function SNN = runSNNclustering(data,k,KNN,stepTh)

% Run SNN clustering for several sparsification thresholds and return the
% best solution on the basis of the minimum SSE criterion.
%
% Inputs:
% data - data matrix of size n by d, where n is the number of data points
% and d is dimensionality
% k - local neighborhood size (in number of data points)
% KNN - a struct returned by computeKNNgraph
% stepTh (whole number) - sparsification step size parameter. For instance:
% stepTh = 1 goes through all sparsifications. It is most accurate but also
% most time consuming option. For small k -values (e.g. k <= 30), this migh
% good option. However, because the number of sparsification thresholds
% increases with k, stepTh = k can be good compromise for larger k-values.
%
% Outputs:
% SNN - a struct containing SNN clustering results, where the fields are:
% SNN.mu - cluster centers of the best graph as means of connected
% components
% SNN.clustCent - cluster centers of the best graph as "densest" data points
% SNN.nrClusters - number of cluster in the best graph
% SNN.SSE - SSE values for different sparsified graphs computed using the
% densest points in the connected components as centroids
% SNN.SSE2 - SSE values for different sparsified graphs computed using the
% means of the connected components as centroids
% SNN.connComp - connected components of the best graph
% SNN.k - local neighborhood size k used
% SNN.clustPoints (cell array) - data point indices of each cluster
% SNN.time - total time in seconds taken to compute the algorithm
%
% Jukka-Pekka Kauppi, jukka-pekka.kauppi@helsinki.fi

use_mean_as_centroid = 1;

nr_data_points = size(data,1);
CONST = 2;
% If sparsification step size is not provided, we use the default step size:
if nargin == 3
    if ( k <= 100 && nr_data_points < 50000 ) || ( k <= 15 )
        % for small data sets or small k, compute all thresholds:
        stepTh = 1;
        stepTh2 = 1; % fine tuning step size
    else
        % for larger data sets and k-values, use step size of k:
        stepTh = k;
        % Next, set the step size for the fine tuning in stepTh2.
        % Preferably, stepTh2 is set to 1 to find the most accurate solution.
        % However, to keep computation time feasible for large k -values, higher
        % step size is can be used.
        stepTh2 = max(1,floor(stepTh/10)); % fine tuning step size
    end
else
    stepTh2 = 1; % fine tuning step size
end

%stepTh = 2*k
%stepTh2 = floor(stepTh/10)

% Pick exactly k nearest neighbors from possibly a larger k-NN list:
nearestNeighborDists = KNN.nearestNeighborDists(:,1:k);
nearestNeighborInds = KNN.nearestNeighborInds(:,1:k);

data = data';
distType = KNN.distType;

clustMedPointsBest = [];
DataOrig = data;
clustMedBest2 = [];
% clustMedBest = [];
% nrClusters = [];
% SSEtot = [];


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% For every distinct parameter k obtain SNN graph and SNN densities:

disp(['Constructing SNN graph for k = ' num2str(k) '...'])
tic
% Previous version:

% tic
%densVals2 = computeSNNgraph_mex(nearestNeighborInds');
%densVals2 = densVals2';
%snnInds2 = nearestNeighborInds;
%snnInds2(densVals2==0) = 0;
%snnDensities2 = sum(densVals2,2);
% toc

% Current version which is faster:

nrData = size(data,2); % number of data points
% create sparse connectivity matrix from the nearest neighbor list:
rows = repmat((1:nrData)',[1 k]);
% k-nearest-neighbor graph:
Wnn = logical(sparse(rows,nearestNeighborInds,1,nrData,nrData));
% compute SNN graph (mutual nearest neighbor graph):
W = min(Wnn, Wnn'); % unweighted SNN graph
Wred = logical(Wnn - W); % removedconnections

[ii,jj,ll] = find(Wnn);
[si,sii] = sort(ii);
sj = jj(sii);
[hi,his] = hist(si,1:nrData);
startInd = [1 1+cumsum(hi)];
endInd = [cumsum(hi)];

snnDensities = zeros(nrData,1);

% column-wise processing with sparse matrices in Matlab is much faster
% than row-wise processing -> take transpose:
W = W';
Wred = Wred';
Wnn = Wnn';

disp(['Computing SNN densities...'])
densVals = zeros(size(nearestNeighborInds));
%tic
for s = 1:nrData
    fi = sj(startInd(s):endInd(s));
    F = Wnn(:,fi);
    Fs = repmat(Wnn(:,s),1,size(F,2));
    vals = 1 + sum(F&Fs,1);
    fi_red = find(Wred(:,s));
    fi_rel = find(W(:,s));
    inds = [fi_red;fi_rel];
    markers = [zeros(length(fi_red),1);ones(length(fi_rel),1)];
    [so,soi] = sort(inds);
    markerss = markers(soi);
    
    [ns,nsi] = sort(nearestNeighborInds(s,:),2);
    [nss,nssi] = sort(nsi);
    vals2 = vals;
    vals2(markerss==0) = 0;
    densVals(s,:) = vals2(nssi);
    
    snnDensities(s) = sum(vals(markerss==1));
    if mod(s,10000) == 0
        disp(['Processed ' num2str(s) '/' num2str(nrData) ' data points'])
        %     toc
        %     tic
    end
end
%toc

snnInds = nearestNeighborInds;
snnInds(densVals==0) = 0;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Obtain sparsification thresholds based on SNN densities:
Ths = unique(snnDensities);Ths = Ths(1:end-1);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Find initialization points using graph sparsification.

% The maximum possible number of sparsification thresholds is k^2.
% Thus, when k is very small, it is wise to go through every threshold
% to guarantee good approximation:
clustMedBest = [];
ThSteps = length(Ths):-stepTh:1;
%ThSteps = 1:length(Ths);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ThSteps = round(linspace(1,length(Ths),10));
% length(Ths)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
nrSteps = length(ThSteps);
nrClusters = NaN*zeros(1,length(Ths));
SSEtot = NaN*zeros(1,length(Ths));
SSEtot2 = NaN*zeros(1,length(Ths));
currentMin = inf;
flag = 0;
iter = 0;
disp('Evaluating SNN graphs using several thresholds...')
for m = ThSteps
    iter = iter + 1;
    %' (current threshold ' num2str(Ths(m)) ', final threshold ' num2str(Ths(end)) ')'])
    % Sparsify SNN graph, get connected components:
    % [nrClusters(m),clustMed] = getInitPoints(snnDensities,Ths(m),snnInds,nearestNeighborDists);
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % SNN graph sparsification:
    sparsVals = double(snnDensities < Ths(m)); % obtain sparsified data points
    ZZ = sparsifyGraph_mex(snnInds',sparsVals'); % sparsify SNN graph
    if sum(abs(ZZ(:))) == 0
        connC = 1:size(data,2);
    else
        nrData = size(data,2);
        rows = repmat((1:nrData)',[1 k]);
        columns = ZZ(:,1:nrData)';
        columns(columns==0) = rows(columns==0);
        G = logical(sparse(rows,columns,1,nrData,nrData));
        [S,connC] = FuSeConncomp(G);
        %        connC2 = connComp_mex(ZZ);
        %        disp(['similarity ' num2str(isequal(connC2,connC))])
    end
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % Compute initialization points and SSE. Consider only non-singleton clusters:
    nrSamples = hist(connC,min(connC):max(connC));
    clusterIndicesNoSingletons = find(nrSamples~=1);
    clustMed = zeros(length(clusterIndicesNoSingletons),1);
    clustMed2 = zeros(length(clusterIndicesNoSingletons),size(data,1));
    clustMed3 = clustMed2;
    clear clustMedPoints
    if length(clusterIndicesNoSingletons) >= 2
        clustMedPoints = cell(1,length(clusterIndicesNoSingletons));
        for n = 1:length(clusterIndicesNoSingletons)
            % get connected points:
            densestPoints = find(connC == clusterIndicesNoSingletons(n));
            % take the densest point:
            [~,mii] = min(nearestNeighborDists(densestPoints));
            clustMed(n,:) = densestPoints(mii);
            % take mean of the connected points:
            Mea = mean(data(:,densestPoints),2)';
            if strcmp(distType,'cor')
                Mea = Mea - mean(Mea);
                Mnorm = sqrt(sum(Mea.^2));
                Mea = Mea./Mnorm(:,ones(1,size(Mea,2)));
            end
            clustMed2(n,:) = Mea;
            clustMed3(n,:) = mean(DataOrig(:,densestPoints),2)';
            clustMedPoints{n} = densestPoints;
        end
        nrClusters(m) = length(clusterIndicesNoSingletons);
        
        % use densest point in the connected component as centroid and compute SSE:
        SSEtot(m) = computeMSE(data',distType,data(:,clustMed)');
        % use mean of connected component as centroid and compute SSE:
        SSEtot2(m) = computeMSE(data',distType,clustMed2);
        NrC = ['#Clusters ' num2str(nrClusters(m))];        
    else
        nrClusters(m) = NaN;
        SSEtot(m) = NaN;
        SSEtot2(m) = NaN;
        NrC = ['#Clusters --'];
    end
    
    % update current best solution:
    if use_mean_as_centroid
        cond_true = SSEtot2(m) < currentMin;
        SSEcurrent = SSEtot2(m);
    else
        cond_true = SSEtot(m) < currentMin;
        SSEcurrent = SSEtot(m);
    end
    if cond_true
        % new minimum found, update current best solution:
        if use_mean_as_centroid
            currentMin = SSEtot2(m);
        else
            currentMin = SSEtot(m);
        end
        clustMedBest = clustMed;
        clustMedBest2 = clustMed3;
        clustMedPointsBest = clustMedPoints;
        nrClustersCurrent = nrClusters(m);
    end
    if ~isinf(currentMin)
        disp(['Current best: ' num2str(currentMin) ' (#Clusters ' num2str(nrClustersCurrent) ')' '  Current: ' num2str(SSEcurrent) ' (' NrC ')  ' 'threshold ' num2str(iter) '/' num2str(nrSteps) ])
    end
    
    if use_mean_as_centroid
        cond_true = SSEtot2(m) > CONST*currentMin;
    else
        cond_true = SSEtot(m) > CONST*currentMin;
    end
    
    if cond_true
        flag = flag + 1;
        if flag == 2
            disp('Solution quality degraded notably, stop searching...')
            break
        end
    else
        flag = 0;
    end
end

clear clustMed

% After finding approximate solutions, find the final solution between
% the two best approximate solutions:
flag = 0;

if stepTh > 1 % 
    % Note: if stepTh == 1, all thresholds have been already carried 
    % out and this fine tuning step is not needed.
    if use_mean_as_centroid    
        totSSEtmp = SSEtot2;
    else
        totSSEtmp = SSEtot;        
    end
    [~, miIndSSE1] = min(totSSEtmp);
    totSSEtmp(miIndSSE1) = NaN;
    [~, miIndSSE2] = min(totSSEtmp);
    iter = 0;
    sortedInds = sort([miIndSSE1 miIndSSE2]);
    nrSteps = length(sortedInds(1):stepTh2:sortedInds(2));
    
    disp(' ')
    disp('Fine tuning solution...')
    for v = sortedInds(1):stepTh2:sortedInds(2)
        iter = iter + 1;
        
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % SNN graph sparsification:
        sparsVals = double(snnDensities < Ths(v)); % obtain sparsified data points
        ZZ = sparsifyGraph_mex(snnInds',sparsVals'); % sparsify SNN graph
        %        connC = connComp_mex(ZZ); % get labels of connected components
        
        if sum(abs(ZZ(:))) == 0
            connC = 1:size(data,2);
        else
            nrData = size(data,2);
            rows = repmat((1:nrData)',[1 k]);
            columns = ZZ(:,1:nrData)';
            columns(columns==0) = rows(columns==0);
            G = logical(sparse(rows,columns,1,nrData,nrData));
            [S,connC] = FuSeConncomp(G); 
        end
        
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        % Compute initialization points and SSE. Consider only non-singleton clusters:
        nrSamples = hist(connC,min(connC):max(connC));
        clusterIndicesNoSingletons = find(nrSamples~=1);
        clustMed = zeros(length(clusterIndicesNoSingletons),1);
        clustMed2 = zeros(length(clusterIndicesNoSingletons),size(data,1));
        clustMed3 = clustMed2;
        clear clustMedPoints
        if length(clusterIndicesNoSingletons) >= 2
            for n = 1:length(clusterIndicesNoSingletons)
                % get connected points:
                densestPoints = find(connC == clusterIndicesNoSingletons(n));
                [~,mii] = min( nearestNeighborDists(densestPoints) );
                
                clustMed(n,:) = densestPoints(mii);
                Mea = mean(data(:,densestPoints),2)';
                if strcmp(distType,'cor')
                    Mea = Mea - mean(Mea);
                    Mnorm = sqrt(sum(Mea.^2));
                    Mea = Mea./Mnorm(:,ones(1,size(Mea,2)));
                end
                clustMed2(n,:) = Mea;
                clustMedPoints{n} = densestPoints;
                clustMed3(n,:) = mean(DataOrig(:,densestPoints),2)';
            end
            nrClusters(v) = length(clusterIndicesNoSingletons);
            SSEtot2(v) = computeMSE(data',distType,clustMed2);
            SSEtot(v) = computeMSE(data',distType,data(:,clustMed)');
            NrC = ['#Clusters ' num2str(nrClusters(v))];
        else
            nrClusters(v) = NaN;
            SSEtot(v) = NaN;
            SSEtot2(v) = NaN;
            NrC = ['#Clusters --'];
        end
        
        % update current best solution:
        if use_mean_as_centroid
            cond_true = SSEtot2(v) < currentMin;
            SSEcurrent = SSEtot2(v);
        
        else
            cond_true = SSEtot(v) < currentMin;
            SSEcurrent = SSEtot2(v);
        end
        if cond_true
            if use_mean_as_centroid
                currentMin = SSEtot2(v);
            else
                currentMin = SSEtot(v);
            end
            nrClustersCurrent = nrClusters(v);
            clustMedBest = clustMed;
            clustMedBest2 = clustMed3;
            clustMedPointsBest = clustMedPoints;
        end
        if ~isinf(currentMin)
            disp(['Current best: ' num2str(currentMin) ' (#Clusters ' num2str(nrClustersCurrent) ')' '  Current: ' num2str(SSEcurrent) ' (' NrC ')  ' 'threshold ' num2str(iter) '/' num2str(nrSteps) ])
        end
        if use_mean_as_centroid
            cond_true = SSEtot2(v) > CONST*currentMin;
        else
            cond_true = SSEtot(v) > CONST*currentMin;
        end
        
        if cond_true
            flag = flag + 1;
            if flag == 2
                disp('Solution quality degraded notably, stop searching...')
                break
            end
        else
            flag = 0;
        end
        
    end
end

% create SNN-struct:
SNN.clustCent = clustMedBest;
SNN.mu = clustMedBest2;
SNN.nrClusters = nrClusters;
SNN.SSE = SSEtot;
SNN.SSE2 = SSEtot2;
SNN.connComp = clustMedPointsBest;
SNN.k = k;

% save index of the connected components:
for m = 1:length(SNN.connComp)
    SNN.clustPoints{m} = nearestNeighborInds(SNN.connComp{m},:);
end

SNN.time = toc;

disp(['The best solution (for k = ' num2str(SNN.k) ') contains ' num2str(size(SNN.mu,1)) ' clusters.'])
disp('SNN clustering step finished.')
