// Utilities and system includes
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <nifti1_io.h>
#include <nifti1.h>
#include "fusion_io.h"
#include "fusion_tools.h"
#include "NonLocalSTAPLE.h"

void usage()
{
    printf("Non-Local STAPLE version 0.1\n\n");

    printf("Usage: nls [options] -l <labels> -i <intensities> -t <target> -o <outprefix>\n\n");

    printf("Required Options:\n");

    printf("\t-l, --labels <labels-file>\n");
    printf("\t\tThe atlas labels file (in .nii/.nii.gz format).\n\n");

    printf("\t-i, --intensities <intensities-file>\n");
    printf("\t\tThe atlas intensities file (in .nii/.nii.gz format).\n\n");

    printf("\t-t, --target <target-file>\n");
    printf("\t\tThe target intensity file (in .nii/.nii.gz format).\n\n");

    printf("\t-o, --output <outprefix>\n");
    printf("\t\tThe output prefix for the output files\n");
    printf("\t\tThe output files will append appropriate suffixes.\n\n");

    printf("Additional Options:\n");

    printf("\t-th, --thresh <val>\n");
    printf("\t\tThe convergence threshold for the E-M algorithm\n");
    printf("\t\tDefault value: 1e-5\n\n");

    printf("\t-sv, --search-volume <sx> <sy> <sz>\n");
    printf("\t\tDefines the search volume radii in the x, y, and z directions\n");
    printf("\t\tDefault value: 5 5 5\n\n");

    printf("\t-pv, --patch-volume <px> <py> <pz>\n");
    printf("\t\tDefines the patch volume radii in the x, y, and z directions\n");
    printf("\t\tDefault value: 1 1 1\n\n");

    printf("\t-si --stdev-intensity <sd>\n");
    printf("\t\tDefines the standard deviation for the gaussian\n");
    printf("\t\tintensity difference model (\\sigma_i)\n");
    printf("\t\tDefault value: 0.1\n\n");

    printf("\t-ss --stdev-spatial <sd>\n");
    printf("\t\tDefines the standard deviation for the gaussian\n");
    printf("\t\tdistance-based decay model (\\sigma_d)\n");
    printf("\t\tDefault value: 2\n\n");

    printf("\t-k, --num-keep <num>\n");
    printf("\t\tDefines the number of neighbors to 'keep'\n");
    printf("\t\tDefault value: 15\n\n");

    printf("\t-p, --prior <prior-file>\n");
    printf("\t\tAn explicit voxelwise prior (in .nii/.nii.gz format)\n");
    printf("\t\tShould be X x Y x Z x L, where L is the number of labels\n");
    printf("\t\tDefault value: (NONE) - uses default prior\n\n");

    printf("\t-w, --save-W\n");
    printf("\t\tSave the label probabilities in addition to the estimate\n");
    printf("\t\tDefault value: Off - don't save W\n\n");

    printf("\t-st, --save-theta\n");
    printf("\t\tSave the performance level parameters\n");
    printf("\t\tDefault value: Off - don't save theta\n\n");
    exit(0);
}

