function results = plotSegmentationResults(results,Params,maxNrClusters,nrNoiseClust,clipNames,resultsPrev,clust_ind,rem_clust)

% USE: plotSegmentationResults(results,Params,nrNoiseClust,maxNrClusters,clipNames)
%
% Visualization of the functional segmentation ISC results. Three figures
% are generated:
% 1) Figure showing clusters detected as noisy based on the total
% number of voxels located in ventricles, white matter or brain-stem.
% 2) Figure showing the clusters over anatomical image. Set user parameters
% below to control visualization (together with input maxNrClusters).
% 3) Figure showing the ISC mean and ISC variability information for each
% clusters (GMM centroids).
%
% Note! Clusters are organized in the increasing order of the ISC variability
% they exhibit, i.e, clusters with most stable segments are shown first.
% As a simple example, plot as 5 most stable ISC segments:
% 
% [Fuse,R] = loadSegmentationResults(Params);
% R = plotSegmentationResults(R,Params,5,0)
% 
% Inputs:
%
% Params - parameter struct of your ISC project
%
% nrNoiseClust - number of "noise clusters" to remove in the white matter, 
% ventricle and brain-stem areas. Clusters are removed in the order of their 
% size in terms of the number of voxels. For instance, nrNoiseClust = 2
% means that 2 largest noise clusters located in white matter, 
% ventricle or brain-stem areas are removed.
%
% maxNrClusters - Maximum number of (non-removed) clusters to visualize to
% allow better visual control when the total number of clusters is large.
% For instance, maxNrClusters = 5 means that 5 clusters having the
% most stable ISC values are visualized over an anatomical brain image. 
% Set this to "Inf" to visualize all (non-thresholded) clusters. Default is 5.
%
% clipNames - a cell array containing names of the clips/sessions of
% interest to make interpretation of the results easier. The default names
% are 'Clip1', 'Clip2', and so forth.
% 
% See also:
% RUNFUSE


sortClusters = 1;

% Inner parameters for visualization:
textOffset = 8; % adjust placement of the MNI coordinate labels
contrast_val = 45; % adjust contrast of the anatomical image
% user parameters for anatomical visualization, adjust them if needed:
z_slice_index = 6:1:71; % axial slice indices for visualization
z_slice_index = 20:1:71; % axial slice indices for visualization
nrColumns = 8; % number of columns in the anatomical slice image series
textOffset = 8; % adjust placement of the MNI coordinate labels
contrast_val = 45; % adjust contrast of the anatomical image

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Parse inputs and set defaults if needed:
disp('Checking inputs...')
if nargin < 2
    error('Too few input arguments!')
end
if nargin > 8
    error('Too many input arguments!')
end
if nargin < 8
    rem_clust = [];
end

if nargin < 7
    clust_ind = [];
end
if nargin < 6
    resultsPrev = [];
end
nrClips = Params.PrivateParams.nrSessions;
if nargin < 5
    for m = 1:nrClips
        clipNames{m} = ['Clip' num2str(m)];
    end
else
    if isempty(clipNames)
        for m = 1:nrClips
            clipNames{m} = ['Clip' num2str(m)];
        end
    elseif length(clipNames) ~=  nrClips
        error(['Clip names must be given as a cell-struct containg ' ...
            num2str(nrClips) ' elements!'])
    end
end
if nargin < 4 || isempty(nrNoiseClust) || nrNoiseClust == 0
    nrNoiseClust = 0;
    disp(['Number of noise clusters not provided by user, no noise mask applied.'])
end
if nargin < 3 || isempty(maxNrClusters)
    maxNrClusters = 5;
    disp(['Maximum number of clusters not provided by user, ' ...
        num2str(maxNrClusters) ...
        ' most stable clusters visualized by default. To visualize all clusters, set maxNrClusters = inf'])
else
    disp(['Visualizing ' num2str(maxNrClusters) ' most stable clusters...'])
end

if ~isempty(resultsPrev)
    % In this case, two brain maps plotted side by side so we reduce the number of slices:
    nrColumns = 3; % number of columns in the anatomical slice image series
    z_slice_index = 26:4:68; % axial slice indices for visualization
end
maxColmapSize = 53; % largest colormap currently supported.

% Remove noise clusters according to user input parameter "NoiseThres":
if nrNoiseClust == inf
    warning('All noise clusters even partly located in the ventricle/brain stem/white matter areas are removed!')
else
    disp([num2str(nrNoiseClust) ' most dominant noise clusters in the ventricle/brain stem/white matter areas are removed...'])
