% This script calculates the accuracy of the webmill data (cerebellum
% cross-sections and simulated cylinder) with respect to number of coverages.
% STAPLER is used to perform the analysis.

% simulation settings
mci = 10; % number of Monte Carlo Iterations
covs = 3:15; % number of coverages
frac_training = [0 0.1 0.25 0.5 1];
results = zeros(length(covs), mci, length(frac_training), 3, 4);
res = cell([1 length(covs)]);
for ci = 1:length(covs)
    res{ci} = zeros(mci, length(frac_training), 3, 4);
end

% algorithm settings
epsilon = 0.001;
prior_flag = 0;
init_flag = 0;
cons_flag = 0;

%
% collect all of the data
%
axial_st_slice = 1;
sagittal_st_slice = 71;
coronal_st_slice = 1;
sim_st_slice = 1;

% settings for the current data
axial_data_dir = '../example-data/cerebellum/axial/testing/obs/';
axial_truth_dir = '../example-data/cerebellum/axial/testing/truth/';
sagittal_data_dir = '../example-data/cerebellum/sagittal/testing/obs/';
sagittal_truth_dir = '../example-data/cerebellum/sagittal/testing/truth/';
coronal_data_dir = '../example-data/cerebellum/coronal/testing/obs/';
coronal_truth_dir = '../example-data/cerebellum/coronal/testing/truth/';
sim_data_dir = '../example-data/cylinder-simulation/testing/obs/';
sim_truth_dir = '../example-data/cylinder-simulation/testing/truth/';

% load the truths
axial_truth = get_webmill_truth(axial_truth_dir);
sagittal_truth = get_webmill_truth(sagittal_truth_dir);
coronal_truth = get_webmill_truth(coronal_truth_dir);
sim_truth = reorder_array(get_webmill_truth(sim_truth_dir));

% create the observation structs
axial_obs = create_obs('slice', size(axial_truth));
sagittal_obs = create_obs('slice', size(sagittal_truth));
coronal_obs = create_obs('slice', size(coronal_truth));
sim_obs = create_obs('slice', size(sim_truth));

% get the filenames
axial_fnames = dir([axial_data_dir, '*.png']);
sagittal_fnames = dir([sagittal_data_dir, '*.png']);
coronal_fnames = dir([coronal_data_dir, '*.png']);
sim_fnames = dir([sim_data_dir, '*.png']);

% load the observations
for i = 1:length(axial_fnames)
    axial_obs = add_webmill_obs(axial_obs, [axial_data_dir, ...
                                axial_fnames(i).name], axial_st_slice);
end
for i = 1:length(sagittal_fnames)
    sagittal_obs = add_webmill_obs(sagittal_obs, [sagittal_data_dir, ...
                                   sagittal_fnames(i).name], sagittal_st_slice);
end
for i = 1:length(coronal_fnames)
    coronal_obs = add_webmill_obs(coronal_obs, [coronal_data_dir, ...
                                  coronal_fnames(i).name], coronal_st_slice);
end
for i = 1:length(sim_fnames)
    sim_obs = add_webmill_obs(sim_obs, [sim_data_dir, ...
                              sim_fnames(i).name], sim_st_slice);
    sim_obs.data{i} = reorder_array(sim_obs.data{i});
end