void set_input_args(int argc,
                    char ** argv,
                    char ** obsfile,
                    char ** imsfile,
                    char ** targfile,
                    intensity_t * epsilon,
                    size_t * ns,
                    size_t * nsf,
                    size_t * nc,
                    intensity_t * stdev,
                    intensity_t * sp_stdev,
                    char ** priorfile,
                    char ** out_prefix,
                    size_t * num_keep,
                    bool * save_labelprobs,
                    bool * save_theta)
{

    if (argc == 1)
        usage();

    // iterate over the input arguments
    for (int i = 1; i < argc; i++) {
        if (strcmp(argv[i], "-l") == 0 ||
            strcmp(argv[i], "--labels") == 0) {
            *obsfile = argv[++i];
        } else if (strcmp(argv[i], "-i") == 0 ||
                   strcmp(argv[i], "--intensities") == 0) {
            *imsfile = argv[++i];
        } else if (strcmp(argv[i], "-t") == 0 ||
                   strcmp(argv[i], "--target") == 0) {
            *targfile = argv[++i];
        } else if (strcmp(argv[i], "-th") == 0 ||
                   strcmp(argv[i], "--thresh") == 0) {
            *epsilon = atof(argv[++i]);
        } else if (strcmp(argv[i], "-sv") == 0 ||
                   strcmp(argv[i], "--search-volume") == 0) {
            ns[0] = (size_t)atoi(argv[++i]);
            ns[1] = (size_t)atoi(argv[++i]);
            ns[2] = (size_t)atoi(argv[++i]);
            nsf[0] = 2 * ns[0] + 1;
            nsf[1] = 2 * ns[1] + 1;
            nsf[2] = 2 * ns[2] + 1;
        } else if (strcmp(argv[i], "-pv") == 0 ||
                   strcmp(argv[i], "--patch-volume") == 0) {
            nc[0] = (size_t)atoi(argv[++i]);
            nc[1] = (size_t)atoi(argv[++i]);
            nc[2] = (size_t)atoi(argv[++i]);
        } else if (strcmp(argv[i], "-si") == 0 ||
                   strcmp(argv[i], "--stdev-intensity") == 0) {
            *stdev = atof(argv[++i]);
        } else if (strcmp(argv[i], "-ss") == 0 ||
                   strcmp(argv[i], "--stdev-spatial") == 0) {
            *sp_stdev = atof(argv[++i]);
        } else if (strcmp(argv[i], "-k") == 0 ||
                   strcmp(argv[i], "--num_keep") == 0) {
            *num_keep = (size_t)atoi(argv[++i]);
        } else if (strcmp(argv[i], "-p") == 0 ||
                   strcmp(argv[i], "--prior") == 0) {
            *priorfile = argv[++i];
        } else if (strcmp(argv[i], "-o") == 0 ||
                   strcmp(argv[i], "--output") == 0) {
            *out_prefix = argv[++i];
        } else if (strcmp(argv[i], "-w") == 0 ||
                   strcmp(argv[i], "--save-W") == 0) {
            *save_labelprobs = true;
        } else if (strcmp(argv[i], "-st") == 0 ||
                   strcmp(argv[i], "--save-theta") == 0) {
            *save_theta = true;
        } else {
            fprintf(stderr, "Error Reading Input Argument: %s\n\n", argv[i]);
            usage();
        }
    }

    // do some quick error checking
    bool error = false;
    if (*obsfile == NULL) {
        fprintf(stderr, "Error: Labels file not specified\n");
        error = true;
    }
    if (*imsfile == NULL) {
        fprintf(stderr, "Error: Intensities file not specified\n");
        error = true;
    }
    if (*targfile == NULL) {
        fprintf(stderr, "Error: Target file not specified\n");
        error = true;
    }
    if (*out_prefix == NULL) {
        fprintf(stderr, "Error: Output prefix not specified\n");
        error = true;
    }
    if (error) {
        fprintf(stderr, "\n");
        usage();
    }
}

void print_input_parms(char * obsfile,
                       char * imsfile,
                       char * targfile,
                       char * priorfile,
                       size_t * dims,
                       size_t * ns,
                       size_t * nc,
                       intensity_t stdev,
                       intensity_t sp_stdev,
                       size_t num_raters,
                       size_t num_labels,
                       size_t num_keep)
{

    printf("Atlas labels file: %s\n", obsfile);
    printf("Atlas intensities file: %s\n", imsfile);
    printf("Target intensity file: %s\n", targfile);
    if (priorfile != NULL)
        printf("Prior probabilities file: %s\n", priorfile);
    else
        printf("Prior probabilities file: NONE - using default prior\n");
    printf("Image Dimensions: [%lu %lu %lu]\n", dims[0], dims[1], dims[2]);
    printf("Search Neighborhood Radii: [%lu %lu %lu]\n", ns[0], ns[1], ns[2]);
    printf("Patch Neighborhood Radii: [%lu %lu %lu]\n", nc[0], nc[1], nc[2]);
    printf("Intensity similarity std dev: %f\n", stdev);
    printf("Distance decay std dev: %f\n", sp_stdev);
    printf("Number of raters: %lu\n", num_raters);
    printf("Number of labels: %lu\n", num_labels);
    printf("Number of voxels to keep: %lu\n", num_keep);
}