end

[results.clustersToVisualize,results.removedInd] = removeNoiseClusters(results,nrNoiseClust,sortClusters);

% If restricted by a user, visualize only "maxNrClusters" clusters
% having the most stable ISCs:
results.clustersToVisualize = results.clustersToVisualize(1:(min(length(...
    results.clustersToVisualize),maxNrClusters)));


if length(results.clustersToVisualize) > maxColmapSize
    warning(['Found number of clusters (' num2str(length(results.clustersToVisualize)) ...
        ') exceeds the currently supported maximum value for visualization, limiting the number of clusters to first ' ...
        num2str(maxColmapSize)])
    results.clustersToVisualize = results.clustersToVisualize(1:maxColmapSize);
end

if ~isempty(clust_ind)
    results.clustersToVisualize = clust_ind;
end
if ~isempty(rem_clust)
    fi_rem = [];
    for m = 1:length(rem_clust)
        fi_rem = [fi_rem find( results.clustersToVisualize == rem_clust(m) )];
    end
    results.clustersToVisualize(fi_rem) = [];
end

results = create_cluster_names(results);
[IM,colM,LL,~,Clust_ind,results] = plot_clusters2(...
    results.clustersToVisualize,results,results.brain,z_slice_index,results.AffMat,...
    textOffset,'clustermap',[],contrast_val,nrColumns,resultsPrev);

results.nrClusters = LL;
G = gcf; % save figure handle
plot_GMM_centers(results.Clust,results,IM,colM,results.clustersToVisualize,results.Di,...
    results.rel_variance,results.E,clipNames);
figure(G); % move spatial plot on top of the GMM plot for clarity

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Subfunctions:
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [sortedClusterIndex,removed_clust] = removeNoiseClusters(results,nrNoiseClust,sortClusters)

if sortClusters
    sortedClusterIndex = results.rel_variance_sorted_index;
else
    sortedClusterIndex = 1:length(results.rel_variance_sorted_index);
end

atlasSub = results.atlasSub;
mask = results.mask;
%NoiseThres = results.NoiseThres;

noiseMask = zeros(size(atlasSub.img));
for k = [103 108 101 112 114]
    noiseMask(find(atlasSub.img == k)) = 1;
end

noise_areas = zeros(length(results.clustInd),4);
S = sum(double(mask.img(:))>0);
for k = 1:length(results.clustInd)
    noise_areas(k,1) = sum(noiseMask(results.clustInd{...
        sortedClusterIndex(k)}));
    noise_areas(k,2) = length(results.clustInd{sortedClusterIndex(k)});
    noise_areas(k,3) = sum(noiseMask(results.clustInd{...
        sortedClusterIndex(k)}))/length(results.clustInd{...
        sortedClusterIndex(k)});
    noise_areas(k,4) = sum(noiseMask(results.clustInd{...
        sortedClusterIndex(k)}))/S;
end
[sn,sni] = sort(noise_areas(:,1),'descend');
fi = sni(1:nrNoiseClust);

% remove also clusters whose relative variance could not be determined:
fiNaN = find(isnan(results.rel_variance_sorted));
fiNeg = find(results.rel_variance_sorted<=0);
fi2 = [fiNaN(:) ; fiNeg(:)];
if ~isempty(fi2)
    warning([num2str(length(fi2)) ...
        ' cluster(s) have zero or negative ISC mean or variance, these custers are not visualized!'])
end

fi = unique([fi(:) ;fi2(:)]);

disp(['Number of "noise clusters" removed: ' num2str(length(fi))])
if sum(noise_areas(:,1)) ~= 0
    figure
    ZZ = bar(noise_areas(:,1));xlim([0.5 0.5+length(noise_areas(:,1))]);
    set(gca,'XTick',1:length(sortedClusterIndex),'XTickLabel',sortedClusterIndex)
    grid on;hold on
    bar(fi,noise_areas(fi,1),'FaceColor','r','BarWidth',ZZ.BarWidth);
    xlim([0.5 0.5+length(noise_areas(:,1))]);
    title('Number of voxels located in ventricles, white matter or brain-stem');
    xlabel('Cluster index #');ylabel('Number of voxels')
    legend([{'Retained clusters'};{'Removed clusters'}],'Location', 'NorthWest')
end

removed_clust = sortedClusterIndex(fi);
sortedClusterIndex(fi) = [];

function results = create_cluster_names(results)

