
// Utilities and system includes
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <nifti1_io.h>
#include <nifti1.h>
#include "fusion_io.h"
#include "fusion_tools.h"
#include "Correction.h"

////////////////////////////////////////////////////////////////////////////////
// Main program
////////////////////////////////////////////////////////////////////////////////
int main(int argc,
         char ** argv)
{
    // initialize some variables
    intensity_t *** target = NULL,
                **** alphanl = NULL,
                **** W = NULL,
                **** prior = NULL,
                ** G = NULL,
                * coefs = NULL;
    label_t **** obs = NULL,
            *** est = NULL,
            ***** locnl = NULL;
    char * targfile = NULL,
         * obsfile = NULL,
         * outfile = NULL;
    size_t num_labels = 0,
           num_raters = 0,
           num_keep = 10,
           * dims = NULL,
           * nc = NULL,
           * nsf = NULL,
           * ns = NULL;
    intensity_t stdev = 0.1,
                sp_stdev = 200;
    bool *** cons;

    // allocate some memory
    coefs = (intensity_t *)malloc(4 * sizeof(*coefs));
    dims = (size_t *)malloc(3 * sizeof(*dims));
    ns = (size_t *)malloc(3 * sizeof(*ns));
    nsf = (size_t *)malloc(3 * sizeof(*nsf));
    nc = (size_t *)malloc(3 * sizeof(*nc));
    for (size_t i = 0; i < 3; i++) {
        ns[i] = 1;
        nc[i] = 1;
    }

    coefs[0] = 1;
    coefs[1] = 1;
    coefs[2] = 1;
    coefs[3] = 1;

    // set the input arguments
    set_input_args(argc, argv, &targfile, &obsfile,
                   ns, nc, coefs, &stdev, &sp_stdev, &outfile);

    for (size_t i = 0; i < 3; i++)
        nsf[i] = 2*ns[i]+1;
    num_keep = nsf[0]*nsf[1]*nsf[2];

    // read the target
    fprintf(stdout, "-> Reading the target image.\n");
    target = get_target(targfile, dims);

    // read the initial estimate
    fprintf(stdout, "-> Reading the target observations.\n");
    obs = get_obs(obsfile, dims, &num_raters, &num_labels);

    // print the parameters to the screen
    fprintf(stdout, "-> Parsed Input Parameters:\n");
    print_input_parms(targfile, obsfile, dims, ns, nc, num_raters,
                      num_labels, coefs, stdev, sp_stdev);

    // normalize the intensity for the target image
    fprintf(stdout, "-> Determine Majority Vote Consensus\n");
    cons = allocate_3D_data<bool>(dims);
    get_majority_vote_consensus(cons, obs, num_raters, dims);

    // allocate some space
    fprintf(stdout, "-> Allocating some space\n");
    alphanl = allocate_4D_data<intensity_t>(dims, num_keep, cons);
    locnl = allocate_5D_data<label_t>(dims, num_keep, 3, cons);
    prior = allocate_4D_data<intensity_t>(dims, num_labels, cons);
    est = allocate_3D_data<label_t>(dims);

    // set the majority vote prior and estimate
    fprintf(stdout, "-> Setting Majority Vote Prior and Estimate\n");
    set_majority_vote_info(cons, obs, num_raters, num_labels, dims, prior, est);

    // set the prior and the initial previous W
    W = allocate_4D_data<intensity_t>(dims, num_labels, cons);
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (cons[x][y][z] == false)
                    for (size_t l = 0; l < num_labels; l++)
                        W[x][y][z][l] = prior[x][y][z][l];

    // normalize the intensity for the target image
    fprintf(stdout, "-> Resetting the target intensity\n");
    reset_target_intensity(target, est, dims);

    // run the non-local reformulation for the target image
    fprintf(stdout, "-> Running the non-local reformulation\n");
    set_nonlocal_data(target, cons, alphanl, locnl, dims,
                      nc, ns, nsf, num_keep, stdev, sp_stdev);

    // construct the transitiion matrices
    fprintf(stdout, "-> Set the Transition Probabilities\n");
    G = get_transition_probabilities(obs, cons, dims, ns, num_labels,
                                     num_raters, num_keep);

    // apply the correction
    fprintf(stdout, "-> Running the correction algorithm\n");
    run_correction_algorithm(est, W, prior, alphanl, locnl, G, coefs,
                             cons, dims, num_keep, num_labels);

    // save the output file
    fprintf(stdout, "-> Saving the final estimate\n");
    save_results(dims, outfile, targfile, est);

    // free the remaining allocated data
    fprintf(stdout, "-> Freeing the remaining data\n");
    for (size_t l = 0; l < num_labels; l++)
        free(G[l]);
    free(G);
    free_3D_data<intensity_t>(target, dims);
    free(ns); free(nc); free(nsf);
    free_3D_data<label_t>(est, dims);
    free_4D_data(obs, dims);
    free_4D_data<intensity_t>(alphanl, dims, cons);
    free_5D_data<label_t>(locnl, dims, num_keep, cons);
    free_3D_data<bool>(cons, dims);
    free(dims);

}