void initialize_consensus_estimate(bool *** consensus,
                                   label_t **** obs,
                                   label_t *** estimate,
                                   size_t * ns,
                                   size_t * dims,
                                   size_t num_raters,
                                   size_t num_labels)
{

    // allocate the space
    bool *** tcon;
    tcon = (bool ***)malloc(dims[0] * sizeof(*tcon));
    for (size_t x = 0; x < dims[0]; x++) {
        tcon[x] = (bool **)malloc(dims[1] * sizeof(*tcon[x]));
        consensus[x] = (bool **)malloc(dims[1] * sizeof(*consensus[x]));
        estimate[x] = (label_t **)malloc(dims[1] * sizeof(*estimate[x]));
        for (size_t y = 0; y < dims[1]; y++) {
            tcon[x][y] = (bool *)malloc(dims[2] * sizeof(*tcon[x][y]));
            consensus[x][y] =
                (bool *)malloc(dims[2] * sizeof(*consensus[x][y]));
            estimate[x][y] =
                (label_t *)malloc(dims[2] * sizeof(*estimate[x][y]));
        }
    }
    intensity_t * lp;
    lp = (intensity_t *)malloc(num_labels * sizeof(*lp));

    // iterate over all of the voxels
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++) {

                // initialize and set the label probabilities
                for (size_t l = 0; l < num_labels; l++)
                    lp[l] = 0;
                for (size_t j = 0; j < num_raters; j++)
                    lp[obs[x][y][z][j]]++;

                intensity_t max_val = 0;
                label_t max_label = 0;

                // find the max value and label
                for (size_t l = 0; l < num_labels; l++)
                    if (lp[l] > max_val) {
                        max_val = lp[l];
                        max_label = l;
                    }

                // set the estimate and the consensus values
                estimate[x][y][z] = max_label;
                tcon[x][y][z] = (max_val == num_raters) ? true : false;
                consensus[x][y][z] = true;
            }

    // iterate over all of the voxels
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (tcon[x][y][z] == false) {

                    // set the patch that we will be analyzing
                    int xl = MAX((int)x - ns[0], 0);
                    int xh = MIN((int)x + ns[0], dims[0]-1);
                    int yl = MAX((int)y - ns[1], 0);
                    int yh = MIN((int)y + ns[1], dims[1]-1);
                    int zl = MAX((int)z - ns[2], 0);
                    int zh = MIN((int)z + ns[2], dims[2]-1);

                    // iterate over the patch of interest
                    for (int xp = xl; xp <= xh; xp++)
                        for (int yp = yl; yp <= yh; yp++)
                            for (int zp = zl; zp <= zh; zp++)
                                consensus[xp][yp][zp] = false;
                }

    for (size_t x = 0; x < dims[0]; x++) {
        for (size_t y = 0; y < dims[1]; y++) {
            free(tcon[x][y]);
        }
        free(tcon[x]);
    }
    free(tcon);
    free(lp);
}

void allocate_non_local(label_t ***** obsnl,
                        intensity_t ***** alphanl,
                        bool *** consensus,
                        size_t * dims,
                        size_t num_raters,
                        size_t num_keep)
{
    // allocate the non-local model data
    for (size_t x = 0; x < dims[0]; x++) {
        obsnl[x] = (label_t ****)malloc(dims[1] * sizeof(*obsnl[x]));
        alphanl[x] = (intensity_t ****)malloc(dims[1] * sizeof(*alphanl[x]));
        for (size_t y = 0; y < dims[1]; y++) {
            obsnl[x][y] =
                (label_t ***)malloc(dims[2] * sizeof(*obsnl[x][y]));
            alphanl[x][y] =
                (intensity_t ***)malloc(dims[2] * sizeof(*alphanl[x][y]));
            for (size_t z = 0; z < dims[2]; z++) {
                if (consensus[x][y][z] == false) {
                    obsnl[x][y][z] =
                        (label_t **)malloc(num_raters *
                                                  sizeof(*obsnl[x][y][z]));
                    alphanl[x][y][z] =
                        (intensity_t **)malloc(num_raters *
                                          sizeof(*alphanl[x][y][z]));
                    for (size_t j = 0; j < num_raters; j++) {
                        obsnl[x][y][z][j] =
                            (label_t *)malloc(num_keep *
                                               sizeof(*obsnl[x][y][z][j]));
                        alphanl[x][y][z][j] =
                            (intensity_t *)malloc(num_keep *
                                             sizeof(*alphanl[x][y][z][j]));
                    }
                }
            }
        }
    }
}

void allocate_W(intensity_t **** W,
                size_t * dims,
                size_t num_labels)
{
    for (size_t x = 0; x < dims[0]; x++) {
        W[x] = (intensity_t ***)malloc(dims[1] * sizeof(*W[x]));
        for (size_t y = 0; y < dims[1]; y++) {
            W[x][y] = (intensity_t **)malloc(dims[2] * sizeof(*W[x][y]));
            for (size_t z = 0; z < dims[2]; z++) {
                W[x][y][z] = (intensity_t *)malloc(num_labels *
                                              sizeof(*W[x][y][z]));
            }
        }
    }
}

