/**
 * @file  reversedeformationfield.cxx
 * @brief Estimate inverse deformation field.
 *
 * Copyright (c) 2008, 2009, 2012 University of Pennsylvania.
 *
 * This file is part of DTI-DROID.
 *
 * DTI-DROID is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * DTI-DROID is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with DTI-DROID.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Contact: SBIA Group <sbia-software at uphs.upenn.edu>
 */

#include <stdlib.h>
#include <unistd.h>
#include <stdio.h>
#include <math.h>
#include <string.h>
#include "mvcddgtools.h"
#include "cresdgtools.h"
#include "matrixSHENdgtools.h"  /*by SHEN*/

/* features for vertices */
#define BG    0            /* background */
#define CSF   10
#define VN    50
#define GM    150
#define WM    250

#define YYES    1
#define NNO     0

#define Zscale  1.5
#define SHIFT   2 /* at least > 1 */
#define OUTSIDE 100000.0

int    HeadInField, Req_z_size ;

void MarcMain(int, char *[]);
void ShenMain(int argc, char *argv[]) ;
void show_usage_SHEN() ;


void WriteImg(char filename[80], unsigned char ***data, int image_size, int z_size) {
    FILE  *fp;
    int   i, k ;
    
    /* write the smoothed image */
    fp=myopen(filename, "w");
    for(k=0;k<z_size;k++)
        for(i=0;i<image_size;i++)
            fwrite(data[k][i], 1, image_size, fp);
    fclose(fp);
}

void WriteDeformationField(char filename[180], Fvector3d ***DeformFld, int image_size, int z_size) {
    FILE  *fp;
    int   i, k ;
    
    /* write the smoothed image */
    fp=myopen(filename, "w");
    /*fwrite(&image_size,sizeof(int),1,fp);
    fwrite(&z_size,sizeof(int),1,fp);*/
    for(k=0;k<z_size;k++)
        for(i=0;i<image_size;i++)
            fwrite(DeformFld[k][i], sizeof(Fvector3d), image_size, fp);
    fclose(fp);
}

void OpenDeformationField(char filename[80], Fvector3d ***DeformFld, int image_size, int z_size) {
    FILE  *fp;
    float XYZ_ratio, deform_ratio, dx, dy, dz ;
    int Last_image_size, Last_z_size ;
    int k, i ;
    
    /* open deformation field in the last resolution */
    fp=myopen(filename, "r");
    if( HeadInField==YYES ) {
        fread(&Last_image_size, sizeof(int), 1, fp); printf("Last_image_size=%d ", Last_image_size) ;
        fread(&Last_z_size, sizeof(int), 1, fp);     printf("Last_z_size=%d\n", Last_z_size) ;
    }
    
    for(k=0;k<z_size;k++)
        for(i=0;i<image_size;i++)
            fread(DeformFld[k][i], sizeof(Fvector3d), image_size, fp);
    fclose(fp);
}



