#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <nifti1_io.h>
#include <nifti1.h>
#include "fusion_io.h"
#include "fusion_tools.h"

void normalize_intensity_unit(intensity_t *** target,
                              label_t *** est,
                              intensity_t **** ims,
                              label_t **** obs,
                              const size_t * dims,
                              const size_t num_raters)
{

    // calculate the mean and standard deviation of the target
    intensity_t mean = 0;
    intensity_t stdev = 0;
    size_t count = 0;

    // calculate the mean of the target
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (est[x][y][z] > 0) {
                    mean += target[x][y][z];
                    count++;
                }
    mean /= count;

    // calculate the standard deviation of the target
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (est[x][y][z] > 0)
                    stdev += pow(target[x][y][z] - mean, 2);
    stdev = sqrt(stdev / (count - 1));

    // normalize the target
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                target[x][y][z] = (target[x][y][z] - mean) / (2*stdev);

    for (size_t j = 0; j < num_raters; j++) {
        mean = 0; stdev = 0; count = 0;

        // calculate the mean of the atlas
        for (size_t x = 0; x < dims[0]; x++)
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++)
                    if (obs[x][y][z][j] > 0) {
                        mean += ims[x][y][z][j];
                        count++;
                    }
        mean /= count;

        // calculate the standard deviation of the atlas
        for (size_t x = 0; x < dims[0]; x++)
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++)
                    if (obs[x][y][z][j] > 0)
                        stdev += pow(ims[x][y][z][j] - mean, 2);
        stdev = sqrt(stdev / (count - 1));

        // normalize the atlas
        for (size_t x = 0; x < dims[0]; x++)
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++)
                    ims[x][y][z][j] = (ims[x][y][z][j] - mean) / (2*stdev);
    }
}