clusters_to_visualize = results.clustersToVisualize;
clusterNames = results.clustNames_abbr(clusters_to_visualize);
results.cluster_order = clusters_to_visualize;

clear clusterNumbers
for pp = 1:size(clusterNames,1)
    ind = strfind(clusterNames{pp},',');
    if ~isempty(ind)
        str1 = clusterNames{pp}(1:ind-1);
        str2 = clusterNames{pp}(ind+2:end);
        if strcmp(str1(end-1:end),'.R') && strcmp(str2(end-1:end),'.L')
            str1tmp = str2;
            str2 = str1;
            str1 = str1tmp;
            clusterNames{pp} = [str1 ', ' str2];
        end
    end
    clusterNames{pp} = ['#' num2str(pp) ' ' clusterNames{pp}];
    clusterNumbers{pp} = ['#' num2str(pp)];
end

results.clusterNames = clusterNames;
results.clusterNumbers = clusterNumbers;

function [IM,colM,LL,Zval2,clust_ind,results] = plot_clusters2(...
    clust_ind,results,brain,leikkeet,AffMat,textOffset,style,Th,...
    contrast_val,nrColumns,resultsPrev)

rem_nonactive = 0;
plot_ground_truth = 0;

clusterNames = results.clusterNames;

if isempty(resultsPrev)
    Dat = 0;
elseif isequal(resultsPrev,1)
    Dat = 1;
elseif isstruct(resultsPrev)
    Dat = 2;
else
    error('Wrong input for resultsPrev!')
end

if ~isempty(results.removedInd)
    rem_noise = 1;
else
    rem_noise = 0;
end

