% This script demonstrates the potential for STAPLE to converge to an
% undesired local optimum for a given simulation. A boundary variation model
% is used with a collection of 50 raters.

% load the big brain
fp = fopen('../example-data/brain-atlas/brain-atlas-11obj-topology.raw','rb');
dat = fread(fp,inf,'uint8');
fclose(fp);
z = reshape(dat,[172 216 164]);

% load the little brain
fp = fopen('../example-data/brain-atlas/brain-atlas-11obj-topology-small.raw','rb');
dat = fread(fp,inf,'uint8');
fclose(fp);
smz = reshape(dat,[86 108 82]);

% select which volume to use
tempvol = smz(:, :, 34:35); % small one
vol = 0*tempvol;
U=unique(tempvol(:));
for i=0:(length(U)-1)
    vol(find(tempvol(:)==U(i+1)))=i;
end

% set the section for which to analyze
xs = 1:size(vol, 1);
ys = 1:size(vol, 2);
zs = 1:size(vol, 3);
vol = vol(xs, ys, zs);

% set the number of labels
[vol num_labels] = reorder_array(vol);

% set the number of raters
num_raters = 50;

% create the observation struct
obs = create_obs('slice', [length(xs) length(ys) length(zs)]);
obst = create_obs('slice', [length(xs) length(ys) length(zs)]);

shiftiness = linspace(0.5, 2, num_raters);

% create the true confusion matrices and add the observations
cms = zeros(num_labels, num_labels, 5);
for i = 1:num_raters

    % apply a boundary model
    data = apply_boundary_model(vol, shiftiness(i));
    for s = 1:size(data, 3)
        obs = add_obs(obs, data(:, :, s), s, i);
    end

    data = apply_boundary_model(vol, shiftiness(i));
    for s = 1:size(data, 3)
        obst = add_obs(obs, data(:, :, s), s, i);
    end

end

% STAPLE/STAPLER settings
epsilon = 0.00001;
prior_flag = 0;
init_flag = 0;
cons_flag = 0;
[theta_bias training_theta] = construct_theta_bias(obst, vol, obs);

% STAPLE
[estimate W theta] = STAPLE(obs, epsilon, prior_flag, init_flag, cons_flag);

% STAPLER
[est_SR W_SR theta_SR] = STAPLER(obs, epsilon, prior_flag, init_flag, ...
                                 cons_flag, theta_bias);

sl = round(size(vol, 3) / 2);

close all;

figure(1);
subplot(2, 2, 1);
imshow(vol(:, :, sl)');
colormap(jet);
caxis([0 num_labels-1]);
title('Truth');
subplot(2, 2, 2);
imshow(obs.data{1}(:, :, sl)');
colormap(jet);
caxis([0 num_labels-1]);
title('Observation');
subplot(2, 2, 3);
imshow(estimate(:, :, sl)');
colormap(jet);
caxis([0 num_labels-1]);
title('STAPLE');
subplot(2, 2, 4);
imshow(est_SR(:, :, sl)');
colormap(jet);
caxis([0 num_labels-1]);
title('STAPLER');

disp(sprintf('STAPLE Accuracy: %f', fraction_correct(vol, estimate)));
disp(sprintf('STAPLER Accuracy: %f', fraction_correct(vol, est_SR)));
