#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <nifti1.h>
#include <nifti1_io.h>
#include "fusion_io.h"

intensity_t *** get_target(const char * file,
                           size_t * dims)
{
    // initialize some variables
    intensity_t *** target;
    nifti_image * nim;

    // read the nifti file
    nim = nifti_image_read(file, 1);

    // set the dimensions of the target
    for (size_t i = 0; i < 3; i++)
        dims[i] = (size_t)nim->dim[i+1];

    // allocate the target matrix
    target = (intensity_t ***)malloc(dims[0] * sizeof(*target));
    for (size_t x = 0; x < dims[0]; x++) {
        target[x] = (intensity_t **)malloc(dims[1] * sizeof(*target[x]));
        for (size_t y = 0; y < dims[1]; y++) {
            target[x][y] = (intensity_t *)malloc(dims[2] *
                                                 sizeof(*target[x][y]));
        }
    }

    // set the target
    get_nim_3D<intensity_t>(nim, target, dims);

    // free the nifti image
    nifti_image_free(nim);

    // return the target
    return(target);
}

label_t *** get_est(const char * file,
                    const size_t * dims,
                    const size_t num_labels)
{
    // initialize some variables
    label_t *** est;
    nifti_image * nim;

    // read the nifti file
    nim = nifti_image_read(file, 1);

    // set the dimensions of the target
    for (size_t i = 0; i < 3; i++)
        if (dims[i] != (size_t)nim->dim[i+1]) {
            fprintf(stderr, "Error estimate dimensions do not match target\n");
            exit(1);
        }

    // allocate the target matrix
    est = (label_t ***)malloc(dims[0] * sizeof(*est));
    for (size_t x = 0; x < dims[0]; x++) {
        est[x] = (label_t **)malloc(dims[1] * sizeof(*est[x]));
        for (size_t y = 0; y < dims[1]; y++) {
            est[x][y] = (label_t *)malloc(dims[2] * sizeof(*est[x][y]));
        }
    }

    // set the estimate
    get_nim_3D<label_t>(nim, est, dims);

    size_t temp_num_labels = 0;
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (est[x][y][z] > temp_num_labels)
                    temp_num_labels = est[x][y][z];
    temp_num_labels++;

    if (num_labels != temp_num_labels) {
        fprintf(stderr, "Differing number of labels on estimate and obs\n");
        exit(1);
    }

    // free the nifti image
    nifti_image_free(nim);

    // return the target
    return(est);
}

label_t *** get_est2(const char * file,
                     const size_t * dims,
                     size_t * num_labels)
{
    // initialize some variables
    label_t *** est;
    nifti_image * nim;

    // read the nifti file
    nim = nifti_image_read(file, 1);

    // set the dimensions of the target
    for (size_t i = 0; i < 3; i++)
        if (dims[i] != (size_t)nim->dim[i+1]) {
            fprintf(stderr, "Error estimate dimensions do not match target\n");
            exit(1);
        }

    // allocate the target matrix
    est = (label_t ***)malloc(dims[0] * sizeof(*est));
    for (size_t x = 0; x < dims[0]; x++) {
        est[x] = (label_t **)malloc(dims[1] * sizeof(*est[x]));
        for (size_t y = 0; y < dims[1]; y++) {
            est[x][y] = (label_t *)malloc(dims[2] * sizeof(*est[x][y]));
        }
    }

    // set the estimate
    get_nim_3D<label_t>(nim, est, dims);

    size_t temp_num_labels = 0;
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                if (est[x][y][z] > temp_num_labels)
                    temp_num_labels = est[x][y][z];
    temp_num_labels++;
    *num_labels = temp_num_labels;

    // free the nifti image
    nifti_image_free(nim);

    // return the target
    return(est);
}

