// Utilities and system includes
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <nifti1_io.h>
#include <nifti1.h>
#include "fusion_io.h"
#include "fusion_tools.h"
#include "Correction.h"

void usage()
{
    printf("Segmentation Correction version 0.1\n\n");

    printf("Usage: segcorr [options] -e <est> -t <target> -o <outfile>\n\n");

    printf("Required Options:\n");

    printf("\t-l, --labels <labels-file>\n");
    printf("\t\tThe atlas labels file (in .nii/.nii.gz format).\n\n");

    printf("\t-t, --target <target-file>\n");
    printf("\t\tThe target intensity file (in .nii/.nii.gz format).\n\n");

    printf("\t-o, --output <outfile>\n");
    printf("\t\tThe output file\n");

    printf("Additional Options:\n");

    printf("\t-sv, --search-volume <sx> <sy> <sz>\n");
    printf("\t\tDefines the search volume radii in the x, y, and z directions\n");
    printf("\t\tDefault value: 1 1 1\n\n");

    printf("\t-pv, --patch-volume <px> <py> <pz>\n");
    printf("\t\tDefines the patch volume radii in the x, y, and z directions\n");
    printf("\t\tDefault value: 1 1 1\n\n");

    printf("\t-c, --coefs <c1> <c2> <c3> <c4>\n");
    printf("\t\tDefines the MRF strength parameters\n");
    printf("\t\tDefault value: 1 1 1 1\n\n");

    printf("\t-si --stdev-intensity <sd>\n");
    printf("\t\tDefines the standard deviation for the gaussian\n");
    printf("\t\tintensity difference model (\\sigma_i)\n");
    printf("\t\tDefault value: 0.1\n\n");

    printf("\t-ss --stdev-spatial <sd>\n");
    printf("\t\tDefines the standard deviation for the gaussian\n");
    printf("\t\tdistance-based decay model (\\sigma_d)\n");
    printf("\t\tDefault value: 200\n\n");

    exit(0);
}

void set_input_args(int argc,
                    char ** argv,
                    char ** targfile,
                    char ** obsfile,
                    size_t * ns,
                    size_t * nc,
                    intensity_t * coefs,
                    intensity_t * stdev,
                    intensity_t * sp_stdev,
                    char ** outfile)
{

    if (argc == 1)
        usage();

    // iterate over the input arguments
    for (int i = 1; i < argc; i++) {
        if (strcmp(argv[i], "-l") == 0 ||
            strcmp(argv[i], "--labels") == 0) {
            *obsfile = argv[++i];
        } else if (strcmp(argv[i], "-t") == 0 ||
                   strcmp(argv[i], "--target") == 0) {
            *targfile = argv[++i];
        } else if (strcmp(argv[i], "-sv") == 0 ||
                   strcmp(argv[i], "--search-volume") == 0) {
            ns[0] = (size_t)atoi(argv[++i]);
            ns[1] = (size_t)atoi(argv[++i]);
            ns[2] = (size_t)atoi(argv[++i]);
        } else if (strcmp(argv[i], "-pv") == 0 ||
                   strcmp(argv[i], "--patch-volume") == 0) {
            nc[0] = (size_t)atoi(argv[++i]);
            nc[1] = (size_t)atoi(argv[++i]);
            nc[2] = (size_t)atoi(argv[++i]);
        } else if (strcmp(argv[i], "-c") == 0 ||
                   strcmp(argv[i], "--coefs") == 0) {
            coefs[0] = (size_t)atof(argv[++i]);
            coefs[1] = (size_t)atof(argv[++i]);
            coefs[2] = (size_t)atof(argv[++i]);
            coefs[3] = (size_t)atof(argv[++i]);
        } else if (strcmp(argv[i], "-si") == 0 ||
                   strcmp(argv[i], "--stdev-intensity") == 0) {
            *stdev = atof(argv[++i]);
        } else if (strcmp(argv[i], "-ss") == 0 ||
                   strcmp(argv[i], "--stdev-spatial") == 0) {
            *sp_stdev = atof(argv[++i]);
        } else if (strcmp(argv[i], "-o") == 0 ||
                   strcmp(argv[i], "--output") == 0) {
            *outfile = argv[++i];
        } else {
            fprintf(stderr, "Error Reading Input Argument: %s\n\n", argv[i]);
            usage();
        }
    }

    // do some quick error checking
    bool error = false;
    if (*targfile == NULL) {
        fprintf(stderr, "Error: Target file not specified\n");
        error = true;
    }
    if (*obsfile == NULL) {
        fprintf(stderr, "Error: Target file not specified\n");
        error = true;
    }
    if (*outfile == NULL) {
        fprintf(stderr, "Error: Output prefix not specified\n");
        error = true;
    }
    if (error) {
        fprintf(stderr, "\n");
        usage();
    }
}