void allocate_theta(intensity_t *** theta,
                    size_t num_labels,
                    size_t num_raters)
{
    for (size_t l = 0; l < num_labels; l++) {
        theta[l] = (intensity_t **)malloc(num_labels * sizeof(*theta[l]));
        for (size_t s = 0; s < num_labels; s++) {
            theta[l][s] = (intensity_t *)malloc(num_raters * sizeof(*theta[l][s]));
        }
    }
}

void set_nonlocal_data(label_t **** obs,
                       label_t ***** obsnl,
                       intensity_t ***** alphanl,
                       intensity_t **** ims,
                       intensity_t *** target,
                       bool *** consensus,
                       size_t * dims,
                       size_t * ns,
                       size_t * nc,
                       size_t * nsf,
                       size_t num_raters,
                       size_t num_keep,
                       intensity_t stdev,
                       intensity_t sp_stdev)
{
    // initialize the data
    intensity_t ** alphar;
    intensity_t *** dist_factor;
    label_t ** obsr;
    intensity_t ff = -1 / pow(stdev, 2);
    size_t numel = nsf[0]*nsf[1]*nsf[2];

    // allocate the data
    dist_factor = (intensity_t ***)malloc(nsf[0] * sizeof(*dist_factor));
    for (size_t x = 0; x < nsf[0]; x++) {
        dist_factor[x] = (intensity_t **)malloc(nsf[1] *
                                                sizeof(*dist_factor[x]));
        for (size_t y = 0; y < nsf[1]; y++) {
            dist_factor[x][y] = (intensity_t *)malloc(nsf[2] *
                                                 sizeof(*dist_factor[x][y]));
        }
    }

    // allocate the data
    alphar = (intensity_t **)malloc(numel * sizeof(*alphar));
    obsr = (label_t **)malloc(numel * sizeof(*obsr));
    for (size_t i = 0; i < numel; i++) {
        alphar[i] = (intensity_t *)malloc(num_raters * sizeof(*alphar[i]));
        obsr[i] = (label_t *)malloc(num_raters * sizeof(*obsr[i]));
    }

    // set the distance factor
    set_dist_factor(nsf, ns, sp_stdev, dist_factor);

    // set the non-local data values for every voxel
    for (int x = 0; x < (int)dims[0]; x++) {
        print_status((size_t)x, dims[0]);
        for (int y = 0; y < (int)dims[1]; y++)
            for (int z = 0; z < (int)dims[2]; z++)
                if (consensus[x][y][z] == false)
                    set_alpha_for_voxel(x, y, z, num_raters, num_keep, dims,
                                        ns, nc, ff, obsnl, obsr, obs, alphanl,
                                        alphar, dist_factor, ims, target);

    }

    // free the allocated memory
    for (size_t x = 0; x < nsf[0]; x++) {
        for (size_t y = 0; y < nsf[1]; y++) {
            free(dist_factor[x][y]);
        }
        free(dist_factor[x]);
    }
    free(dist_factor);
    for (size_t i = 0; i < numel; i++) {
        free(obsr[i]);
        free(alphar[i]);
    }
    free(obsr);
    free(alphar);

}

