#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <nifti1_io.h>
#include <nifti1.h>
#include "fusion_io.h"
#include "fusion_tools.h"
#include "OutOfAtlasLabeling.h"


void usage()
{
    printf("Out-of-Atlas Labeling version 0.1\n\n");

    printf("Usage: oal [options] -l <labels> -i <ints> -t <targ> -e <est> -o <out>\n\n");

    printf("Required Options:\n");

    printf("\t-l, --labels <labels-file>\n");
    printf("\t\tThe atlas labels file (in nii/.nii.gz format).\n\n");

    printf("\t-i, --intensities <intensities-file>\n");
    printf("\t\tThe atlas intensities file (in .nii/.nii.gz format).\n\n");

    printf("\t-t, --target <target-file>\n");
    printf("\t\tThe target intensity file (in .nii/.nii.gz format).\n\n");

    printf("\t-e, --estimate <estimate-file>\n");
    printf("\t\tThe target labels estimate file (in .nii/.nii.gz format).\n\n");

    printf("\t-o, --output <output-file>\n");
    printf("\t\tThe anomaly results file (must end in .nii/.nii.gz)\n");

    printf("\nAdditional Options:\n");

    printf("\t-v, --search-volume <sx> <sy> <sz>\n");
    printf("\t\tDefines the search volume radii in the x, y, and z directions\n");
    printf("\t\tDefault value: 5 5 5\n\n");

    printf("\t-b, --num-bins <num>\n");
    printf("\t\tThe number of bins used in the kernel density estimation\n");
    printf("\t\tDefault value: 64\n\n");

    printf("\t-s, --stdev <val>\n");
    printf("\t\tThe standard deviation used in the KDE\n");
    printf("\t\tDefault value: 0.1\n\n");

    exit(0);
}

void set_input_args(int argc,
                    char ** argv,
                    char ** obsfile,
                    char ** imsfile,
                    char ** targfile,
                    char ** estfile,
                    size_t * ns,
                    size_t * num_bins,
                    float * stdev,
                    char ** outfile)
{

    if (argc == 1)
        usage();

    // iterate over the input arguments
    for (int i = 1; i < argc; i++) {
        if (strcmp(argv[i], "-l") == 0 ||
            strcmp(argv[i], "--labels") == 0) {
            *obsfile = argv[++i];
        } else if (strcmp(argv[i], "-i") == 0 ||
                   strcmp(argv[i], "--intensities") == 0) {
            *imsfile = argv[++i];
        } else if (strcmp(argv[i], "-t") == 0 ||
                   strcmp(argv[i], "--target") == 0) {
            *targfile = argv[++i];
        } else if (strcmp(argv[i], "-e") == 0 ||
                   strcmp(argv[i], "--estimate") == 0) {
            *estfile = argv[++i];
        } else if (strcmp(argv[i], "-o") == 0 ||
                   strcmp(argv[i], "--output") == 0) {
            *outfile = argv[++i];
        } else if (strcmp(argv[i], "-v") == 0 ||
                   strcmp(argv[i], "--search-volume") == 0) {
            ns[0] = (size_t)atoi(argv[++i]);
            ns[1] = (size_t)atoi(argv[++i]);
            ns[2] = (size_t)atoi(argv[++i]);
        } else if (strcmp(argv[i], "-b") == 0 ||
                   strcmp(argv[i], "--num-bins") == 0) {
            *num_bins = (size_t)atoi(argv[++i]);
        } else if (strcmp(argv[i], "-s") == 0 ||
                   strcmp(argv[i], "--stdev") == 0) {
            *stdev = atof(argv[++i]);
        } else if (strcmp(argv[i], "-h") == 0 ||
                   strcmp(argv[i], "--help") == 0) {
            usage();
        } else {
            fprintf(stderr, "Error Reading Input Argument: %s\n\n", argv[i]);
            usage();
        }
    }

    // do some quick error checking
    bool error = false;
    if (*obsfile == NULL) {
        fprintf(stderr, "Error: Labels file not specified\n");
        error = true;
    }
    if (*imsfile == NULL) {
        fprintf(stderr, "Error: Intensities file not specified\n");
        error = true;
    }
    if (*targfile == NULL) {
        fprintf(stderr, "Error: Target file not specified\n");
        error = true;
    }
    if (*estfile == NULL) {
        fprintf(stderr, "Error: Target labels estimate file not specified\n");
        error = true;
    }
    if (*outfile == NULL) {
        fprintf(stderr, "Error: Output file not specified\n");
        error = true;
    }
    if (error) {
        fprintf(stderr, "\n");
        usage();
    }
}