void print_input_parms(char * targfile,
                       char * obsfile,
                       size_t * dims,
                       size_t * ns,
                       size_t * nc,
                       size_t num_raters,
                       size_t num_labels,
                       intensity_t * coefs,
                       intensity_t stdev,
                       intensity_t sp_stdev)
{

    printf("Target intensity file: %s\n", targfile);
    printf("Target estimate file: %s\n", obsfile);
    printf("Image Dimensions: [%lu %lu %lu]\n", dims[0], dims[1], dims[2]);
    printf("Search Neighborhood Radii: [%lu %lu %lu]\n", ns[0], ns[1], ns[2]);
    printf("Patch Neighborhood Radii: [%lu %lu %lu]\n", nc[0], nc[1], nc[2]);
    printf("MRF Coefficients: [%.4f %.4f %.4f %.4f]\n",
           coefs[0], coefs[1], coefs[2], coefs[3]);
    printf("Number of raters: %lu\n", num_raters);
    printf("Number of labels: %lu\n", num_labels);
    printf("Intensity similarity std dev: %f\n", stdev);
    printf("Distance decay std dev: %f\n", sp_stdev);
}

void save_results(size_t * dims,
                  char * outfile,
                  char * targfile,
                  label_t *** estimate)
{

    // read the nifti image
    nifti_image * nim;
    nim = nifti_image_read(targfile, 1);

    // set the output datatype
    nim->datatype = OUT_LABEL_TYPE;

    // set the target
    set_nim_3D<label_t>(nim, estimate, dims);

    // set the output names
    if (nim->fname != NULL) free(nim->fname);
    if (nim->iname != NULL) free(nim->iname);
    size_t ll = strlen(outfile);
    nim->fname = (char *)calloc(1, ll+8);
    strcpy(nim->fname, outfile);
    nim->iname = (char *)calloc(1, ll+8);
    strcpy(nim->iname, outfile);

    // save the file
    nifti_image_write(nim);

    // free the file
    nifti_image_free(nim);
}

void reset_target_intensity(intensity_t *** target,
                            label_t *** est,
                            size_t * dims)
{
    // calculate the mean and standard deviation of the target
    intensity_t mean = 0;
    intensity_t stdev = 0;
    size_t count = 0;

    // calculate the mean of the target
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (est[x][y][z] > 0) {
                    mean += target[x][y][z];
                    count++;
                }
    mean /= count;

    // calculate the standard deviation of the target
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (est[x][y][z] > 0)
                    stdev += pow(target[x][y][z] - mean, 2);
    stdev = sqrt(stdev / (count - 1));

    // normalize the target
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                target[x][y][z] = (target[x][y][z] - mean) / stdev;
}