void set_alpha_for_voxel(int x,
                         int y,
                         int z,
                         size_t num_raters,
                         size_t num_keep,
                         size_t * dims,
                         size_t * ns,
                         size_t * nc,
                         intensity_t ff,
                         label_t ***** obsnl,
                         label_t ** obsr,
                         label_t **** obs,
                         intensity_t ***** alphanl,
                         intensity_t ** alphar,
                         intensity_t *** dist_factor,
                         intensity_t **** ims,
                         intensity_t *** target)
{
    intensity_t diff, df, alpha_sum, max;
    size_t count, i, ii, j, k;
    int xp, yp, zp;

    // set the patch that we will be analyzing
    int xl = x - ns[0];
    int xh = x + ns[0];
    int yl = y - ns[1];
    int yh = y + ns[1];
    int zl = z - ns[2];
    int zh = z + ns[2];

    // set the patch that we will be analyzing
    int xl0 = MAX(xl, 0);
    int xh0 = MIN(xh, dims[0]-1);
    int yl0 = MAX(yl, 0);
    int yh0 = MIN(yh, dims[1]-1);
    int zl0 = MAX(zl, 0);
    int zh0 = MIN(zh, dims[2]-1);

    // get the number of elements
    size_t numel = (xh0 - xl0 + 1) * (yh0 - yl0 + 1) * (zh0 - zl0 + 1);

    // iterate over the patch of interest
    count = 0;
    for (xp = xl0; xp <= xh0; xp++)
        for (yp = yl0; yp <= yh0; yp++)
            for (zp = zl0; zp <= zh0; zp++) {

                // set the distance factor for this voxel
                df = dist_factor[xp-xl][yp-yl][zp-zl];

                for (j = 0; j < num_raters; j++) {
                    // get the local similarity
                    diff = get_norm_diff(x, y, z, xp, yp, zp, j,
                                         dims, nc, ims, target);
                    alphar[count][j] = df * exp(ff*diff);
                    obsr[count][j] = obs[xp][yp][zp][j];
                }

                // increment the count
                count++;
            }

    // iterate over the raters
    for (j = 0; j < num_raters; j++) {

        // find the top "num_keep" alpha values
        alpha_sum = 0;
        for (k = 0; k < num_keep; k++) {

            // find the max
            max = alphar[0][j];
            ii = 0;
            for (i = 1; i < numel; i++)
                if (alphar[i][j] > max) {
                    max = alphar[i][j];
                    ii = i;
                }

            // set the non-local parameters
            alphanl[x][y][z][j][k] = max;
            obsnl[x][y][z][j][k] = obsr[ii][j];
            alpha_sum += max;
            alphar[ii][j] = 0;
        }

        // normalize the alphas
        if (alpha_sum == 0) {
            alphanl[x][y][z][j][0] = 1;
            obsnl[x][y][z][j][0] = obs[x][y][z][j];
        } else {
            for (k = 0; k < num_keep; k++)
                alphanl[x][y][z][j][k] /= alpha_sum;
        }
    }
}

void set_dist_factor(size_t * nsf,
                     size_t * ns,
                     intensity_t sp_stdev,
                     intensity_t *** dist_factor)
{

    // set the standard deviations (allow for anisotropic)
    intensity_t sx = sp_stdev;
    intensity_t sy = sp_stdev;
    intensity_t sz = sp_stdev;

    // set the factor so that we only need to calculate it once
    intensity_t fx = -1 / pow(sx, 2);
    intensity_t fy = -1 / pow(sy, 2);
    intensity_t fz = -1 / pow(sz, 2);

    // set the distance factors for each element in the neighborhood (patch)
    for (int x = 0; x < (int)nsf[0]; x++)
        for (int y = 0; y < (int)nsf[1]; y++)
            for (int z = 0; z < (int)nsf[2]; z++)
                dist_factor[x][y][z] =
                    exp(fx * pow((intensity_t)((int)ns[0] - x), 2) +
                        fy * pow((intensity_t)((int)ns[1] - y), 2) +
                        fz * pow((intensity_t)((int)ns[2] - z), 2));
}

void set_initial_theta(intensity_t *** theta,
                       size_t num_raters,
                       size_t num_labels)
{

    // set to a 'majority vote' theta
    for (size_t l = 0; l < num_labels; l++)
        for (size_t s = 0; s < num_labels; s++)
            for (size_t j = 0; j < num_raters; j++)
                theta[l][s][j] = (l == s) ? 1 : 0.1;

    // normalize the estimate
    normalize_theta(theta, num_raters, num_labels);
}

void normalize_theta(intensity_t *** theta,
                     size_t num_raters,
                     size_t num_labels)
{

    // normalize the current value for theta
    for (size_t j = 0; j < num_raters; j++)
        for (size_t s = 0; s < num_labels; s++) {
            intensity_t sum = 0;
            for (size_t l = 0; l < num_labels; l++)
                sum += theta[l][s][j];
            for (size_t l = 0; l < num_labels; l++)
                if (sum > 0)
                    theta[l][s][j] /= sum;
                else
                    theta[l][s][j] = (l == s) ? 1 : 0;
        }
}

void reset_theta(intensity_t *** theta,
                 intensity_t *** theta_prev,
                 size_t num_raters,
                 size_t num_labels)
{
    for (size_t l = 0; l < num_labels; l++)
        for (size_t s = 0; s < num_labels; s++)
            for (size_t j = 0; j < num_raters; j++) {
                theta_prev[l][s][j] = theta[l][s][j];
                theta[l][s][j] = 0;
            }
}