switch style
    case 'ISCcomb'
        
        quant_factor = 100; % isc map quant factor
        figure
        colM = [1 0 0; 0 1 0; 0 0 1; 1 0.6 0; 0 1 1];
        
        for mm = 1:size(brain.img,3)
            brain.img(:,:,mm) = quant(brain.img(:,:,mm),contrast_val);
        end
        
        colM = [[0 0 0];colM];
        Un = nonzeros(unique(brain.img));
        Br_tmp = brain.img;
        for mm = 1:length(Un)
            Br_tmp(find(brain.img==Un(mm))) = mm;
        end
        brain.img = Br_tmp;
        
        nr_intens_lev = length(brain.img);
        Cm = colormap(gray(nr_intens_lev));
        FirstGrayLev = size(colM,1);
        colM = [colM;Cm];
        set(gcf,'Colormap',colM)
        D = double(brain.img);
        D = D + FirstGrayLev;
        D = D + 1;
        D_isc_map = D; % + 100 - FirstGrayLev;
        
        Im = clust_ind;
        D_isc_map(Im.img>0) = Im.img(Im.img>0);
        D_isc_map(D==7) = 0;
        ISCmap = zeros(size(brain.img,2),size(brain.img,1),1,size(brain.img,3));
        for mm = 1:size(brain.img,3)
            ISCmap(:,:,mm) = rot90(D_isc_map(:,:,mm));
        end
        
        ISCmap = squeeze(ISCmap);
        IM = ISCmap;
        IM2 = IM;
        IM = IM + 1;
        IM2 = IM2 + 1;
        
    case 'ISCmap'
        
        nr_clust = 60;
        quant_factor = 100; % isc map quant factor
        figure
        colM = colormap(hot(nr_clust));
        
        for mm = 1:size(brain.img,3)
            brain.img(:,:,mm) = quant(brain.img(:,:,mm),contrast_val);
        end
        
        colM = [[0 0 0];[1 1 1]*0.8;colM];
        Un = nonzeros(unique(brain.img));
        Br_tmp = brain.img;
        for mm = 1:length(Un)
            Br_tmp(find(brain.img==Un(mm))) = mm;
        end
        brain.img = Br_tmp;
        
        nr_intens_lev = length(brain.img);
        Cm = colormap(gray(nr_intens_lev));
        FirstGrayLev = size(colM,1);
        colM = [colM;Cm];
        set(gcf,'Colormap',colM)
        D = double(brain.img);
        D = D + FirstGrayLev;
        D = D + 1;
        
        D_isc_map = D; % + 100 - FirstGrayLev;
        
        % ISC kartta:
        Im = clust_ind;
        P = Im.img > Th;
        Th = Th*quant_factor;
        Im.img = round(Im.img*quant_factor);
        D_isc_map(P) = Im.img(P);
        isc_val = Im.img(P);
        mima = [min(isc_val(:)) max(isc_val(:))];
        D_isc_map(D_isc_map==101) = 0;
        ISCmap = zeros(size(brain.img,2),size(brain.img,1),1,size(brain.img,3));
        for mm = 1:size(brain.img,3)
            ISCmap(:,:,mm) = rot90(D_isc_map(:,:,mm));
        end
        ISCmap = squeeze(ISCmap);
        IM = ISCmap;
        IM2 = IM;
        
    case 'clustermap'
        
        plotCentres = 0;
        
        clust = results.clustInd(clust_ind);
        nr_clust = length(results.clustInd);
        figure;set(gcf,'Units','Normalized','Position',[0.03 0.1 0.9 0.7])
        Nr_clust = nr_clust+2;        
        ind = [1 2];
        if 1
            for mm = 1:size(brain.img,3)
                brain.img(:,:,mm) = quant(brain.img(:,:,mm),contrast_val);
            end
        end
        if Dat > 0
            if Dat == 1
                subclust1 = results.subclusters;
                nrClust1 = length(subclust1);
                if nrClust1 < 20
                    load colMap_small
                elseif nrClust1 >= 20 && nrClust1 < 27
                    load colMap_mid
                else
                    load colMap_large
                end
                clust_ind = 1:nrClust1;
                if rem_noise
                    fi = [];
                    for nn = 1:length(results.removedInd)
                        fi = [fi;find(clust_ind == results.removedInd(nn))];
                    end
                    clust_ind(fi) = [];
                end
                removed_clusts1 = results.removedInd;
                results.removed_clusts1 = removed_clusts1;
                results.subclust1 = subclust1;
                results.colMap = colM;
            else
                colM = resultsPrev.colMap;
                removed_clusts1 = resultsPrev.removed_clusts1;
                subclust1 = resultsPrev.subclust1;
                subclust2 = results.subclusters;
                nrClust1 = length(subclust1);
                nrClust2 = length(subclust2);
                clust_ind1 = 1:nrClust1;
                clust_ind2 = 1:nrClust2;
                
                if rem_noise
                    for nn = 1:length(removed_clusts1)
                        clust_ind1(clust_ind1 == removed_clusts1(nn)) = [];
                    end
                    for nn = 1:length(results.removedInd)
                        clust_ind2(clust_ind2 == results.removedInd(nn)) = [];
                    end
                end
                clust_ind = clust_ind2;
                [colM,DiceVals] = match_colors(colM,clust_ind1,...
                    clust_ind2,subclust1,subclust2);
            end
        else
            subclust1 = results.subclusters;
            nrClust1 = length(subclust1);
            if nrClust1 < 20
                load colMap_small
            elseif nrClust1 >= 20 && nrClust1 < 27
                load colMap_mid
            else
                load colMap_large
            end
        end
        colM = [[0 0 0];[1 1 1]*0.8;colM];
        Un = nonzeros(unique(brain.img));
        Br_tmp = brain.img;
        for mm = 1:length(Un)
            Br_tmp(find(brain.img==Un(mm))) = mm;
        end
        brain.img = Br_tmp;
        
        nr_intens_lev = length(brain.img);
        Cm = colormap(gray(nr_intens_lev));
        FirstGrayLev = size(colM,1);
        colM = [colM;Cm];
        
        set(gcf,'Colormap',colM)
        D = double(brain.img);
        D = D + FirstGrayLev;
        D = D + 1;        
        % listn klusterit:
        clear LL
        rem_nonactive = 0;    
        plot_contrast_im = 0;

        if plot_contrast_im
            load contrastIm
            for mm = 1:2
                D(find(D==FirstGrayLev)) = 0;
                if mm == 1
                    D( T.img <= 64 ) = mm + 2;
                else
                    D( T.img > 129 & T.img < 210 ) = mm + 2;
                end
                LL(mm) = length(results.clustInd{clust_ind(mm)});
            end
            colM(4,:) = [1 0 0];
            colM(3,:) = [0 0 1];
            set(gcf,'Colormap',colM)
        else
            for mm = 1:length(clust_ind)
                D(find(D==FirstGrayLev)) = 0;
                for mmm = 1:length(results.subclusters{clust_ind(mm)})
                    D(results.subclusters{clust_ind(mm)}{mmm}) = mm + 2;
                end
                LL(mm) = length(results.clustInd{clust_ind(mm)});
            end
        end
        if rem_nonactive
            D(~mask) = nonactiveRegions;
        end
        
        IM = zeros(size(brain.img,2),size(brain.img,1),1,size(brain.img,3));
        for mm = 1:size(brain.img,3)
            IM(:,:,mm) = rot90(D(:,:,mm));
        end
        
        IM2 = IM;
        un = unique(IM);
        un = un(find(un > 1 & un < nr_clust+3));
        for s = 1:length(un)
            IM2(IM==un(s)) = (s+1);
        end