void set_dist_factor(size_t * nsf,
                     size_t * ns,
                     intensity_t sp_stdev,
                     intensity_t *** dist_factor)
{

    // set the standard deviations (allow for anisotropic)
    intensity_t sx = sp_stdev;
    intensity_t sy = sp_stdev;
    intensity_t sz = sp_stdev;

    // set the factor so that we only need to calculate it once
    intensity_t fx = -1 / pow(sx, 2);
    intensity_t fy = -1 / pow(sy, 2);
    intensity_t fz = -1 / pow(sz, 2);

    // set the distance factors for each element in the neighborhood (patch)
    for (int x = 0; x < (int)nsf[0]; x++)
        for (int y = 0; y < (int)nsf[1]; y++)
            for (int z = 0; z < (int)nsf[2]; z++)
                dist_factor[x][y][z] =
                    exp(fx * pow((intensity_t)((int)ns[0] - x), 2) +
                        fy * pow((intensity_t)((int)ns[1] - y), 2) +
                        fz * pow((intensity_t)((int)ns[2] - z), 2));
}

void get_majority_vote_consensus(bool *** cons,
                                 label_t **** obs,
                                 size_t num_raters,
                                 size_t * dims)
{
    // iterate over all of the voxels
    for (int x = 0; x < dims[0]; x++) {
        print_status((size_t)x, dims[0]);
        for (int y = 0; y < dims[1]; y++)
            for (int z = 0; z < dims[2]; z++) {

                // initialize the value to true
                bool c = true;
                label_t lab = obs[x][y][z][0];

                // see if all raters agree
                for (size_t j = 1; j < num_raters && c == true; j++)
                    if (lab != obs[x][y][z][j])
                        c = false;

                // set the final value
                cons[x][y][z] = c;
            }
    }

}

void set_majority_vote_info(bool *** cons,
                            label_t **** obs,
                            size_t num_raters,
                            size_t num_labels,
                            size_t * dims,
                            intensity_t **** prior,
                            label_t *** est)
{

    for (int x = 0; x < dims[0]; x++) {
        print_status((size_t)x, dims[0]);
        for (int y = 0; y < dims[1]; y++)
            for (int z = 0; z < dims[2]; z++)
                if (cons[x][y][z] == false) {

                    // set the initial W
                    for (size_t j = 0; j < num_raters; j++) {
                        label_t l = obs[x][y][z][j];
                        prior[x][y][z][l]++;
                    }

                    intensity_t norm_fact = 0;
                    intensity_t max_val = 0;
                    label_t max_lab = 0;

                    // set the estimate and normalize the probabilities
                    for (size_t l = 0; l < num_labels; l++) {
                        if (prior[x][y][z][l] > max_val) {
                            max_val = prior[x][y][z][l];
                            max_lab = l;
                        }
                        norm_fact += prior[x][y][z][l];
                    }
                    est[x][y][z] = max_lab;
                    for (size_t l = 0; l < num_labels; l++)
                        prior[x][y][z][l] /= norm_fact;
                } else {
                    est[x][y][z] = obs[x][y][z][0];
                }
    }

}

void initialize_consensus_estimate(bool *** cons,
                                   label_t *** est,
                                   size_t * ns,
                                   size_t * dims)
{

    // iterate over all of the voxels
    for (int x = 0; x < dims[0]; x++) {
        print_status((size_t)x, dims[0]);
        for (int y = 0; y < dims[1]; y++)
            for (int z = 0; z < dims[2]; z++) {

                // initialize the value to true
                bool c = true;
                label_t lab = est[x][y][z];

                // set the patch that we will be analyzing
                int xl = MAX(x - (int)ns[0], 0);
                int xh = MIN(x + (int)ns[0], dims[0]-1);
                int yl = MAX(y - (int)ns[1], 0);
                int yh = MIN(y + (int)ns[1], dims[1]-1);
                int zl = MAX(z - (int)ns[2], 0);
                int zh = MIN(z + (int)ns[2], dims[2]-1);

                // iterate over the patch of interest
                for (int xp = xl; xp <= xh && c == true; xp++)
                    for (int yp = yl; yp <= yh && c == true; yp++)
                        for (int zp = zl; zp <= zh && c == true; zp++)
                            if (lab != est[xp][yp][zp])
                                c = false;

                // set the final value
                cons[x][y][z] = c;
            }
    }
}