void IterativelyEstimateInverseDeformationField(Fvector3d  Displace_subVoxel, float ii, float jj, float kk, Fvector3d ***ReversedField_BigSpace, float ***TotalWeights, int image_size, int z_size) {
    float CurrentV ;
    float b, c, d, b1, c1, d1, combined, weight;
    int   ni, nj, nk, niP1, njP1, nkP1, GreyValue ;
    
    
    /* To make index of matrix valid */
    ii += SHIFT ;
    jj += SHIFT ;
    kk += SHIFT ;
    
    ni = (int)ii ;
    nj = (int)jj ;
    nk = (int)kk ;
    
    niP1 = ni+1 ;
    njP1 = nj+1 ;
    nkP1 = nk+1 ;
    
    if(ni>=0 && ni<image_size+2*SHIFT-1  &&  nj>=0 && nj<image_size+2*SHIFT-1 &&  nk>=0 && nk<z_size+2*SHIFT-1 ) {
        b = ii-ni ;        b1 = 1.-b ;
        c = jj-nj ;        c1 = 1.-c ;
        d = kk-nk ;        d1 = 1.-d ;
        
        combined = ( d1*((b1*c1)+(b*c1)+(b1*c)+(b*c)) + d*((b1*c1)+(b*c1)+(b1*c)+(b*c)) ) ;
        
        
        weight = d1*(b1*c1)/combined ;
        TotalWeights[nk][ni][nj]   += weight ;
        ReversedField_BigSpace[nk][ni][nj].x += Displace_subVoxel.x*weight ;
        ReversedField_BigSpace[nk][ni][nj].y += Displace_subVoxel.y*weight ;
        ReversedField_BigSpace[nk][ni][nj].z += Displace_subVoxel.z*weight ;
        
        weight = d1*(b*c1)/combined ;
        TotalWeights[nk][niP1][nj]   += weight ;
        ReversedField_BigSpace[nk][niP1][nj].x += Displace_subVoxel.x*weight ;
        ReversedField_BigSpace[nk][niP1][nj].y += Displace_subVoxel.y*weight ;
        ReversedField_BigSpace[nk][niP1][nj].z += Displace_subVoxel.z*weight ;
        
        weight = d1*(b1*c)/combined ;
        TotalWeights[nk][ni][njP1]   += weight ;
        ReversedField_BigSpace[nk][ni][njP1].x += Displace_subVoxel.x*weight ;
        ReversedField_BigSpace[nk][ni][njP1].y += Displace_subVoxel.y*weight ;
        ReversedField_BigSpace[nk][ni][njP1].z += Displace_subVoxel.z*weight ;
        
        weight = d1*(b*c)/combined ;
        TotalWeights[nk][niP1][njP1]   += weight ;
        ReversedField_BigSpace[nk][niP1][njP1].x += Displace_subVoxel.x*weight ;
        ReversedField_BigSpace[nk][niP1][njP1].y += Displace_subVoxel.y*weight ;
        ReversedField_BigSpace[nk][niP1][njP1].z += Displace_subVoxel.z*weight ;
        
        weight = d*(b1*c1)/combined ;
        TotalWeights[nkP1][ni][nj]   += weight ;
        ReversedField_BigSpace[nkP1][ni][nj].x += Displace_subVoxel.x*weight ;
        ReversedField_BigSpace[nkP1][ni][nj].y += Displace_subVoxel.y*weight ;
        ReversedField_BigSpace[nkP1][ni][nj].z += Displace_subVoxel.z*weight ;
        
        weight = d*(b*c1)/combined ;
        TotalWeights[nkP1][niP1][nj]   += weight ;
        ReversedField_BigSpace[nkP1][niP1][nj].x += Displace_subVoxel.x*weight ;
        ReversedField_BigSpace[nkP1][niP1][nj].y += Displace_subVoxel.y*weight ;
        ReversedField_BigSpace[nkP1][niP1][nj].z += Displace_subVoxel.z*weight ;
        
        weight = d*(b1*c)/combined ;
        TotalWeights[nkP1][ni][njP1]   += weight ;
        ReversedField_BigSpace[nkP1][ni][njP1].x += Displace_subVoxel.x*weight ;
        ReversedField_BigSpace[nkP1][ni][njP1].y += Displace_subVoxel.y*weight ;
        ReversedField_BigSpace[nkP1][ni][njP1].z += Displace_subVoxel.z*weight ;
        
        weight = d*(b*c)/combined ;
        TotalWeights[nkP1][niP1][njP1]   += weight ;
        ReversedField_BigSpace[nkP1][niP1][njP1].x += Displace_subVoxel.x*weight ;
        ReversedField_BigSpace[nkP1][niP1][njP1].y += Displace_subVoxel.y*weight ;
        ReversedField_BigSpace[nkP1][niP1][njP1].z += Displace_subVoxel.z*weight ;
    }
}