intensity_t calc_convergence_factor(intensity_t *** theta,
                                    intensity_t *** theta_prev,
                                    size_t num_raters,
                                    size_t num_labels)
{

    intensity_t current_sum = 0;
    intensity_t previous_sum = 0;

    for (size_t j = 0; j < num_raters; j++)
        for (size_t l = 0; l < num_labels; l++) {
            current_sum += theta[l][l][j];
            previous_sum += theta_prev[l][l][j];
        }

    intensity_t diff = current_sum - previous_sum;
    intensity_t den = (intensity_t)(num_raters * num_labels);

    return(fabs(diff / den));
}

intensity_t get_norm_diff(int x,
                          int y,
                          int z,
                          int xp,
                          int yp,
                          int zp,
                          size_t j,
                          size_t * dims,
                          size_t * nc,
                          intensity_t **** ims,
                          intensity_t *** target)
{

    // set the patch that we will be analyzing
    int xi, yi, zi;
    int xh = nc[0];
    int yh = nc[1];
    int zh = nc[2];
    size_t xt, yt, zt, xj, yj, zj;
    size_t num_count = 0;
    intensity_t diff = 0;

    for (xi = -nc[0]; xi <= xh; xi++) {
        xt = x + xi; xj = xp + xi;
        if (nc[0] == 0 || MAX(xt,xj) < dims[0])
            for (yi = -nc[1]; yi <= yh; yi++) {
                yt = y + yi; yj = yp + yi;
                if (nc[1] == 0 || MAX(yt,yj) < dims[1])
                    for (zi = -nc[2]; zi <= zh; zi++) {
                        zt = z + zi; zj = zp + zi;
                        if (nc[2] == 0 || MAX(zt,zj) < dims[2]) {

                            // add the contribution from this voxel
                            diff += pow(target[xt][yt][zt] -
                                        ims[xj][yj][zj][j], 2);
                            num_count++;
                        }
                    }
            }
    }

    // return the average L2 difference
    return(diff /= num_count);
}

void run_EM_for_voxel(size_t x,
                      size_t y,
                      size_t z,
                      size_t num_raters,
                      size_t num_labels,
                      size_t num_keep,
                      size_t prior_type,
                      label_t ***** obsnl,
                      label_t *** estimate,
                      intensity_t ***** alphanl,
                      intensity_t *** theta,
                      intensity_t *** theta_prev,
                      intensity_t * lp,
                      intensity_t * pval,
                      intensity_t ** jsum)
{
    // variable declarations
    intensity_t sum = 0;
    intensity_t max_val = 0;
    label_t label = 0;
    intensity_t w;
    label_t l;

    // initialize the sums
    for (size_t s = 0; s < num_labels; s++)
        for (size_t j = 0; j < num_raters; j++)
            jsum[s][j] = 0;

    if (prior_type == 1) {
        // iterate over the patch
        for (size_t j = 0; j < num_raters; j++)
            for (size_t k = 0; k < num_keep; k++) {
                w = alphanl[x][y][z][j][k];
                if (w > 0) {

                    // get the observed label
                    l = obsnl[x][y][z][j][k];

                    // set the prior value if necessary
                    pval[l] += w / num_raters;

                    // get the contribution from each rater
                    for (size_t s = 0; s < num_labels; s++)
                        jsum[s][j] += theta_prev[l][s][j] * w;
                }
            }
    } else {
        // iterate over the patch
        for (size_t j = 0; j < num_raters; j++)
            for (size_t k = 0; k < num_keep; k++) {
                w = alphanl[x][y][z][j][k];
                if (w > 0) {

                    // get the observed label
                    l = obsnl[x][y][z][j][k];

                    // get the contribution from each rater
                    for (size_t s = 0; s < num_labels; s++)
                        jsum[s][j] += theta_prev[l][s][j] * w;
                }
            }
    }

    // calculate the preliminary label probability and the partition function
    for (size_t s = 0; s < num_labels; s++) {
        lp[s] = pval[s];
        for (size_t j = 0; j < num_raters; j++)
            lp[s] *= jsum[s][j];

        sum += lp[s];
    }

    // normalize the label probability
    for (size_t s = 0; s < num_labels; s++) {
        lp[s] /= sum;
        if (lp[s] > max_val) {
            max_val = lp[s];
            label = s;
        }
    }

    // set the current estimate for this voxel
    estimate[x][y][z] = label;

    // add the impact to theta
    for (size_t j = 0; j < num_raters; j++)
        for (size_t k = 0; k < num_keep; k++) {

            // get the alpha value
            w = alphanl[x][y][z][j][k];

            // add the component to theta
            if (w > 0) {
                l = obsnl[x][y][z][j][k];
                for (size_t s = 0; s < num_labels; s++)
                    theta[l][s][j] += lp[s] * w;
            }
        }

}