void update_consensus_estimate(bool *** consensus,
                               bool *** consensus2,
                               label_t **** estnl,
                               intensity_t **** alphanl,
                               size_t num_keep,
                               size_t * dims)
{
    label_t lab;
    bool con;

    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (consensus[x][y][z] == false) {

                    // initialize
                    con = true;
                    lab = estnl[x][y][z][0];

                    // check if it is in consensus
                    for (size_t k = 1; k < num_keep && con == true; k++)
                        if (estnl[x][y][z][k] != lab && alphanl[x][y][z][k] > 0)
                            con = false;

                    // set the consensus value
                    consensus2[x][y][z] = con;
                } else
                    consensus2[x][y][z] = true;
}

intensity_t get_kl_fact(int xp,
                        int yp,
                        int zp,
                        size_t * dims,
                        size_t * nc,
                        intensity_t *** target,
                        intensity_t ff,
                        intensity_t * cdens,
                        intensity_t * pdens,
                        intensity_t * xkl,
                        size_t num_bins)
{

    // set the patch density
    set_target_density(xp, yp, zp, dims, nc, num_bins, ff, pdens, xkl, target);

    for (size_t b = 0; b < num_bins; b++)
        if (cdens[b] > 0 && pdens[b] > 0)
            pdens[b] = cdens[b] * log(cdens[b] / pdens[b]);
        else
            pdens[b] = 0;

    intensity_t fact = 0;
    for (size_t b = 0; b < num_bins-1; b++) {
        intensity_t num = (pdens[b] + pdens[b+1]) / 2;
        intensity_t den = xkl[b+1] - xkl[b];
        fact += num / den;
    }
    return(exp(-1.5 * fact));
}

intensity_t get_norm_diff(int x,
                          int y,
                          int z,
                          int xp,
                          int yp,
                          int zp,
                          size_t * dims,
                          size_t * nc,
                          intensity_t *** target)
{

    // set the patch that we will be analyzing
    int xi, yi, zi;
    int xh = nc[0];
    int yh = nc[1];
    int zh = nc[2];
    size_t xt, yt, zt, xj, yj, zj;
    size_t num_count = 0;
    intensity_t diff = 0;

    for (xi = -nc[0]; xi <= xh; xi++) {
        xt = x + xi; xj = xp + xi;
        if (nc[0] == 0 || MAX(xt,xj) < dims[0])
            for (yi = -nc[1]; yi <= yh; yi++) {
                yt = y + yi; yj = yp + yi;
                if (nc[1] == 0 || MAX(yt,yj) < dims[1])
                    for (zi = -nc[2]; zi <= zh; zi++) {
                        zt = z + zi; zj = zp + zi;
                        if (nc[2] == 0 || MAX(zt,zj) < dims[2]) {

                            // add the contribution from this voxel
                            diff += pow(target[xt][yt][zt] -
                                        target[xj][yj][zj], 2);
                            num_count++;
                        }
                    }
            }
    }

    // return the average L2 difference
    return(diff /= num_count);
}

