/*
 * Copyright 1993-2010 NVIDIA Corporation.  All rights reserved.
 *
 * Please refer to the NVIDIA end user license agreement (EULA) associated
 * with this source code for terms and conditions that govern your use of
 * this software. Any use, reproduction, disclosure, or distribution of
 * this software and related documentation outside the terms of the EULA
 * is strictly prohibited.
 *
 */

 /*
 * This sample implements a separable convolution filter 
 * of a 2D image with an arbitrary kernel.
 */

// Utilities and system includes
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <nifti1_io.h>
#include <nifti1.h>
#include "fusion_io.h"
#include "fusion_tools.h"
#include "NonLocalSTAPLE.h"

////////////////////////////////////////////////////////////////////////////////
// Main program
////////////////////////////////////////////////////////////////////////////////
int main(int argc,
         char **argv)
{

    //
    // Variable Declarations
    //
    char * obsfile,
         * imsfile,
         * targfile,
         * priorfile,
         * out_prefix;
    size_t num_raters,
           num_labels,
           num_keep,
           num_coef,
           * dims,
           * ns,
           * nc,
           * nsf;
    label_t **** obs,
            *** estimate,
            ***** obsnl;
    intensity_t **** ims,
                *** target,
                *** theta,
                **** prior,
                **** W,
                ***** alphanl,
                rho,
                epsilon,
                stdev,
                sp_stdev;
    bool *** consensus,
         save_labelprobs,
         save_theta;

    // initialize some variables
    obsfile = NULL;
    imsfile = NULL;
    targfile = NULL;
    priorfile = NULL;
    out_prefix = NULL;
    W = NULL;
    prior = NULL;

    // allocate some initial space
    dims = (size_t *)malloc(3 * sizeof(*dims));
    ns = (size_t *)malloc(3 * sizeof(*ns));
    nc = (size_t *)malloc(3 * sizeof(*nc));
    nsf = (size_t *)malloc(3 * sizeof(*nsf));

    // set the default parameter values
    epsilon = 1e-5;
    num_coef = 3;
    rho = 0.1;
    ns[0] = 5; ns[1] = 5; ns[2] = 5;
    nc[0] = 1; nc[1] = 1; nc[2] = 1;
    stdev = 0.1;
    sp_stdev = 2;
    num_keep = 15;
    save_labelprobs = false;
    save_theta = false;

    // get the input arguments
    set_input_args(argc, argv, &obsfile, &imsfile, &targfile, &epsilon, ns, nsf,
                   nc, &stdev, &sp_stdev, &priorfile, &out_prefix, &num_keep,
                   &save_labelprobs, &save_theta);
    priorfile = NULL;

    size_t numel = (2*ns[0]+1) * (2*ns[1]+1) * (2*ns[2]+1);
    if (num_keep > numel)
        num_keep = numel;
    if (numel == 1) {
        nc[0] = 0;
        nc[1] = 0;
        nc[2] = 0;
    }


    // read the data from the input files
    fprintf(stdout, "-> Reading the target image.\n");
    fflush(stdout);
    target = get_target(targfile, dims);
    fprintf(stdout, "-> Reading the atlas label obversations.\n");
    fflush(stdout);
    obs = get_obs(obsfile, dims, &num_raters, &num_labels);
    fprintf(stdout, "-> Reading the atlas intensities.\n");
    fflush(stdout);
    ims = get_ims(imsfile, dims, num_raters);

    // print the input parameters to the screen
    fprintf(stdout, "-> Parsed Input Parameters:\n");
    fflush(stdout);
    print_input_parms(obsfile, imsfile, targfile, priorfile, dims, ns, nc,
                      stdev, sp_stdev, num_raters, num_labels, num_keep);

    // initialize the consensus and estimate matrices
    fprintf(stdout, "-> Initializing Consensus\n");
    fflush(stdout);
    consensus = (bool ***)malloc(dims[0] * sizeof(*consensus));
    estimate = (label_t ***)malloc(dims[0] * sizeof(*estimate));
    initialize_consensus_estimate(consensus, obs, estimate, ns, dims,
                                  num_raters, num_labels);

    // normalize the intensity
    fprintf(stdout, "-> Normalizing intensity\n");
    fflush(stdout);
    normalize_intensity(target, estimate, ims, obs, dims, num_raters,
                        num_labels, num_coef, rho);

    // set the non-local correspondence data
    fprintf(stdout, "-> Calculating the non-local reformulation ");
    fprintf(stdout, "(This can take a long time)\n");
    fflush(stdout);
    obsnl = (label_t *****)malloc(dims[0] * sizeof(*obsnl));
    alphanl = (intensity_t *****)malloc(dims[0] * sizeof(*alphanl));
    allocate_non_local(obsnl, alphanl, consensus, dims, num_raters, num_keep);
    set_nonlocal_data(obs, obsnl, alphanl, ims, target, consensus, dims, ns,
                      nc, nsf, num_raters, num_keep, stdev, sp_stdev);

    // free the data we're done with
    fprintf(stdout, "-> Freeing temporary data\n");
    fflush(stdout);
    free_3D_data<intensity_t>(target, dims);
    free_4D_data<intensity_t>(ims, dims);
    free_4D_data<label_t>(obs, dims);
    free(ns); free(nsf); free(nc);

    //// load the prior if one was specified
    //if (priorfile != NULL) {
    //    prior = (intensity_t ****)malloc(dims[0] * sizeof(*prior));
    //    set_prior(priorfile, prior, dims, num_labels);
    //}

    // allocate the labelprobs if necessary
    if (save_labelprobs == 1) {
        W = (intensity_t ****)malloc(dims[0] * sizeof(*W));
        allocate_W(W, dims, num_labels);
    }

    // allocate theta
    theta = (intensity_t ***)malloc(num_labels * sizeof(*theta));
    allocate_theta(theta, num_labels, num_raters);

    // run the Non-Local STAPLE algorithm
    fprintf(stdout, "-> Running the E-M Algorithm\n");
    fflush(stdout);
    run_NonLocalSTAPLE(obsnl, alphanl, estimate, theta, prior, W, epsilon,
                       consensus, dims, num_raters, num_labels, num_keep);

    // save the results to the appropriate files
    fprintf(stdout, "-> Saving Results\n");
    fflush(stdout);
    save_results(dims, out_prefix, targfile, estimate);

}