intensity_t **** get_ims(const char * file,
                         const size_t * dims,
                         const size_t num_raters)
{
    // initialize some variables
    intensity_t **** ims;

    // check if the input is a valid nifti file or not
    if (nifti_is_complete_filename(file)) {

        // read the nifti file
        nifti_image * nim = nifti_image_read(file, 1);

        // set the dimensions of the target
        for (size_t i = 0; i < 3; i++)
            if (dims[i] != (size_t)nim->dim[i+1]) {
                fprintf(stderr, "Error atlas dimensions do not match target\n");
                exit(1);
            }

        // check the number of raters
        if (num_raters != (size_t)nim->dim[4]) {
            fprintf(stderr, "Error number of raters mismatch\n");
            exit(1);
        }

        // allocate the target matrix
        ims = (intensity_t ****)malloc(dims[0] * sizeof(*ims));
        for (size_t x = 0; x < dims[0]; x++) {
            ims[x] = (intensity_t ***)malloc(dims[1] * sizeof(*ims[x]));
            for (size_t y = 0; y < dims[1]; y++) {
                ims[x][y] = (intensity_t **)malloc(dims[2] *
                                                   sizeof(*ims[x][y]));
                for (size_t z = 0; z < dims[2]; z++) {
                    ims[x][y][z] = (intensity_t *)malloc(num_raters *
                                                         sizeof(*ims[x][y][z]));
                }
            }
        }

        // set the ims
        get_nim_4D<intensity_t>(nim, ims, dims, num_raters);

        // free the nifti image
        nifti_image_free(nim);
    } else {

        fprintf(stdout, "\tAssuming text file input (not nifti file)\n");
        size_t * temp_dims;
        size_t num;

        temp_dims = (size_t *)malloc(3 * sizeof(*dims));

        // set the ims
        ims = read_4D_textfile<intensity_t>(file, temp_dims, &num);

        // make sure the number of raters match
        if (num != num_raters) {
            fprintf(stderr, "Error number of raters mismatch\n");
            exit(1);
        }

        // make sure the dimensions match
        for (size_t i = 0; i < 3; i++)
            if (dims[i] != temp_dims[i]) {
                fprintf(stderr, "Error atlas dimensions do not match target\n");
                exit(1);
            }

        // free the temporary dimensions
        free(temp_dims);
    }

    // return the target
    return(ims);
}

label_t **** get_obs(const char * file,
                     const size_t * dims,
                     size_t * num_raters,
                     size_t * num_labels)
{
    // initialize some variables
    label_t **** obs;

    // check if the input is a valid nifti file or not
    if (nifti_is_complete_filename(file)) {

        // read the nifti file
        nifti_image * nim = nifti_image_read(file, 1);

        // set the number of raters
        *num_raters = nim->dim[4];

        // set the dimensions of the target
        for (size_t i = 0; i < 3; i++)
            if (dims[i] != (size_t)nim->dim[i+1]) {
                fprintf(stderr, "Error obs dimensions do not match target\n");
                exit(1);
            }

        // allocate the target matrix
        obs = (label_t ****)malloc(dims[0] * sizeof(*obs));
        for (size_t x = 0; x < dims[0]; x++) {
            obs[x] = (label_t ***)malloc(dims[1] * sizeof(*obs[x]));
            for (size_t y = 0; y < dims[1]; y++) {
                obs[x][y] = (label_t **)malloc(dims[2] * sizeof(*obs[x][y]));
                for (size_t z = 0; z < dims[2]; z++) {
                    obs[x][y][z] = (label_t *)malloc(*num_raters *
                                                        sizeof(*obs[x][y][z]));
                }
            }
        }

        // read the data from the nifti image
        get_nim_4D<label_t>(nim, obs, dims, *num_raters);

        // free the nifti image
        nifti_image_free(nim);
    } else {

        fprintf(stdout, "\tAssuming text file input (not nifti file)\n");
        size_t * temp_dims = (size_t *)malloc(3 * sizeof(*dims));

        // set the obs
        obs = read_4D_textfile<label_t>(file, temp_dims, num_raters);

        // make sure the dimensions match
        for (size_t i = 0; i < 3; i++)
            if (dims[i] != temp_dims[i]) {
                fprintf(stderr, "Error atlas dimensions do not match target\n");
                exit(1);
            }

        // free the temporary dimensions
        free(temp_dims);
    }

    // set the number of labels
    *num_labels = 0;
    for (size_t x = 0; x < dims[0]; x++)
        for (size_t y = 0; y < dims[1]; y++)
            for (size_t z = 0; z < dims[2]; z++)
                for (size_t j = 0; j < *num_raters; j++)
                    if (obs[x][y][z][j] > *num_labels)
                        *num_labels = obs[x][y][z][j];
    (*num_labels)++;

    // return the target
    return(obs);
}