for ci = 1:length(covs)
    c = covs(ci);

    for m = 1:mci

        for f = 1:length(frac_training)

            % get a random collection of observations to construct the coverages
            % NOTE: In some cases it will not be possible to get full coverages
            %       we will just all available observations for those slices in
            %       those scenarios.
            axial_obs_cov = get_random_coverages(axial_obs, c);
            sagittal_obs_cov = get_random_coverages(sagittal_obs, c);
            coronal_obs_cov = get_random_coverages(coronal_obs, c);
            sim_obs_cov = get_random_coverages(sim_obs, c);

            % STAPLER (axial)
            num_labels = length(unique(axial_truth));
            num_raters = length(unique(axial_obs_cov.raters));
            bias = zeros([num_labels, num_labels, num_raters]);
            if (frac_training(f) > 0)
                axial_obs_tr = get_fraction_training(axial_obs, ...
                                                     axial_obs_cov, ...
                                                     frac_training(f));
                [bias axial_tr_theta] = construct_theta_bias(axial_obs_tr, ...
                                                             axial_truth, ...
                                                             axial_obs_cov, ...
                                                             cons_flag);
            end
            [SRZ_est SRZ_W SRZ_theta] = STAPLER(axial_obs_cov, epsilon, ...
                                                prior_flag, init_flag, ...
                                                cons_flag, bias);
            res{ci}(m, f, 1, 1) = fraction_correct(axial_truth, SRZ_est);
            res{ci}(m, f, 2, 1) = fraction_correct(axial_truth, SRZ_est, ...
                                                axial_obs);
            res{ci}(m, f, 3, 1) = dice(axial_truth, SRZ_est);

            % STAPLER (sagittal)
            num_labels = length(unique(sagittal_truth))+1;
            num_raters = length(unique(sagittal_obs_cov.raters));
            bias = zeros([num_labels, num_labels, num_raters]);
            if (frac_training(f) > 0)
                sagittal_obs_tr = get_fraction_training(sagittal_obs, ...
                                                     sagittal_obs_cov, ...
                                                     frac_training(f));
                [bias sagittal_tr_theta] = construct_theta_bias(...
                                                           sagittal_obs_tr, ...
                                                           sagittal_truth, ...
                                                           sagittal_obs_cov, ...
                                                           cons_flag);
            end
            [SRZ_est SRZ_W SRZ_theta] = STAPLER(sagittal_obs_cov, epsilon, ...
                                            prior_flag, init_flag, bias);
            res{ci}(m, f, 1, 2) = fraction_correct(sagittal_truth, SRZ_est);
            res{ci}(m, f, 2, 2) = fraction_correct(sagittal_truth, SRZ_est, ...
                                                sagittal_obs);
            res{ci}(m, f, 3, 2) = dice(sagittal_truth, SRZ_est);

            % STAPLER (coronal)
            num_labels = length(unique(coronal_truth));
            num_raters = length(unique(coronal_obs_cov.raters));
            bias = zeros([num_labels, num_labels, num_raters]);
            if (frac_training(f) > 0)
                coronal_obs_tr = get_fraction_training(coronal_obs, ...
                                                     coronal_obs_cov, ...
                                                     frac_training(f));
                [bias coronal_tr_theta] = construct_theta_bias(...
                                                            coronal_obs_tr, ...
                                                            coronal_truth, ...
                                                            coronal_obs_cov, ...
                                                            cons_flag);
            end
            [SRZ_est SRZ_W SRZ_theta] = STAPLER(coronal_obs_cov, epsilon, ...
                                            prior_flag, init_flag, bias);
            res{ci}(m, f, 1, 3) = fraction_correct(coronal_truth, SRZ_est);
            res{ci}(m, f, 2, 3) = fraction_correct(coronal_truth, SRZ_est, ...
                                                coronal_obs);
            res{ci}(m, f, 3, 3) = dice(coronal_truth, SRZ_est);

            % STAPLER (cylinder-simulation)
            num_labels = length(unique(sim_truth));
            num_raters = length(unique(sim_obs_cov.raters));
            bias = zeros([num_labels, num_labels, num_raters]);
            if (frac_training(f) > 0)
                sim_obs_tr = get_fraction_training(sim_obs, ...
                                                   sim_obs_cov, ...
                                                   frac_training(f));
                [bias sim_tr_theta] = construct_theta_bias(sim_obs_tr, ...
                                                           sim_truth, ...
                                                           sim_obs_cov, ...
                                                           cons_flag);
            end
            [SRZ_est SRZ_W SRZ_theta] = STAPLER(sim_obs_cov, epsilon, ...
                                            prior_flag, init_flag, bias);
            res{ci}(m, f, 1, 4) = fraction_correct(sim_truth, SRZ_est);
            res{ci}(m, f, 2, 4) = fraction_correct(sim_truth, SRZ_est, ...
                                                sim_obs);
            res{ci}(m, f, 3, 4) = dice(sim_truth, SRZ_est);
        end
    end
end

results = zeros([length(covs) mci, length(frac_training), 3, 4]);

for c = 1:length(covs)
    results(c, :, :, :, :) = res{c};
end

t = 2;

ya1 = mean(squeeze(results(:, :, 1, t, 1)), 2);
ys1 = mean(squeeze(results(:, :, 1, t, 2)), 2);
yc1 = mean(squeeze(results(:, :, 1, t, 3)), 2);
yy1 = mean(squeeze(results(:, :, 1, t, 4)), 2);