void set_target_density(int x,
                        int y,
                        int z,
                        size_t * dims,
                        size_t * nc,
                        size_t num_bins,
                        intensity_t ff,
                        intensity_t * dens,
                        intensity_t * xkl,
                        intensity_t *** target)
{
    // set the patch that we will be analyzing
    int xl = x - nc[0];
    int xh = x + nc[0];
    int yl = y - nc[1];
    int yh = y + nc[1];
    int zl = z - nc[2];
    int zh = z + nc[2];

    // set the patch that we will be analyzing
    int xl0 = MAX(xl, 0);
    int xh0 = MIN(xh, dims[0]-1);
    int yl0 = MAX(yl, 0);
    int yh0 = MIN(yh, dims[1]-1);
    int zl0 = MAX(zl, 0);
    int zh0 = MIN(zh, dims[2]-1);

    for (size_t b = 0; b < num_bins; b++)
        dens[b] = 0;
    for (int xs = xl0; xs <= xh0; xs++)
        for (int ys = yl0; ys <= yh0; ys++)
            for (int zs = zl0; zs <= zh0; zs++) {
                intensity_t v = target[xs][ys][zs];
                for (size_t b = 0; b < num_bins; b++)
                    dens[b] += exp(ff * pow(xkl[b] - v, 2));
            }
    intensity_t fact = 0;
    for (size_t b = 0; b < num_bins-1; b++) {
        intensity_t num = (dens[b] + dens[b+1]) / 2;
        intensity_t den = xkl[b+1] - xkl[b];
        fact += num / den;
    }
    for (size_t b = 0; b < num_bins; b++)
        dens[b] /= fact;
}

void set_alpha_for_voxel(int x,
                         int y,
                         int z,
                         size_t num_keep,
                         size_t * dims,
                         size_t * ns,
                         size_t * nc,
                         intensity_t ff,
                         intensity_t **** alphanl,
                         intensity_t * alphar,
                         label_t ***** locnl,
                         label_t ** locr,
                         intensity_t *** target,
                         intensity_t *** dist_factor)
{
    intensity_t diff;
    size_t count;
    int xp, yp, zp;
    intensity_t max;
    intensity_t df;
    size_t k, i, ii;
    size_t numel = (2*ns[0]+1) * (2*ns[1]+1) * (2*ns[2]+1);

    // set the patch that we will be analyzing
    int xl = x - ns[0];
    int xh = x + ns[0];
    int yl = y - ns[1];
    int yh = y + ns[1];
    int zl = z - ns[2];
    int zh = z + ns[2];

    // set the patch that we will be analyzing
    int xl0 = MAX(xl, 0);
    int xh0 = MIN(xh, dims[0]-1);
    int yl0 = MAX(yl, 0);
    int yh0 = MIN(yh, dims[1]-1);
    int zl0 = MAX(zl, 0);
    int zh0 = MIN(zh, dims[2]-1);

    // iterate over the patch of interest
    count = 0;
    for (xp = xl0; xp <= xh0; xp++)
        for (yp = yl0; yp <= yh0; yp++)
            for (zp = zl0; zp <= zh0; zp++) {

                // set the distance factor
                df = dist_factor[xp-xl][yp-yl][zp-zl];

                // get the local similarity
                if (ff < -1e-5)
                    diff = get_norm_diff(x, y, z, xp, yp, z, dims, nc, target);
                else
                    diff = 0;
                alphar[count] = df * exp(ff*diff);
                locr[count][0] = xp;
                locr[count][1] = yp;
                locr[count][2] = zp;

                // increment the count
                count++;
            }

    // find the top "num_keep" alpha values
    intensity_t alpha_sum = 0;
    for (k = 0; k < num_keep; k++) {
        if (k < count) {
            alphanl[x][y][z][k] = alphar[k];
            locnl[x][y][z][k][0] = locr[k][0];
            locnl[x][y][z][k][1] = locr[k][1];
            locnl[x][y][z][k][2] = locr[k][2];
            alpha_sum += alphar[k];
        } else {
            alphanl[x][y][z][k] = -10000;
            locnl[x][y][z][k][0] = 0;
            locnl[x][y][z][k][1] = 0;
            locnl[x][y][z][k][2] = 0;
        }
    }
    for (k = 0; k < count; k++)
        alphanl[x][y][z][k] *= count / alpha_sum;

}