unsigned char InterpolatedIntensity(float ii, float jj, float kk, unsigned char ***Img, int image_size, int z_size) {
    float CurrentV ;
    float b, c, d, b1, c1, d1;
    int   ni, nj, nk, niP1, njP1, nkP1, GreyValue ;
    
    
    ni = (int)ii ;
    nj = (int)jj ;
    nk = (int)kk ;
    
    niP1 = ni+1 ;
    njP1 = nj+1 ;
    nkP1 = nk+1 ;
    
    if(ni>=0 && ni<image_size-1  &&  nj>=0 && nj<image_size-1  &&  nk>=0 && nk<z_size-1 ) {
        b = ii-ni ;        b1 = 1.-b ;
        c = jj-nj ;        c1 = 1.-c ;
        d = kk-nk ;        d1 = 1.-d ;
        
        CurrentV = ( d1*(Img[nk][ni][nj]*(b1*c1) + Img[nk][niP1][nj]*(b*c1) + Img[nk][ni][njP1]*(b1*c) + Img[nk][niP1][njP1]*(b*c)) + d*(Img[nkP1][ni][nj]*(b1*c1) + Img[nkP1][niP1][nj]*(b*c1) + Img[nkP1][ni][njP1]*(b1*c) + Img[nkP1][niP1][njP1]*(b*c)) )/( d1*((b1*c1)+(b*c1)+(b1*c)+(b*c)) + d*((b1*c1)+(b*c1)+(b1*c)+(b*c)) ) ;
        
        if( CurrentV>255 )
            GreyValue = 255 ;
        else
            GreyValue = static_cast<int>(CurrentV);
    }
    
    if(ni==image_size-1 && nj>=0 && nj<image_size-1 && nk>=0 && nk<z_size-1 || ni>=0 && ni<image_size-1 && nj==image_size-1 && nk>=0 && nk<z_size-1  || ni>=0 && ni<image_size-1 && nj>=0 && nj<image_size-1 && nk==z_size-1)
        GreyValue = Img[nk][ni][nj] ;
    
    return GreyValue ;
}

void InterpolatedDisplacement(Fvector3d *Displace_subVoxel, float ii, float jj, float kk, Fvector3d ***DeformFld, int image_size, int z_size) {
    float CurrentV ;
    float b, c, d, b1, c1, d1;
    int   ni, nj, nk, niP1, njP1, nkP1, GreyValue ;
    
    
    ni = (int)ii ;
    nj = (int)jj ;
    nk = (int)kk ;
    
    niP1 = ni+1 ;
    njP1 = nj+1 ;
    nkP1 = nk+1 ;
    
    if(ni>=0 && ni<image_size-1  &&  nj>=0 && nj<image_size-1  &&  nk>=0 && nk<z_size-1 ) {
        b = ii-ni ;        b1 = 1.-b ;
        c = jj-nj ;        c1 = 1.-c ;
        d = kk-nk ;        d1 = 1.-d ;
        
        (*Displace_subVoxel).x = ( d1*(DeformFld[nk][ni][nj].x*(b1*c1) + DeformFld[nk][niP1][nj].x*(b*c1) + DeformFld[nk][ni][njP1].x*(b1*c) + DeformFld[nk][niP1][njP1].x*(b*c)) + d*(DeformFld[nkP1][ni][nj].x*(b1*c1) + DeformFld[nkP1][niP1][nj].x*(b*c1) + DeformFld[nkP1][ni][njP1].x*(b1*c) + DeformFld[nkP1][niP1][njP1].x*(b*c)) )/( d1*((b1*c1)+(b*c1)+(b1*c)+(b*c)) + d*((b1*c1)+(b*c1)+(b1*c)+(b*c)) ) ;
        
        (*Displace_subVoxel).y = ( d1*(DeformFld[nk][ni][nj].y*(b1*c1) + DeformFld[nk][niP1][nj].y*(b*c1) + DeformFld[nk][ni][njP1].y*(b1*c) + DeformFld[nk][niP1][njP1].y*(b*c)) + d*(DeformFld[nkP1][ni][nj].y*(b1*c1) + DeformFld[nkP1][niP1][nj].y*(b*c1) + DeformFld[nkP1][ni][njP1].y*(b1*c) + DeformFld[nkP1][niP1][njP1].y*(b*c)) )/( d1*((b1*c1)+(b*c1)+(b1*c)+(b*c)) + d*((b1*c1)+(b*c1)+(b1*c)+(b*c)) ) ;
        
        (*Displace_subVoxel).z = ( d1*(DeformFld[nk][ni][nj].z*(b1*c1) + DeformFld[nk][niP1][nj].z*(b*c1) + DeformFld[nk][ni][njP1].z*(b1*c) + DeformFld[nk][niP1][njP1].z*(b*c)) + d*(DeformFld[nkP1][ni][nj].z*(b1*c1) + DeformFld[nkP1][niP1][nj].z*(b*c1) + DeformFld[nkP1][ni][njP1].z*(b1*c) + DeformFld[nkP1][niP1][njP1].z*(b*c)) )/( d1*((b1*c1)+(b*c1)+(b1*c)+(b*c)) + d*((b1*c1)+(b*c1)+(b1*c)+(b*c)) ) ;
    }
    
    if(ni==image_size-1 && nj>=0 && nj<image_size-1 && nk>=0 && nk<z_size-1 || ni>=0 && ni<image_size-1 && nj==image_size-1 && nk>=0 && nk<z_size-1  || ni>=0 && ni<image_size-1 && nj>=0 && nj<image_size-1 && nk==z_size-1) {
        (*Displace_subVoxel).x = DeformFld[nk][ni][nj].x ;
        (*Displace_subVoxel).y = DeformFld[nk][ni][nj].y ;
        (*Displace_subVoxel).z = DeformFld[nk][ni][nj].z ;
    }
}