void normalize_intensity_poly(intensity_t *** target,
                              label_t *** est,
                              intensity_t **** ims,
                              label_t **** obs,
                              const size_t * dims,
                              const size_t num_raters,
                              const size_t num_labels,
                              const size_t num_coef)
{

    size_t *  counts,
           *  tcounts,
           ** acounts,
              num_keep = 0,
              labcount;
    intensity_t ** ameans,
                ** alphas,
                *  tmeans,
                * ta;
    int * lablist;

    size_t thresh = 50;
    intensity_t inc = 0.0001;
    intensity_t th = 1e-9;
    intensity_t convergence_factor = 1000;
    intensity_t mse_prev = 0;
    intensity_t mse = 100000;

    // first, keep track of the number of appearances of each label
    counts = (size_t *)malloc(num_labels * sizeof(*counts));
    for (size_t l = 0; l < num_labels; l++)
        counts[l] = 0;

    // set the number of appearances for each label
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                counts[est[x][y][z]]++;

    // set the number of labels to keep
    num_keep = 0;
    for (size_t l = 0; l < num_labels; l++)
        if (counts[l] > thresh)
            num_keep++;

    // allocate some space
    lablist = (int *)malloc(num_labels * sizeof(*lablist));
    tmeans = (intensity_t *)malloc(num_keep * sizeof(*tmeans));
    tcounts = (size_t *)malloc(num_keep * sizeof(*tcounts));
    ameans = (intensity_t **)malloc(num_keep * sizeof(*ameans));
    acounts = (size_t **)malloc(num_keep * sizeof(*acounts));
    ta = (intensity_t *)malloc(num_coef * sizeof(*ta));
    alphas = (intensity_t **)malloc(num_coef * sizeof(*alphas));
    for (size_t k = 0; k < num_keep; k++) {
        ameans[k] = (intensity_t *)malloc(num_raters * sizeof(*ameans[k]));
        acounts[k] = (size_t *)malloc(num_raters * sizeof(*acounts[k]));
    }
    for (size_t c = 0; c < num_coef; c++)
        alphas[c] = (intensity_t *)malloc(num_raters * sizeof(*alphas[c]));

    // set the list of labels to keep
    labcount = 0;
    for (size_t l = 0; l < num_labels; l++) {
        lablist[l] = -1;
        if (counts[l] > thresh) {
            lablist[l] = labcount;
            labcount++;
        }
    }

    // initialize the values
    for (size_t c = 0; c < num_coef; c++)
        for (size_t j = 0; j < num_raters; j++)
            alphas[c][j] = 0;
    for (size_t k = 0; k < num_keep; k++) {
        tmeans[k] = 0;
        tcounts[k] = 0;
        for (size_t j = 0; j < num_raters; j++) {
            ameans[k][j] = 0;
            acounts[k][j] = 0;
        }
    }

    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++) {
                int k = lablist[est[x][y][z]];

                // set the target means
                if (k >= 0) {
                    tmeans[k] += target[x][y][z];
                    tcounts[k]++;
                }

                // set the atlas means
                for (size_t j = 0; j < num_raters; j++) {
                    k = lablist[obs[x][y][z][j]];
                    if (k >= 0) {
                        ameans[k][j] += ims[x][y][z][j];
                        acounts[k][j]++;
                    }
                }
            }

    // initialize the values
    for (size_t k = 0; k < num_keep; k++) {
        for (size_t j = 0; j < num_raters; j++)
            ameans[k][j] /= acounts[k][j];
        tmeans[k] /= tcounts[k];
    }

    // run the gradient descent algorithm
    while (convergence_factor > th) {
        mse_prev = mse;
        for (size_t j = 0; j < num_raters; j++) {
            for (size_t c = 0; c < num_coef; c++)
                ta[c] = alphas[c][j];
            for (size_t c = 0; c < num_coef; c++) {
                intensity_t sum_k = 0;
                for (size_t k = 0; k < num_keep; k++) {
                    intensity_t val = 0;
                    for (size_t c2 = 0; c2 < num_coef; c2++)
                        val += ta[c2] * pow(ameans[k][j], c2);
                    sum_k += (val - tmeans[k]) * pow(ameans[k][j], c);
                }
                alphas[c][j] -= inc * sum_k;
            }
        }

        // calculate the mean squared error
        mse = 0;
        for (size_t j = 0; j < num_raters; j++)
            for (size_t k = 0; k < num_keep; k++) {
                intensity_t val = 0;
                for (size_t c = 0; c < num_coef; c++)
                    val += alphas[c][j] * pow(ameans[k][j], c);
                mse += pow(tmeans[k] - val, 2);
            }
        mse /= (num_raters * num_keep);
        convergence_factor = mse_prev - mse;
    }

    // apply the transformation
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                for (size_t j = 0; j < num_raters; j++) {
                    intensity_t val = 0;
                    for (size_t c = 0; c < num_coef; c++)
                        val += alphas[c][j] * pow(ims[x][y][z][j], c);
                    ims[x][y][z][j] = val;
                }

    // free the allocated memory
    free(counts);
    free(lablist);
    free(tmeans);
    free(tcounts);
    for (size_t k = 0; k < num_keep; k++) {
        free(ameans[k]);
        free(acounts[k]);
    }
    free(ameans);
    free(acounts);
    for (size_t c = 0; c < num_coef; c++)
        free(alphas[c]);
    free(alphas);
    free(ta);
}

