% This script tests the accuracy of STAPLE and STAPLER on the cerebellum
% (axial) example data contributed by minimally trained people at
% Vanderbilt University

% algorithm settings
epsilon = 0.001;
prior_flag = 0;
init_flag = 0;
cons_flag = 0;

%
% run the testing data
%

% settings for the simulation
st_slice = 1;
data_dir = '../example-data/cerebellum/axial/testing/obs/';
truth_dir = '../example-data/cerebellum/axial/testing/truth/';
fnames = dir([data_dir, '*.png']);

% load the truth
truth = get_webmill_truth(truth_dir);

% set the training and testing truth
training_truth = truth(:, :, 1:2:end);
testing_truth = truth(:, :, 2:2:end);

% create the observation struct
obs_training = create_obs('slice', size(training_truth));
obs_testing = create_obs('slice', size(testing_truth));
for i = 1:length(fnames)

    % training observations
    obs_training = add_webmill_obs(obs_training, [data_dir, fnames(i).name], ...
                                   st_slice, 'odd');
    obs_training.slices = reorder_array(obs_training.slices)+1;

    % testing observations
    obs_testing = add_webmill_obs(obs_testing, [data_dir, fnames(i).name], ...
                                  st_slice, 'even');
    obs_testing.slices = reorder_array(obs_testing.slices)+1;
end

% construct the training theta and theta bias
[theta_bias training_theta] = construct_theta_bias(obs_training, ...
                                            training_truth, obs_testing, 0);
zero_bias = zeros(size(theta_bias));

%
% run the algorithms
%

% STAPLE on each slice (guaranteed full coverage)
S_est = zeros(obs_testing.dims);
for s = 1:obs_testing.dims(3)
    obs_slice = get_slice_from_obs(obs_testing, s);

    if obs_slice.num_obs > 0

        [temp_est temp_W temp_theta] = STAPLE(obs_slice, epsilon, ...
                                              prior_flag, init_flag, cons_flag);
        S_est(:, :, s) = temp_est;
    end
end

% majority vote
MV_est = majority_vote(obs_testing);

% STAPLER zero bias
[SRZ_est SRZ_W SRZ_theta] = STAPLER(obs_testing, epsilon, prior_flag, ...
                                    init_flag, cons_flag, zero_bias);

% STAPLER Training Data
[SRT_est SRT_W SRT_theta] = STAPLER(obs_testing, epsilon, prior_flag, ...
                                    init_flag, cons_flag, theta_bias);

% print the accuracy levels to the screen
disp(['*** Results on Testing Data ***']);
MV_accuracy = fraction_correct(testing_truth, MV_est);
disp(['Majority Vote Accuracy: ', num2str(MV_accuracy)]);
S_accuracy = fraction_correct(testing_truth, S_est);
disp(['STAPLE Accuracy: ', num2str(S_accuracy)]);
SRZ_accuracy = fraction_correct(testing_truth, SRZ_est);
disp(['STAPLER (No Bias) Accuracy: ', num2str(SRZ_accuracy)]);
SRT_accuracy = fraction_correct(testing_truth, SRT_est);
disp(['STAPLER (Training) Accuracy: ', num2str(SRT_accuracy)]);

% organize the results so we can plot them
results = zeros(obs_testing.dims(3), 3);
results_obs = -1 * ones([obs_testing.dims(3) 100]);
for s = 1:obs_testing.dims(3)
    results(s, 1) = fraction_correct(testing_truth(:, :, s), S_est(:, :, s));
    results(s, 2) = fraction_correct(testing_truth(:, :, s), SRT_est(:, :, s));
    results(s, 3) = fraction_correct(testing_truth(:, :, s), MV_est(:, :, s));
    sinds = find(obs_testing.slices == s);
    for i = 1:length(sinds)
        results_obs(s, i) = fraction_correct(testing_truth(:, :, s), ...
                                          obs_testing.data{sinds(i)});
    end
end
results_obs(results_obs < 0) = NaN;