void EstimateInverseDeformationField( int SampleNum, Fvector3d ***DeformFld, int image_size, int z_size, char filename[80] ) {
    int s, i, j, k, x, y, z ;
    FILE  *fp;
    Fvector3d     ***ReversedField, ***ReversedField_BigSpace ;
    float         ***TotalWeights, max, min ;
    float  ii, jj, kk, interval ;
    int    temp, GreyValue, MappedV ;
    Fvector3d  Mdl_subVoxel, Displace_subVoxel ;
    
    
    interval = 1.0/(SampleNum*2+1) ;
    printf("interval=%f SampleNum=%d\n", interval, SampleNum) ;
    
    /* reverse field */
    TotalWeights           = Falloc3d(image_size+2*SHIFT, image_size+2*SHIFT, Req_z_size+2*SHIFT);
    ReversedField_BigSpace = Fvector3dalloc3d(image_size+2*SHIFT, image_size+2*SHIFT, Req_z_size+2*SHIFT);
    ReversedField          = Fvector3dalloc3d(image_size, image_size, Req_z_size);
    for(k=0; k<Req_z_size+2*SHIFT; k++)
        for(i=0; i<image_size+2*SHIFT; i++)
            for(j=0; j<image_size+2*SHIFT; j++) {
        TotalWeights[k][i][j] = 0 ;
        
        ReversedField_BigSpace[k][i][j].x = 0 ;
        ReversedField_BigSpace[k][i][j].y = 0 ;
        ReversedField_BigSpace[k][i][j].z = 0 ;
            }
    
    /* estimation ... */
    for(k=0; k<z_size; k++) {
        printf("z=%d\n", k) ;
        for(i=0; i<image_size; i++)
            for(j=0; j<image_size; j++) {
            for(z=-SampleNum; z<=SampleNum; z++)
                for(x=-SampleNum; x<=SampleNum; x++)
                    for(y=-SampleNum; y<=SampleNum; y++) {
                Mdl_subVoxel.x = x*interval + i ;
                Mdl_subVoxel.y = y*interval + j ;
                Mdl_subVoxel.z = z*interval + k ;
                
                /* Get the greyvalue interpolation */
                /*GreyValue = InterpolatedIntensity(Mdl_subVoxel.x, Mdl_subVoxel.y, Mdl_subVoxel.z, ObjOriginalImg, image_size, z_size) ;
                 * MappedV = GreyValue ;*/
                
                InterpolatedDisplacement(&Displace_subVoxel, Mdl_subVoxel.x, Mdl_subVoxel.y, Mdl_subVoxel.z, DeformFld, image_size, z_size) ;
                
                ii = Mdl_subVoxel.x + Displace_subVoxel.x ;
                jj = Mdl_subVoxel.y + Displace_subVoxel.y ;
                kk = Mdl_subVoxel.z + Displace_subVoxel.z ;
                
                IterativelyEstimateInverseDeformationField(Displace_subVoxel, ii, jj, kk, ReversedField_BigSpace, TotalWeights, image_size, Req_z_size) ;
                    }
            }
    }
    
    
    /* normalize ... */
    for(k=0; k<Req_z_size; k++)
        for(i=0; i<image_size; i++)
            for(j=0; j<image_size; j++)
                if( TotalWeights[k+SHIFT][i+SHIFT][j+SHIFT]>0 ) {
        ReversedField[k][i][j].x = ReversedField_BigSpace[k+SHIFT][i+SHIFT][j+SHIFT].x/(-TotalWeights[k+SHIFT][i+SHIFT][j+SHIFT]) ;
        ReversedField[k][i][j].y = ReversedField_BigSpace[k+SHIFT][i+SHIFT][j+SHIFT].y/(-TotalWeights[k+SHIFT][i+SHIFT][j+SHIFT]) ;
        ReversedField[k][i][j].z = ReversedField_BigSpace[k+SHIFT][i+SHIFT][j+SHIFT].z/(-TotalWeights[k+SHIFT][i+SHIFT][j+SHIFT]) ;
                }
                else {
        ReversedField[k][i][j].x = OUTSIDE ;
        ReversedField[k][i][j].y = OUTSIDE ;
        ReversedField[k][i][j].z = OUTSIDE ;
                }
    
    /* save */
    WriteDeformationField(filename, ReversedField, image_size, Req_z_size) ;
    
    /* free */
    Ffree3d(TotalWeights, Req_z_size+2*SHIFT, image_size+2*SHIFT) ;
    Fvector3dfree3d(ReversedField_BigSpace, Req_z_size+2*SHIFT, image_size+2*SHIFT) ;
    Fvector3dfree3d(ReversedField, Req_z_size, image_size) ;
}