void print_input_parms(char * obsfile,
                       char * imsfile,
                       char * targfile,
                       char * estfile,
                       size_t * dims,
                       size_t * ns,
                       size_t num_bins,
                       size_t num_raters,
                       size_t num_labels,
                       float stdev)
{
    printf("Atlas labels file: %s\n", obsfile);
    printf("Atlas intensities file: %s\n", imsfile);
    printf("Target intensity file: %s\n", targfile);
    printf("Target Labels Estimate file: %s\n", estfile);
    printf("Image Dimensions: [%lu %lu %lu]\n", dims[0], dims[1], dims[2]);
    printf("Search Neighborhood Radii: [%lu %lu %lu]\n", ns[0], ns[1], ns[2]);
    printf("Number of bins: %lu\n", num_bins);
    printf("Standard Deviation for KDE: %f\n", stdev);
    printf("Number of raters: %lu\n", num_raters);
    printf("Number of labels: %lu\n", num_labels);
}


void set_label_densities(label_t **** obs,
                         intensity_t **** ims,
                         intensity_t * xi,
                         intensity_t ** dens,
                         float stdev,
                         const size_t * dims,
                         const size_t num_raters,
                         const size_t num_labels,
                         const size_t num_bins)
{

    // initialize some variables
    intensity_t * counts,
                * ffacts;
    counts = (intensity_t *)malloc((num_labels-1) * sizeof(*counts));
    ffacts = (intensity_t *)malloc((num_labels-1) * sizeof(*counts));

    for (size_t l = 0; l < num_labels-1; l++)
        counts[l] = 0;

    // set the counts
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                for (size_t j = 0; j < num_raters; j++) {
                    label_t l = obs[x][y][z][j];
                    if (l > 0)
                        counts[l-1]++;
                }

    // initialize the densities to zero
    for (size_t l = 0; l < num_labels-1; l++) {
        ffacts[l] = 1 / (-2 * pow(stdev * pow(counts[l], -1/5), 2));
        for (size_t b = 0; b < num_bins; b++)
            dens[l][b] = 0;
    }

    // set the label densities
    for (size_t x = 0; x < dims[0]; x++) {
        print_status(x, dims[0]);
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                for (size_t j = 0; j < num_raters; j++) {
                    label_t l = obs[x][y][z][j];
                    if (l > 0) {
                        l--;
                        intensity_t v = ims[x][y][z][j];
                        for (size_t b = 0; b < num_bins; b++)
                            dens[l][b] += exp(ffacts[l] * pow(xi[b] - v, 2));
                    }
                }
    }

    // normalize the label densities
    for (size_t l = 0; l < num_labels-1; l++) {
        intensity_t fact = 0;
        for (size_t b = 0; b < num_bins-1; b++) {
            intensity_t num = (dens[l][b] + dens[l][b+1]) / 2;
            intensity_t den = xi[b+1] - xi[b];
            fact += num / den;
        }
        for (size_t b = 0; b < num_bins; b++)
            dens[l][b] /= fact;
    }

    free(counts);
    free(ffacts);
}

intensity_t get_err_val(intensity_t *** target,
                        label_t *** est,
                        intensity_t ** dens,
                        intensity_t * target_dens,
                        intensity_t * atlas_dens,
                        intensity_t * diff_dens,
                        intensity_t * xi,
                        float stdev,
                        size_t * num_app,
                        const size_t * dims,
                        const size_t * ns,
                        const size_t num_bins,
                        const size_t num_labels,
                        const size_t x,
                        const size_t y,
                        const size_t z)
{


    // set the patch that we will be analyzing
    int xl = x - ns[0];
    int xh = x + ns[0];
    int yl = y - ns[1];
    int yh = y + ns[1];
    int zl = z - ns[2];
    int zh = z + ns[2];

    // set the patch that we will be analyzing
    int xl0 = MAX(xl, 0);
    int xh0 = MIN(xh, dims[0]-1);
    int yl0 = MAX(yl, 0);
    int yh0 = MIN(yh, dims[1]-1);
    int zl0 = MAX(zl, 0);
    int zh0 = MIN(zh, dims[2]-1);

    // reset the number of appearances
    for (size_t l = 0; l < num_labels-1; l++)
        num_app[l] = 0;
    for (size_t b = 0; b < num_bins; b++) {
        atlas_dens[b] = 0;
        target_dens[b] = 0;
    }

    intensity_t num = 0;
    for (int xp = xl0; xp <= xh0; xp++)
        for (int yp = yl0; yp <= yh0; yp++)
            for (int zp = zl0; zp <= zh0; zp++)
                if (est[xp][yp][zp] >= 1)
                    num++;
    intensity_t ffact = 1 / (-2 * pow(stdev * pow(num, -1/5), 2));

    // set the target density and keep track of the number of appearances
    for (int xp = xl0; xp <= xh0; xp++)
        for (int yp = yl0; yp <= yh0; yp++)
            for (int zp = zl0; zp <= zh0; zp++) {
                label_t l = est[xp][yp][zp];
                if (l > 0) {
                    num_app[l-1]++;
                    intensity_t v = target[xp][yp][zp];
                    for (size_t b = 0; b < num_bins; b++)
                        target_dens[b] += exp(ffact * pow(xi[b] - v, 2));
                }
            }

    // set the atlas density
    for (size_t l = 0; l < num_labels-1; l++)
        if (num_app[l] > 0)
            for (size_t b = 0; b < num_bins; b++)
                atlas_dens[b] += num_app[l] * dens[l][b];

    // find the normalization values
    intensity_t max_t = 0;
    intensity_t max_a = 0;
    for (size_t b = 0; b < num_bins; b++) {
        if (atlas_dens[b] > max_a)
            max_a = atlas_dens[b];
        if (target_dens[b] > max_t)
            max_t = target_dens[b];
    }

    // normalize the densities
    for (size_t b = 0; b < num_bins; b++) {
        atlas_dens[b] /= max_a;
        target_dens[b] /= max_t;
        if (atlas_dens[b] <= target_dens[b])
            diff_dens[b] = target_dens[b] - atlas_dens[b];
        else
            diff_dens[b] = 0;
    }

    intensity_t val1 = 0;
    intensity_t val2 = 0;

    // determine the integral of the target density function
    for (size_t b = 0; b < num_bins-1; b++) {
        intensity_t num = (target_dens[b] + target_dens[b+1]) / 2;
        intensity_t den = xi[b+1] - xi[b];
        val1 += num / den;
    }

    // determine the integral of the difference
    for (size_t b = 0; b < num_bins-1; b++) {
        intensity_t num = (diff_dens[b] + diff_dens[b+1]) / (val1*2);
        intensity_t den = xi[b+1] - xi[b];
        val2 += num / den;
    }

    return(val2);

}