void run_NonLocalSTAPLE(label_t ***** obsnl,
                        intensity_t ***** alphanl,
                        label_t *** estimate,
                        intensity_t *** theta,
                        intensity_t **** prior,
                        intensity_t **** W,
                        intensity_t epsilon,
                        bool *** consensus,
                        size_t * dims,
                        size_t num_raters,
                        size_t num_labels,
                        size_t num_keep)
{

    // initialize some variables
    size_t prior_type = (prior == NULL) ? 1 : 0;
    intensity_t convergence_factor = 999999;
    intensity_t *** theta_prev,
           ** jsum,
           * lp,
           * pval;

    // allocate some space
    lp = (intensity_t *)malloc(num_labels * sizeof(*lp));
    pval = (intensity_t *)malloc(num_labels * sizeof(*pval));
    theta_prev = (intensity_t ***)malloc(num_labels * sizeof(*theta_prev));
    allocate_theta(theta_prev, num_labels, num_raters);
    jsum = (intensity_t **)malloc(num_labels * sizeof(*jsum));
    for (size_t s = 0; s < num_labels; s++)
        jsum[s] = (intensity_t *)malloc(num_raters * sizeof(*jsum[s]));

    // set the initial theta value
    set_initial_theta(theta, num_raters, num_labels);

    // iterate until convergence
    while (convergence_factor > epsilon) {

        // set theta_prev = theta and theta = 0
        reset_theta(theta, theta_prev, num_raters, num_labels);

        // iterate over every voxel
        for (size_t x = 0; x < dims[0]; x++) {
            print_status(x, dims[0]);
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++) {
                    if (!consensus[x][y][z]) {

                        // initialize the prior
                        for (size_t s = 0; s < num_labels; s++)
                            pval[s] = (prior_type == 0) ? prior[x][y][z][s] : 0;

                        // run EM for the voxel
                        run_EM_for_voxel(x, y, z, num_raters, num_labels,
                                         num_keep, prior_type, obsnl,
                                         estimate, alphanl, theta, theta_prev,
                                         lp, pval, jsum);

                    } else
                        for (size_t l = 0; l < num_labels; l++)
                            lp[l] = estimate[x][y][z] == l ? 1 : 0;

                    // set W if we have to
                    if (W != NULL)
                        for (size_t l = 0; l < num_labels; l++)
                            W[x][y][z][l] = lp[l];
                }
        }

        // normalize theta
        normalize_theta(theta, num_raters, num_labels);

        // set the convergence factor
        convergence_factor = calc_convergence_factor(theta, theta_prev,
                                                     num_raters, num_labels);

        // print the convergence factor to the screen
        fprintf(stdout, "Convergence Factor: %f\n", convergence_factor);
    }

    // free the allocated memory
    for (size_t l = 0; l < num_labels; l++) {
        for (size_t s = 0; s < num_labels; s++)
            free(theta_prev[l][s]);
        free(theta_prev[l]);
        free(jsum[l]);
    }
    free(theta_prev);
    free(jsum);
    free(lp);
    free(pval);
}

void save_results(size_t * dims,
                  char * outfile,
                  char * targfile,
                  label_t *** estimate)
{

    // read the nifti image
    nifti_image * nim;
    nim = nifti_image_read(targfile, 1);

    // set the output datatype
    nim->datatype = OUT_LABEL_TYPE;

    // set the target
    set_nim_3D<label_t>(nim, estimate, dims);

    // set the output names
    if (nim->fname != NULL) free(nim->fname);
    if (nim->iname != NULL) free(nim->iname);
    size_t ll = strlen(outfile);
    nim->fname = (char *)calloc(1, ll+8);
    strcpy(nim->fname, outfile);
    nim->iname = (char *)calloc(1, ll+8);
    strcpy(nim->iname, outfile);

    // save the file
    nifti_image_write(nim);

    // free the file
    nifti_image_free(nim);
}