% Let's do a couple monte carlo iterations using 65% of the data so that we can
% assess the DSC variation
mci = 10;
frac = 0.65;
results_dice = zeros(mci, 6, 3);
for i = 1:mci

    % the new testing set of observations
    obs_frac = get_fraction_obs(obs_testing, frac);

    % construct the theta bias
    [theta_bias training_theta] = construct_theta_bias(...
                                    obs_training, training_truth, obs_frac, 0);

    % STAPLE on each slice (guaranteed full coverage)
    Sf_est = zeros(obs_frac.dims);
    for s = 1:obs_frac.dims(3)
        obs_slice = get_slice_from_obs(obs_frac, s);

        if obs_slice.num_obs > 0

            [temp_est temp_W temp_theta] = STAPLE(obs_slice, epsilon, ...
                                                  prior_flag, init_flag, ...
                                                  cons_flag);
            Sf_est(:, :, s) = temp_est;
        end
    end

    % STAPLER Training Data
    [SRTf_est SRTf_W SRTf_theta] = STAPLER(obs_frac, epsilon, prior_flag, ...
                                           init_flag, cons_flag, theta_bias);

    % Majority Vote Training Data
    [MVf_est] = majority_vote(obs_frac);

    [m d u] = dice(testing_truth, Sf_est);
    results_dice(i, :, 1) = d;
    [m d u] = dice(testing_truth, SRTf_est);
    results_dice(i, :, 2) = d;
    [m d u] = dice(testing_truth, MVf_est);
    results_dice(i, :, 3) = d;
end
rd = results_dice(:, 2:end, 2) - results_dice(:, 2:end, 1);
rd2 = results_dice(:, 2:end, 2) - results_dice(:, 2:end, 3);

% set the slice numbers and observation indices
sl = 16;
sinds = find(obs_testing.slices == sl);
obs_inds = [6 7 12];

% minor truth fix
testing_truth(99:end, 76:end, sl) = 0;
testing_truth(25:30, 73, sl) = 3;

% set the colormap values
cmap = [1 1 1; 0 0 1; 0 1 0; 1 0 1; 1 0 0; 1 1 0];

% plot the per slice accuracy
figure(1);
subplot(3, 1, 1:2);
plot(1:obs_testing.dims(3), results(:, 1), 'b', 'LineWidth', 3);
hold on;
plot(1:obs_testing.dims(3), results(:, 2), 'g', 'LineWidth', 3);
plot(1:obs_testing.dims(3), results(:, 3), 'r', 'LineWidth', 3);
boxplot(results_obs');
ylim([0.7 1.005]);
legend({'STAPLE'; 'STAPLER'; 'Majority Vote'}, 0);
hold off;
subplot(3, 1, 3);
hist(obs_testing.slices, 1:obs_testing.dims(3));
xlim([0.5 obs_testing.dims(3)+0.5]);
ylim([0 25]);

% plot the per label accuracy
figure(2);
boxplot(rd);
hold on;
plot([0 6], [0 0], 'k--');
hold off;
ylim([-0.02 0.2]);

% plot the per label accuracy
figure(3);
boxplot(rd2);
hold on;
plot([0 6], [0 0], 'k--');
hold off;
ylim([-0.02 0.2]);

% create a figure showing the results
figure(4);
subplot(2, 3, 1);
imshow(testing_truth(:, :, sl));
caxis([0 5]);
colormap(cmap);
title('Truth Slice');

subplot(2, 3, 2);
imshow(S_est(:, :, sl));
caxis([0 5]);
colormap(cmap);
title('STAPLE Result');

subplot(2, 3, 3);
imshow(SRT_est(:, :, sl));
caxis([0 5]);
colormap(cmap);
title('STAPLER Result');

subplot(2, 3, 4);
imshow(obs_testing.data{sinds(obs_inds(1))});
caxis([0 5]);
colormap(cmap);
title('Example Observation 1');

subplot(2, 3, 5);
imshow(obs_testing.data{sinds(obs_inds(2))});
caxis([0 5]);
colormap(cmap);
title('Example Observation 2');

subplot(2, 3, 6);
imshow(obs_testing.data{sinds(obs_inds(3))});
caxis([0 5]);
colormap(cmap);
title('Example Observation 3');