void normalize_intensity_decay(intensity_t *** target,
                               label_t *** est,
                               intensity_t **** ims,
                               label_t **** obs,
                               const size_t * dims,
                               const size_t num_raters,
                               const intensity_t rho)
{

    intensity_t *** vals;
    intensity_t mina = 10000;
    intensity_t max_val = 10000;

    // make sure its a valid value for rho
    if (rho < 0.01) {
        fprintf(stdout, "Intensity Decay not performed (rho < 0.01)\n");
        return;
    }

    int num = round(5 * sqrt(1 / rho));

    // set the overall min value
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (target[x][y][z] < mina)
                    mina = target[x][y][z];

    // allocate space for the vals
    vals = (intensity_t ***)malloc(dims[0] * sizeof(*vals));
    for (size_t x = 0; x < dims[0]; x++) {
        vals[x] = (intensity_t **)malloc(dims[1] * sizeof(*vals[x]));
        for (size_t y = 0; y < dims[1]; y++)
            vals[x][y] = (intensity_t *)malloc(dims[2] * sizeof(*vals[x][y]));
    }

    // first, do the target
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                vals[x][y][z] = (est[x][y][z] > 0) ? 0 : max_val;
    for (int i = 0; i < num; i++)
        for (size_t x = 0; x < dims[0]; x++)
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++)
                    if (vals[x][y][z] == max_val) {
                        // set the patch that we will be analyzing
                        int xl = x - 1;
                        int xh = x + 1;
                        int yl = y - 1;
                        int yh = y + 1;
                        int zl = z - 1;
                        int zh = z + 1;

                        // set the patch that we will be analyzing
                        int xl0 = MAX(xl, 0);
                        int xh0 = MIN(xh, dims[0]-1);
                        int yl0 = MAX(yl, 0);
                        int yh0 = MIN(yh, dims[1]-1);
                        int zl0 = MAX(zl, 0);
                        int zh0 = MIN(zh, dims[2]-1);

                        intensity_t min_val = max_val;

                        for (int xi = xl0; xi <= xh0; xi++)
                            for (int yi = yl0; yi <= yh0; yi++)
                                for (int zi = zl0; zi <= zh0; zi++)
                                    if (vals[xi][yi][zi] < min_val)
                                        min_val = vals[xi][yi][zi];

                        if (min_val == max_val)
                            vals[x][y][z] = min_val;
                        else if (min_val < vals[x][y][z])
                            vals[x][y][z] = min_val + 1;
                    }

    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                target[x][y][z] = (target[x][y][z] - mina) *
                                  exp(-rho * vals[x][y][z]) + mina;

    for (size_t j = 0; j < num_raters; j++) {
        for (size_t x = 0; x < dims[0]; x++)
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++) {
                    vals[x][y][z] = (obs[x][y][z][j] > 0) ? 0 : max_val;
                    if (ims[x][y][z][j] < mina)
                        ims[x][y][z][j] = mina;
                }
        for (int i = 0; i < num; i++)
            for (size_t x = 0; x < dims[0]; x++)
                for (size_t y = 0; y < dims[1]; y++)
                    for (size_t z = 0; z < dims[2]; z++)
                        if (vals[x][y][z] == max_val) {
                            // set the patch that we will be analyzing
                            int xl = x - 1;
                            int xh = x + 1;
                            int yl = y - 1;
                            int yh = y + 1;
                            int zl = z - 1;
                            int zh = z + 1;

                            // set the patch that we will be analyzing
                            int xl0 = MAX(xl, 0);
                            int xh0 = MIN(xh, dims[0]-1);
                            int yl0 = MAX(yl, 0);
                            int yh0 = MIN(yh, dims[1]-1);
                            int zl0 = MAX(zl, 0);
                            int zh0 = MIN(zh, dims[2]-1);

                            intensity_t min_val = max_val;

                            for (int xi = xl0; xi <= xh0; xi++)
                                for (int yi = yl0; yi <= yh0; yi++)
                                    for (int zi = zl0; zi <= zh0; zi++)
                                        if (vals[xi][yi][zi] < min_val)
                                            min_val = vals[xi][yi][zi];

                            if (min_val == max_val)
                                vals[x][y][z] = min_val;
                            else if (min_val < vals[x][y][z])
                                vals[x][y][z] = min_val + 1;
                        }

        for (size_t x = 0; x < dims[0]; x++)
            for (size_t y = 0; y < dims[1]; y++)
                for (size_t z = 0; z < dims[2]; z++)
                    ims[x][y][z][j] = (ims[x][y][z][j] - mina) *
                                       exp(-rho * vals[x][y][z]) + mina;
    }

    for(size_t x = 0; x < dims[0]; x++) {
        for (size_t y = 0; y < dims[1]; y++)
            free(vals[x][y]);
        free(vals[x]);
    }
    free(vals);
}

void normalize_intensity(intensity_t *** target,
                         label_t *** est,
                         intensity_t **** ims,
                         label_t **** obs,
                         const size_t * dims,
                         const size_t num_raters,
                         const size_t num_labels,
                         const size_t num_coef,
                         const intensity_t rho)
{

    // normalize the intensity of the target and the atlases to a unit gaussian
    normalize_intensity_unit(target, est, ims, obs, dims, num_raters);

    // normalize the intensity using a polynomial
    normalize_intensity_poly(target, est, ims, obs, dims,
                             num_raters, num_labels, num_coef);

    // decay the intensities in the background
    normalize_intensity_decay(target, est, ims, obs, dims, num_raters, rho);
}

void print_status(size_t ind,
                  size_t num)
{
    int currval = (int)(20 * (double)ind) / ((double)num);
    int prevval = (int)(20 * ((double)ind-1)) / ((double)num);

    if (ind == 0)
        fprintf(stdout, "[");
    else if (ind == num-1)
        fprintf(stdout, "]\n");
    else if (currval > prevval)
        fprintf(stdout, "=");

    fflush(stdout);
}