void run_anomaly_detection(intensity_t *** target,
                           label_t **** obs,
                           intensity_t **** ims,
                           label_t *** est,
                           intensity_t *** errim,
                           float stdev,
                           const size_t * dims,
                           const size_t * ns,
                           const size_t num_bins,
                           const size_t num_raters,
                           const size_t num_labels)
{
    // initialize some variables
    intensity_t * xi,
                * target_dens,
                * atlas_dens,
                * diff_dens,
                ** dens;
    size_t * num_app;
    num_app = (size_t *)malloc((num_labels-1) * sizeof(*num_app));

    // allocate some space
    xi = (intensity_t *)malloc(num_bins * sizeof(*xi));
    target_dens = (intensity_t *)malloc(num_bins * sizeof(*target_dens));
    atlas_dens = (intensity_t *)malloc(num_bins * sizeof(*atlas_dens));
    diff_dens = (intensity_t *)malloc(num_bins * sizeof(*diff_dens));
    dens = (intensity_t **)malloc((num_labels-1) * sizeof(*dens));
    for (size_t l = 0; l < num_labels-1; l++)
        dens[l] = (intensity_t *)malloc(num_bins * sizeof(*dens[l]));

    // find the min and max value
    intensity_t min = 10000;
    intensity_t max = -10000;
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (est[x][y][z] > 0) {
                    if (target[x][y][z] < min)
                        min = target[x][y][z];
                    if (target[x][y][z] > max)
                        max = target[x][y][z];
                }
    min -= 1;
    max += 1;

    // set xi
    size_t count = 0;
    for (intensity_t v = min; v <= max; v += (max - min) / (num_bins-1))
        xi[count++] = v;

    // set the label densities
    fprintf(stdout, "-> Setting the Label Density Functions.\n");
    set_label_densities(obs, ims, xi, dens, stdev, dims, num_raters,
                        num_labels, num_bins);

    // run the anomaly detection
    fprintf(stdout, "-> Running the Out-of-Atlas Labeling Algorithm.\n");
    for (size_t x = 0; x < dims[0]; x++) {
        print_status(x, dims[0]);
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (est[x][y][z] > 0) {
                    errim[x][y][z] = get_err_val(target, est, dens,
                                                 target_dens, atlas_dens,
                                                 diff_dens, xi, stdev, num_app,
                                                 dims, ns, num_bins, num_labels,
                                                 x, y, z);
                } else
                    errim[x][y][z] = 0;
    }

    // free the allocated space
    free(num_app);
    free(xi);
    free(target_dens);
    free(atlas_dens);
    free(diff_dens);
    for (size_t l = 0; l < num_labels-1; l++)
        free(dens[l]);
    free(dens);

}

void save_outfile(char * targfile,
                  char * outfile,
                  intensity_t *** errim,
                  const size_t * dims)
{
    // read the nifti image
    nifti_image * nim;
    nim = nifti_image_read(targfile, 1);

    // set the output datatype
    nim->datatype = OUT_INTENSITY_TYPE;

    // set the target
    set_nim_3D<intensity_t>(nim, errim, dims);

    // set the output names
    if (nim->fname != NULL) free(nim->fname);
    if (nim->iname != NULL) free(nim->iname);
    size_t ll = strlen(outfile);
    nim->fname = (char *)calloc(1, ll+8);
    strcpy(nim->fname, outfile);
    nim->iname = (char *)calloc(1, ll+8);
    strcpy(nim->iname, outfile);

    // save the file
    nifti_image_write(nim);

    // free the file
    nifti_image_free(nim);
}