void set_nonlocal_data(intensity_t *** target,
                       bool *** consensus,
                       intensity_t **** alphanl,
                       label_t ***** locnl,
                       size_t * dims,
                       size_t * nc,
                       size_t * ns,
                       size_t * nsf,
                       size_t num_keep,
                       intensity_t stdev,
                       intensity_t sp_stdev)
{

    // initialize the data
    intensity_t ff = -1 / (2*pow(stdev, 2));
    intensity_t *** dist_factor;
    intensity_t * alphar;
    label_t ** locr;
    size_t numel = nsf[0]*nsf[1]*nsf[2];

    // allocate the data
    dist_factor = (intensity_t ***)malloc(nsf[0] * sizeof(*dist_factor));
    for (size_t x = 0; x < nsf[0]; x++) {
        dist_factor[x] = (intensity_t **)malloc(nsf[1] *
                                                sizeof(*dist_factor[x]));
        for (size_t y = 0; y < nsf[1]; y++) {
            dist_factor[x][y] = (intensity_t *)malloc(nsf[2] *
                                                 sizeof(*dist_factor[x][y]));
        }
    }
    set_dist_factor(nsf, ns, sp_stdev, dist_factor);

    // allocate the data
    alphar = (intensity_t *)malloc(numel * sizeof(*alphar));
    locr = (label_t **)malloc(numel * sizeof(*locr));
    for (size_t i = 0; i < numel; i++)
        locr[i] = (label_t *)malloc(3 * sizeof(*locr[i]));

    // make sure that we're starting from scratch
    for (int x = 0; x < (int)dims[0]; x++)
        for (int y = 0; y < (int)dims[1]; y++)
            for (int z = 0; z < (int)dims[2]; z++)
                if (consensus[x][y][z] == false)
                    for (int k = 0; k < num_keep; k++) {
                        alphanl[x][y][z][k] = 0;
                        locnl[x][y][z][k][0] = 0;
                        locnl[x][y][z][k][1] = 0;
                        locnl[x][y][z][k][2] = 0;
                    }

    // set the non-local data values for every voxel
    for (int x = 0; x < (int)dims[0]; x++) {
        print_status((size_t)x, dims[0]);
        for (int y = 0; y < (int)dims[1]; y++)
            for (int z = 0; z < (int)dims[2]; z++)
                if (consensus[x][y][z] == false)
                    set_alpha_for_voxel(x, y, z, num_keep, dims, ns, nc, ff,
                                        alphanl, alphar, locnl, locr,
                                        target, dist_factor);
    }

    // free the allocated memory
    for (size_t x = 0; x < nsf[0]; x++) {
        for (size_t y = 0; y < nsf[1]; y++) {
            free(dist_factor[x][y]);
        }
        free(dist_factor[x]);
    }
    free(dist_factor);
    free(alphar);
    for (size_t i = 0; i < numel; i++)
        free(locr[i]);
    free(locr);
}

double get_convergence_factor_est(label_t *** est,
                                  label_t *** estp,
                                  bool *** cons,
                                  size_t * dims)
{
    double c1 = 0; double c2 = 0;
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (cons[x][y][z] == false) {
                    c1++;
                    if (est[x][y][z] != estp[x][y][z])
                        c2++;
                }
    return(c2 / c1);
}