ya2 = mean(squeeze(results(:, :, 2, t, 1)), 2);
ys2 = mean(squeeze(results(:, :, 2, t, 2)), 2);
yc2 = mean(squeeze(results(:, :, 2, t, 3)), 2);
yy2 = mean(squeeze(results(:, :, 2, t, 4)), 2);

ya3 = mean(squeeze(results(:, :, 3, t, 1)), 2);
ys3 = mean(squeeze(results(:, :, 3, t, 2)), 2);
yc3 = mean(squeeze(results(:, :, 3, t, 3)), 2);
yy3 = mean(squeeze(results(:, :, 3, t, 4)), 2);

ya4 = mean(squeeze(results(:, :, 4, t, 1)), 2);
ys4 = mean(squeeze(results(:, :, 4, t, 2)), 2);
yc4 = mean(squeeze(results(:, :, 4, t, 3)), 2);
yy4 = mean(squeeze(results(:, :, 4, t, 4)), 2);

ya5 = mean(squeeze(results(:, :, 5, t, 1)), 2);
ys5 = mean(squeeze(results(:, :, 5, t, 2)), 2);
yc5 = mean(squeeze(results(:, :, 5, t, 3)), 2);
yy5 = mean(squeeze(results(:, :, 5, t, 4)), 2);

ea1 = std(squeeze(results(:, :, 1, t, 1)), [], 2);
es1 = std(squeeze(results(:, :, 1, t, 2)), [], 2);
ec1 = std(squeeze(results(:, :, 1, t, 3)), [], 2);
ey1 = std(squeeze(results(:, :, 1, t, 4)), [], 2);

ea2 = std(squeeze(results(:, :, 2, t, 1)), [], 2);
es2 = std(squeeze(results(:, :, 2, t, 2)), [], 2);
ec2 = std(squeeze(results(:, :, 2, t, 3)), [], 2);
ey2 = std(squeeze(results(:, :, 2, t, 4)), [], 2);

ea3 = std(squeeze(results(:, :, 3, t, 1)), [], 2);
es3 = std(squeeze(results(:, :, 3, t, 2)), [], 2);
ec3 = std(squeeze(results(:, :, 3, t, 3)), [], 2);
ey3 = std(squeeze(results(:, :, 3, t, 4)), [], 2);

ea4 = std(squeeze(results(:, :, 4, t, 1)), [], 2);
es4 = std(squeeze(results(:, :, 4, t, 2)), [], 2);
ec4 = std(squeeze(results(:, :, 4, t, 3)), [], 2);
ey4 = std(squeeze(results(:, :, 4, t, 4)), [], 2);

ea5 = std(squeeze(results(:, :, 5, t, 1)), [], 2);
es5 = std(squeeze(results(:, :, 5, t, 2)), [], 2);
ec5 = std(squeeze(results(:, :, 5, t, 3)), [], 2);
ey5 = std(squeeze(results(:, :, 5, t, 4)), [], 2);

figure(1);
errorbar(covs, ya1, ea1, 'k');
hold on;
errorbar(covs, ya2, ea2, 'b');
errorbar(covs, ya3, ea3, 'r');
errorbar(covs, ya4, ea4, 'g');
errorbar(covs, ya5, ea5, 'c');
hold off

figure(2);
errorbar(covs, ys1, es1, 'k');
hold on;
errorbar(covs, ys2, es2, 'b');
errorbar(covs, ys3, es3, 'r');
errorbar(covs, ys4, es4, 'g');
errorbar(covs, ys5, es5, 'c');
hold off

figure(3);
errorbar(covs, yc1, ec1, 'k');
hold on;
errorbar(covs, yc2, ec2, 'b');
errorbar(covs, yc3, ec3, 'r');
errorbar(covs, yc4, ec4, 'g');
errorbar(covs, yc5, ec5, 'c');
hold off

figure(4);
errorbar(covs, yy1, ey1, 'k');
hold on;
errorbar(covs, yy2, ey2, 'b');
errorbar(covs, yy3, ey3, 'r');
errorbar(covs, yy4, ey4, 'g');
errorbar(covs, yy5, ey5, 'c');
hold off