void reportVersion(char *nameOfBinary) {
    char fullVersion[] = "$Rev: 1 $"; //This line will be automatically
    //populated by svn and should be
    //of the form "$Rev: 156 $" else it could
    //cause a segmentation fault
    char *shortVersion = strchr(fullVersion, ' ')+1; //Eliminate te word "$Rev: "
    char *secondSpaceInVersion = strchr(shortVersion, ' '); //Find the space after the number
    *secondSpaceInVersion ='\0';//Eliminate anything after the number, by zeroing the space
    printf("%s version: 0.0.%s \n", nameOfBinary, shortVersion);
}

int main(int argc, char *argv[]) {
    ShenMain(argc, argv);
    return 0;
}



/* the following is edited by SHEN */
void ShenMain(int argc, char *argv[]) {
    int           image_size, z_size, nb_size ;
    Fvector3d     ***DeformFld ;
    int           i, j, k, c, num;
    FILE          *fp;
    extern char   *optarg;
    char          filename[80] ;
    float         level, weight, thres ;
    int           SampleNum ;
    
    
    num=3;
    
    if(argc<num){
        //TODO: This needs to be improved a bit to report the version only
        // when a user types exactly --version as the flag and not any other
        // word, but for now, that does the trick in a simple way.
        if(argc == 2) {
            reportVersion(argv[0]);
            exit(1);
        } else {
            show_usage_SHEN() ;
        }
    }
    
    /* input image size */
    image_size= 256;
    z_size    = 124 ;
    HeadInField = NNO ;
    SampleNum = 1 ;
    Req_z_size = -100000 ;
    while((c=getopt(argc-2, argv+2, "h:s:z:X:")) != -1) {
        switch(c) {
            case 'h':
                HeadInField = YYES ;
                break ;
                
            case 's':
                sscanf(optarg, "%d", &SampleNum) ; /* for upsampling vector domain */
                break ;
                
            case 'z':
                sscanf(optarg, "%d", &Req_z_size) ;
                break ;
                
            case 'X':
                sscanf(optarg, "%d", &image_size) ;
                break ;
                
                
            default:
                break;
        }
    }
    
    /* To estimate z_size */
    fp=myopen(argv[1], "r");
    if( HeadInField==YYES ) {
        fread(&image_size, sizeof(int), 1, fp); printf("image_size=%d ", image_size) ;
        fread(&z_size, sizeof(int), 1, fp);     printf("z_size=%d\n", z_size) ;
    }
    else {
        fseek(fp, 0, SEEK_END);
        z_size=ftell(fp)/(image_size*image_size*12);
        rewind(fp);
    }
    fclose(fp);
    printf("vector field size: image_size=%d  z_size=%d\n", image_size, z_size) ;
    
    if( Req_z_size<z_size )
        Req_z_size=z_size;
    printf("Your required slice number is %d\n", z_size) ;
    
    /* vector field */
    DeformFld = Fvector3dalloc3d(image_size, image_size, z_size);
    OpenDeformationField(argv[1], DeformFld, image_size, z_size);
    
    /* reversely perform warping and save the result */
    EstimateInverseDeformationField( SampleNum, DeformFld, image_size, z_size, argv[2] ) ;
}


void show_usage_SHEN() {
    printf("USAGE: reversedeformationfield <VectorField> <output-reversed-field>\n\
\t -h                   : there is head information in vector field \n\
\t -s <int>             : upsample rate(i.e. 1, 2; default: 1) \n\
\t -z <int>             : your required z slice numer(should be bigger than z slice number in deformation field) \n\
\t -X <int>             : image size in XY plane; default: 256 \n");
    exit(1);
}