end

Ha = gca;
ite = 1;
IM2 = squeeze(IM);
IMvol = IM2;

if plot_ground_truth
    leikkeet = leikkeet(3);
    nrColumns = 1;
end
IM2 = IM2(:,:,leikkeet);

ZZ = size(IM2,3);
IM2 = IM2(:,:,ZZ:-1:1);

% plot images:
NN = floor(ZZ/nrColumns);
L = 2;
offset_y = 8;
offset_x = textOffset;

Dtot = [];
for m = 1:NN
    D = [];
    for n = 1:nrColumns
        D = [D size(colM,1)*ones(size(squeeze(IM2(:,:,ite)),1),L) ...
            squeeze(IM2(:,:,ite))];
        paikka_x(m,n) = size(D,2)-size(IM2,2)-L + 1 + offset_x;
        paikka_y(m,n) = offset_y + (size(IM2,1)+L)*(m-1);
        ite = ite + 1;
    end
    D = [D size(colM,1)*ones(size(IM2,1),L)];
    D = [D;size(colM,1)*ones(L,size(D,2))];
    Dtot = [Dtot;D];
end

image(Dtot);

Voxel_index_of_mni_origin = round(inv(AffMat)*[0 0 0 1]')+1;
vox_ind = [Voxel_index_of_mni_origin(1)*ones(size(brain.img,3),1)-1 ...
    Voxel_index_of_mni_origin(2)*ones(size(brain.img,3),1)-1 ...
    (1:size(brain.img,3))'-1 ones(size(brain.img,3),1)];
B = AffMat*(vox_ind)';

Zval2 = round(B(3,:)*10)/10;
Zval2 = Zval2(leikkeet);
Zval2 = Zval2(length(Zval2):-1:1);

switch style
    case 'clustermap'
        nr_clips = length(clusterNames);
        if Dat == 0
            H = colorbar('peer',Ha);
            yt = 3:(length(clust_ind)+2);
            set(H,'Direction','reverse')
            yt = (size(get(gcf,'Colormap'),1)-length(clust_ind)-1):(size(get(gcf,'Colormap'),1)-1);
            set(H,'YLim',[yt(1) yt(end)])
            set(H,'YTick',yt+0.5,'YTickLabel',clusterNames)
        elseif Dat == 1
            H = colorbar('peer',Ha);
            yt = 3:(length(clust_ind)+2);
            set(H,'Direction','reverse')
            yt = (size(get(gcf,'Colormap'),1)-length(clust_ind)-1):(size(get(gcf,'Colormap'),1)-1);
            set(H,'YLim',[yt(1) yt(end)])
            set(H,'YTick',yt+0.5)
            L = get(H,'Label');
            set(L,'String','Dice index between the clusters in the two data sets')
            set(H,'Location','westoutside')
            results.colbar = H;
            results.title = title('Data set #1');
        elseif Dat == 2
            set(resultsPrev.colbar,'YTickLabel',DiceVals)
            results.title = title('Data set #2');
            results.DiceVals = DiceVals;
        else
            
        end
end
hold on
ite = 1;
for m = 1:NN
    for n = 1:nrColumns
        fl = 0;
        if Zval2(ite) > 0
            S = '';
        else
            S = char(hex2dec('2212'));
            fl = 1;
        end
        if Zval2(ite) == round(Zval2(ite))
            D = '.0';
        else
            D = '';
        end
        Zval2(ite) = abs(Zval2(ite));
        if length([S num2str(Zval2(ite)) D]) < 4
            paikka_x(m,n) = paikka_x(m,n) + 5;
        elseif length([S num2str(Zval2(ite)) D]) == 5
            paikka_x(m,n) = paikka_x(m,n) - 7;    
        elseif length([S num2str(Zval2(ite)) D]) == 4 && fl
            paikka_x(m,n) = paikka_x(m,n) - 3;
        else
            
        end
        t = [S num2str(Zval2(ite)) D];
        t = t(1:end-2);
        text(paikka_x(m,n),paikka_y(m,n),t,'Color',[1 1 1]*1,'Units','data','FontSize',7)
        ite = ite + 1;
    end
end

axis tight
axis equal
axis off
zoom on