intensity_t ** get_transition_probabilities(label_t **** obs,
                                            bool *** cons,
                                            size_t * dims,
                                            size_t * ns,
                                            size_t num_labels,
                                            size_t num_raters,
                                            size_t num_keep)
{

    // first allocate the whole matrix
    intensity_t ** G;
    G = (intensity_t **)malloc(num_labels * sizeof(*G));
    for (size_t l = 0; l < num_labels; l++)
        G[l] = (intensity_t *)malloc(num_labels * sizeof(*G[l]));

    // initialize all probabilities to zero
    for (size_t l1 = 0; l1 < num_labels; l1++)
        for (size_t l2 = 0; l2 < num_labels; l2++)
            G[l1][l2] = 0;

    // set all of the probabilities
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (cons[x][y][z] == false)
                    for (size_t j = 0; j < num_raters; j++) {

                        // set the "true" label
                        label_t l1 = obs[x][y][z][j];

                        // set the patch that we will be analyzing
                        int xl = (int)x - ns[0];
                        int xh = (int)x + ns[0];
                        int yl = (int)y - ns[1];
                        int yh = (int)y + ns[1];
                        int zl = (int)z - ns[2];
                        int zh = (int)z + ns[2];

                        // iterate over the patch of interest
                        for (int xp = xl; xp <= xh; xp++)
                            for (int yp = yl; yp <= yh; yp++)
                                for (int zp = zl; zp <= zh; zp++) {
                                    if (xp >= 0 && xp < dims[0] &&
                                        yp >= 0 && yp < dims[1] &&
                                        zp >= 0 && zp < dims[2]) {

                                        label_t l2 = obs[xp][yp][zp][j];
                                        G[l1][l2]++;
                                    }
                                }
                    }

    for (size_t l1 = 0; l1 < num_labels; l1++) {
        intensity_t normfact = 0;
        for (size_t l2 = 0; l2 < num_labels; l2++)
            normfact += G[l1][l2];
        for (size_t l2 = 0; l2 < num_labels; l2++)
            G[l1][l2] /= normfact;
    }
    for (size_t l1 = 0; l1 < num_labels; l1++) {
        for (size_t l2 = 0; l2 < num_labels; l2++) {
            G[l1][l2] = (l1 == l2) ? 1 : -1;
            //G[l1][l2] = G[l1][l2] - (1 / (intensity_t)num_labels);
        }
    }

    return(G);
}

void apply_MRF_voxel(size_t x,
                     size_t y,
                     size_t z,
                     bool *** cons,
                     size_t num_labels,
                     size_t num_keep,
                     intensity_t * coefs,
                     intensity_t **** alphanl,
                     label_t ***** locnl,
                     intensity_t ** G,
                     label_t *** est,
                     label_t *** estp,
                     intensity_t **** W,
                     intensity_t **** Wp,
                     intensity_t **** prior)
{

    size_t xp, yp, zp;
    label_t lab = 0;
    intensity_t * lp;
    intensity_t * pp;
    intensity_t intfact;
    intensity_t labfact;
    intensity_t currlabfact;
    intensity_t max_val = -1;
    intensity_t normfact = 0;
    intensity_t lprb;

    intensity_t f1, f2, f3, f4, count;

    // allocate some temporary space
    lp = (intensity_t *)malloc(num_labels * sizeof(*lp));
    pp = (intensity_t *)malloc(num_labels * sizeof(*pp));

    // initialize the probabilities using the prior
    for (size_t l = 0; l < num_labels; l++) {
        lp[l] = 0;
        pp[l] = prior[x][y][z][l];
    }

    // iterate over all of the labels
    for (size_t l = 0; l < num_labels; l++)
        if (pp[l] > 0) {
            f1 = 0; f2 = 0; f3 = 0; f4 = 0; count = 0;
            for (size_t k = 0; k < num_keep; k++) {

                // set the intensity factor
                intfact = alphanl[x][y][z][k];

                if (intfact > -10000) {

                    // set the current location
                    xp = locnl[x][y][z][k][0];
                    yp = locnl[x][y][z][k][1];
                    zp = locnl[x][y][z][k][2];

                    // set the current label fact
                    if (cons[xp][yp][zp])
                        currlabfact = (estp[xp][yp][zp] == l) ? 1 : 0;
                    else
                        currlabfact = Wp[xp][yp][zp][l];

                    // set the label factor
                    labfact = 0;
                    if (cons[xp][yp][zp])
                        labfact += G[l][estp[xp][yp][zp]];
                    else
                        for (size_t s = 0; s < num_labels; s++)
                            if (G[l][s] != 0) {
                                lprb = Wp[xp][yp][zp][s];
                                if (lprb > 0)
                                    labfact += G[l][s] * lprb;
                            }

                    // increment the factors
                    f1 += currlabfact;
                    f2 += labfact;
                    f3 += currlabfact * intfact;
                    f4 += intfact * labfact;
                    count++;

                    if (x == 101 && y == 112 && z == 5)
                        printf("%lu %lu %lu %d\n", xp, yp, zp, estp[xp][yp][zp]);
                }
            }

            // set the final label probability
            lp[l] = pp[l] * exp((1/count) * (coefs[0]*f1 + coefs[1]*f2 +
                                             coefs[2]*f3 + coefs[3]*f4));
            normfact += lp[l];

            if (x == 101 && y == 112 && z == 5)
                printf("%f %f %f %f\n", f2, f3, lp[l], pp[l]);
        }

    // normalize the label probability and set the estimate
    max_val = -1;
    for (size_t l = 0; l < num_labels; l++) {
        lp[l] /= normfact;
        if (lp[l] > max_val) {
            max_val = lp[l];
            lab = l;
        }
        W[x][y][z][l] = lp[l];
    }
    est[x][y][z] = lab;
    if (x == 101 && y == 112 && z == 5)
        printf("%f %d\n", normfact, lab);

    // free the temporary memory
    free(lp);
    free(pp);
}

void run_correction_algorithm(label_t *** est,
                              intensity_t **** W,
                              intensity_t **** prior,
                              intensity_t **** alphanl,
                              label_t ***** locnl,
                              intensity_t ** G,
                              intensity_t * coefs,
                              bool *** cons,
                              size_t * dims,
                              size_t num_keep,
                              size_t num_labels)
{

    // initialize the prior
    label_t *** estp;
    double convergence_factor = 10000;
    intensity_t fact,
                **** Wp;
    size_t xp, yp, zp;

    // set the threshold
    double thresh = 1e-4;
    int max_num_iters = 30;
    int curr_num_iters = 0;

    // allocate some space
    estp = allocate_3D_data<label_t>(dims);
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                estp[x][y][z] = est[x][y][z];
    Wp = allocate_4D_data<intensity_t>(dims, num_labels, cons);

    while (convergence_factor > thresh && curr_num_iters < max_num_iters) {

        // set the previous estimate and label probabilities
        for (size_t x = 0; x < dims[0]; x++)
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++) {

                    if (cons[x][y][z] == false) {

                        // set the previous estimate
                        estp[x][y][z] = est[x][y][z];

                        // set the previous label probabilities
                        for (size_t l = 0; l < num_labels; l++) {
                            Wp[x][y][z][l] = W[x][y][z][l];

                            if (W[x][y][z][l] == 1)
                                cons[x][y][z] = true;
                        }

                        if (cons[x][y][z] == true) {
                            free(W[x][y][z]);
                            free(Wp[x][y][z]);
                            free(prior[x][y][z]);
                            free(alphanl[x][y][z]);
                            for (size_t k = 0; k < num_keep; k++)
                                free(locnl[x][y][z][k]);
                            free(locnl[x][y][z]);
                        }
                    }
                }

        // Apply the MRF
        for (size_t x = 0; x < dims[0]; x++) {
            print_status((size_t)x, dims[0]);
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++)
                    if (cons[x][y][z] == false)
                        apply_MRF_voxel(x, y, z, cons, num_labels, num_keep,
                                        coefs, alphanl, locnl, G, est, estp,
                                        W, Wp, prior);
        }

        // set the convergence factor
        convergence_factor = get_convergence_factor_est(est, estp, cons, dims);
        fprintf(stdout, "Convergence Factor: %f\n", convergence_factor);
        curr_num_iters++;
    }

    // free the allocated memory
    free_3D_data<label_t>(estp, dims);
    free_4D_data<intensity_t>(Wp, dims, cons);
}

