/**
 * @file  rigidtransform.cxx
 * @brief Originally part of HAMMER.
 *
 * Copyright (c) 2001, 2012, 2013 University of Pennsylvania.
 *
 * This file is part of DTI-DROID.
 *
 * DTI-DROID is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * DTI-DROID is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with DTI-DROID.  If not, see <http://www.gnu.org/licenses/>.
 *
 * Contact: SBIA Group <sbia-software at uphs.upenn.edu>
 */

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <string.h>
#include <getopt.h>

#include "mvcd.h"
#include "cres.h"
#include "matrixSHEN.h"

#include <dtidroid/basis.h>


// acceptable in .cxx file
using namespace dtidroid;


// ===========================================================================
// help
// ===========================================================================

// ---------------------------------------------------------------------------
void show_usage() {
    print_version("rigidtransform", "", "");
    printf("\n");
    printf("USAGE: rigidtransform image_A image_B transformed_image_A transformation <options>\n\
    \t -X               : XY size\n\
    \t -g               : general case, no constraints on the directions of the principal axes\n\
    \t -T <string>      : another image file, in order to operate the same transformation on it\n\
    \t -R               : Rigid body transformation only\n\
    \t -B               : for the brain application case where the brain pose is roughly oriented!\n\
    \t -F               : Don't fix digital resampling errors!\n\
    \t -t               : threshold for intensities to be focused(default:10)\n");
}

// ===========================================================================
// constants and global variables
// ===========================================================================

#define FALSE   0
#define TRUE    1
#define NNO     0
#define YYES    1

int     generalCase, RigidTransformOnly, BrainCaseWhereOrientDiffLT90degree, FixDigitalSamplingProb ;
char    OtherImgFile[180] ;
int     OperateSameTransformation ;
unsigned char ***data3, ***data3_to_2 ;
int     z_size3, threshold ;

// ===========================================================================
// forward declarations of functions
// ===========================================================================

void TransformImgA2ImgB(int image_size, unsigned char ***data1, int z_size1, unsigned char ***data2, int z_size2, unsigned char ***data1_to_2, char filename[180]) ;
void CenterOf3DImage(Fvector3d *center, unsigned char ***img, int image_size, int z_size);
void MomentsOf3DImage(Matrix *momentsMatrix3x3, Fvector3d center, unsigned char ***img, int image_size, int z_size) ;
void Write_Transformation( Matrix *Transform, Fvector3d center1, Fvector3d center2, char filename[180]) ;

// ===========================================================================
// main
// ===========================================================================

int main(int argc, char *argv[]) {
    unsigned char ***data1, ***data1_to_2, ***data2 ;
    int           i, j, k, image_size, z_size1, z_size2, c, num;
    FILE          *fp;
    extern char   *optarg;
    char          filename[180] ;
    
    
    image_size=256;
    num=4;
    OperateSameTransformation = FALSE ;
    
    if(argc<num) {
        show_usage();
        exit(1);
    }
    
    threshold = 10 ;
    generalCase=FALSE ;
    RigidTransformOnly=FALSE ;
    BrainCaseWhereOrientDiffLT90degree=FALSE ;
    FixDigitalSamplingProb = YYES ;
    while((c=getopt(argc-4, argv+4, "gT:RBFtX:")) != -1) {
        switch(c) {
            case 'g':
                generalCase=TRUE ;
                break;
                
            case 'T':
                sscanf(optarg, "%s", OtherImgFile) ;
                OperateSameTransformation = TRUE ;
                break ;
                
            case 'R':
                RigidTransformOnly=TRUE ;
                break;
                
            case 'B':
                BrainCaseWhereOrientDiffLT90degree=TRUE ;
                break;
                
            case 'F':
                FixDigitalSamplingProb = NNO ;
                break;
                
            case 't':
                sscanf(optarg, "%d", &threshold) ;
                break ;
                
            case 'X':
                sscanf(optarg, "%d", &image_size) ;
                break;
                
            default:
                break;
        }
    }
    printf("threshold=%d\n", threshold) ;
    printf("generalCase=%d\n",  generalCase) ;
    if( OperateSameTransformation==TRUE ) printf("Same transformation on the second image!\n") ;
    
    
    /***** allocate memory and read image data *****/
    /* data 1 */
    printf("%s!\n", argv[1]) ;
    fp=myopen(argv[1], "rb");
    fseek(fp, 0, SEEK_END);
    z_size1=ftell(fp)/(image_size*image_size);
    rewind(fp);
    
    data1=UCalloc3d(image_size, image_size, z_size1);
    
    for(k=0;k<z_size1;k++)
        for(i=0;i<image_size;i++)
            fread(data1[k][i], 1, image_size, fp);
    fclose(fp);
    
    /* data 2 */
    printf("%s!\n", argv[2]) ;
    fp=myopen(argv[2], "rb");
    fseek(fp, 0, SEEK_END);
    z_size2=ftell(fp)/(image_size*image_size);
    rewind(fp);
    
    data2     =UCalloc3d(image_size, image_size, z_size2);
    data1_to_2=UCalloc3d(image_size, image_size, z_size2);
    
    for(k=0;k<z_size2;k++)
        for(i=0;i<image_size;i++)
            fread(data2[k][i], 1, image_size, fp);
    fclose(fp);
    
    
    /* data 3 (optinal) */
    if( OperateSameTransformation==TRUE ) {
        printf("%s!\n", OtherImgFile) ;
        if( (fp=fopen(OtherImgFile, "rb"))!=NULL ) {
            fseek(fp, 0, SEEK_END);
            z_size3=ftell(fp)/(image_size*image_size);
            rewind(fp);
            
            data3=UCalloc3d(image_size, image_size, z_size3);
            
            for(k=0;k<z_size3;k++)
                for(i=0;i<image_size;i++)
                    fread(data3[k][i], 1, image_size, fp);
            fclose(fp);
            
            data3_to_2=UCalloc3d(image_size, image_size, z_size2);
        }
        else
            OperateSameTransformation=FALSE ;
    }
    
    /* transform */
    TransformImgA2ImgB(image_size, data1, z_size1, data2, z_size2, data1_to_2, argv[4]) ; /* data3 transformed here */
    
    /* saving ... */
    printf("%s\n", argv[3]) ;
    fp=myopen(argv[3], "wb");
    for(k=0;k<z_size2;k++)
        for(i=0;i<image_size;i++)
            fwrite(data1_to_2[k][i], 1, image_size, fp);
    fclose(fp);
    
    if( OperateSameTransformation==TRUE ) {
        /* save transformation on the second image  ... */
        sprintf(filename, "%s.rig", OtherImgFile);    printf("%s\n", filename) ;
        fp=myopen(filename, "wb");
        for(k=0;k<z_size2;k++)
            for(i=0;i<image_size;i++)
                fwrite(data3_to_2[k][i], 1, image_size, fp);
        fclose(fp);
    }
    return 0;
}

// ===========================================================================
// auxiliary functions
// ===========================================================================

float PixelLevelProcessing(unsigned char level) {
    float result ;
    
    /*result = level ;*/
    result = 255 ;
    if(level==204)
        result = 128. ;
    
    return result ;
}


void CenterOf3DImage(Fvector3d *center, unsigned char ***img, int image_size, int z_size) {
    int       i, j, k ;
    Fvector3d sum ;
    float     m000 ;
    
    /* center  */
    sum.x=0; sum.y=0; sum.z=0; m000=0 ;
    for(k=0;k<z_size;k++)
        for(i=0;i<image_size;i++)
            for(j=0;j<image_size;j++) {
                if(img[k][i][j]>threshold) {
                    sum.x+=(float) (i)*PixelLevelProcessing(img[k][i][j])/255.;
                    sum.y+=(float) (j)*PixelLevelProcessing(img[k][i][j])/255.;
                    sum.z+=(float) (k)*PixelLevelProcessing(img[k][i][j])/255.;
                    m000 += PixelLevelProcessing(img[k][i][j])/255. ;
                }
            }
    (*center).x=sum.x/m000;
    (*center).y=sum.y/m000;
    (*center).z=sum.z/m000;
    printf("center=(%f, %f, %f)\n", (*center).x, (*center).y, (*center).z) ;
}


void MomentsOf3DImage(Matrix *momentsMatrix3x3, Fvector3d center, unsigned char ***img, int image_size, int z_size) {
    int       i, j, k ;
    float     x2, xy, xz, yz, y2, z2, m000 ;
    
    printf("moments:\n") ;
    /* moments */
    /* x2 */
    x2 = 0 ;
    xy = 0 ;
    xz = 0 ;
    yz = 0 ;
    y2 = 0 ;
    z2 = 0 ;
    m000 = 0 ;
    for(k=0;k<z_size;k++)
        for(i=0;i<image_size;i++)
            for(j=0;j<image_size;j++) {
                if(img[k][i][j]>threshold) {
                    x2 += (i-center.x)*(i-center.x)*PixelLevelProcessing(img[k][i][j]) ;
                    y2 += (j-center.y)*(j-center.y)*PixelLevelProcessing(img[k][i][j]) ;
                    z2 += (k-center.z)*(k-center.z)*PixelLevelProcessing(img[k][i][j]) ;
                    
                    xy += (i-center.x)*(j-center.y)*PixelLevelProcessing(img[k][i][j]) ;
                    xz += (i-center.x)*(k-center.z)*PixelLevelProcessing(img[k][i][j]) ;
                    yz += (j-center.y)*(k-center.z)*PixelLevelProcessing(img[k][i][j]) ;
                    
                    m000 += PixelLevelProcessing(img[k][i][j]) ;
                }
            }
    x2 /= m000;
    y2 /= m000;
    z2 /= m000;
    xy /= m000;
    xz /= m000;
    yz /= m000;
    
    momentsMatrix3x3->data[0][0] = x2 ;
    momentsMatrix3x3->data[0][1] = xy ;
    momentsMatrix3x3->data[0][2] = xz ;
    
    momentsMatrix3x3->data[1][0] = xy ;
    momentsMatrix3x3->data[1][1] = y2 ;
    momentsMatrix3x3->data[1][2] = yz ;
    
    momentsMatrix3x3->data[2][0] = xz ;
    momentsMatrix3x3->data[2][1] = yz ;
    momentsMatrix3x3->data[2][2] = z2 ;
    
    /*printf("x2=%f, xy=%f, xz=%f, yz=%f, y2=%f, z2=%f\n\n", x2, xy, xz, yz, y2, z2) ;*/
    
    for(i=0; i<3; i++) {
        for(j=0; j<3; j++)
            printf("%5.3f ", momentsMatrix3x3->data[i][j]) ;
        printf("\n") ;
    }
}



void TransformImgA2ImgB(int image_size, unsigned char ***data1, int z_size1, unsigned char ***data2, int z_size2, unsigned char ***data1_to_2, char filename[180]) {
    int        i, j, k, x, y, z ;
    Matrix     *Data1_momentsMatrix3x3, *Data2_momentsMatrix3x3, *TransferedMtrx, *Transform, *temp1, *temp2 ;
    Fvector3d  center1, center2;
    float      delta_0, delta_1, delta_2 ;
    float      diff ;
    
    /* eigenvector */
    Matrix *EigenVector1, *EigenVector2;
    float  *EigenValue1,  *EigenValue2, *current_position, *transformed_position;
    Matrix *EigenValueMatrix1, *EigenValueMatrix2, *Eigen1Matrix, *Eigen2Matrix;
    
    
    /* create  metrices*/
    CreateMatrix(&Data1_momentsMatrix3x3,  3, 3);
    CreateMatrix(&Data2_momentsMatrix3x3,  3, 3);
    CreateMatrix(&TransferedMtrx,            3, 3);
    CreateMatrix(&Transform,               3, 3);
    CreateMatrix(&temp1,                   3, 3);
    CreateMatrix(&temp2,                   3, 3);
    
    /* centers of images */
    CenterOf3DImage(&center1, data1, image_size, z_size1) ;
    CenterOf3DImage(&center2, data2, image_size, z_size2) ;
    
    /* moments */
    MomentsOf3DImage(Data1_momentsMatrix3x3, center1, data1, image_size, z_size1) ;
    MomentsOf3DImage(Data2_momentsMatrix3x3, center2, data2, image_size, z_size2) ;
    
    
    /* eigenvectors and eigenvalues */
    CreateMatrix(&EigenVector1,  3, 3);
    CreateMatrix(&EigenVector2,  3, 3);
    EigenValue1 = vectorSHEN(0, 3-1) ;
    EigenValue2 = vectorSHEN(0, 3-1) ;
    current_position= vectorSHEN(0, 3-1) ;
    transformed_position= vectorSHEN(0, 3-1) ;
    CreateMatrix(&EigenValueMatrix1,  3, 3);
    CreateMatrix(&EigenValueMatrix2,  3, 3);
    CreateMatrix(&Eigen1Matrix,  3, 3);
    CreateMatrix(&Eigen2Matrix,  3, 3);
    
    
    /* compute eigenvectors and eigenvalues */
    Mat_Calculate_EigenVectors_EigenValues(Data1_momentsMatrix3x3, EigenValue1, EigenVector1, FALSE) ;
    Mat_Calculate_EigenVectors_EigenValues(Data2_momentsMatrix3x3, EigenValue2, EigenVector2, FALSE) ;
    
    Mat_Print(EigenVector1);
    Mat_Print(EigenVector2);
    
    
    /* To make sure that the calculated eigenvectors in images A and B will not reversed. */
    if( BrainCaseWhereOrientDiffLT90degree==TRUE || generalCase==TRUE ) {
        /* E1*E2^T -> positive */
        for(i=0; i<3; i++) {
            diff = 0 ;
            for(j=0; j<3; j++)
                diff += EigenVector1->data[i][j]*EigenVector2->data[i][j] ;
            if( diff<0 ) {
                printf("******* reversed on the %dth eigenvector!\n", i) ;
                for(j=0; j<3; j++)
                    EigenVector1->data[i][j] *= -1.0 ;  /* reverse it, since geometric moment cannot capture the orientation information*/
            }
        }
        Mat_Print(EigenVector1);
    }
    
    
    /* EigenValue to EigenValueMatrix */
    for(i=0; i<3; i++)
        for(j=0; j<3; j++)
            if(i==j) {
                EigenValueMatrix1->data[i][j]= sqrt(EigenValue1[i]) ;
                EigenValueMatrix2->data[i][j]= 1./sqrt(EigenValue2[i]) ;
            }
            else {
                EigenValueMatrix1->data[i][j]= 0;
                EigenValueMatrix2->data[i][j]= 0;
            }
    
    /* get transformation */
    Mat_A_equal_BxC(temp1, EigenVector1, EigenValueMatrix1) ;             /*A=BxC*/
    Mat_A_equal_BxC(temp2, temp1, EigenValueMatrix2) ;             /*A=BxC*/
    for(i=0; i<3; i++)
        for(j=0; j<3; j++)
            TransferedMtrx->data[i][j]=EigenVector2->data[j][i] ;
    Mat_A_equal_BxC(Transform, temp2, TransferedMtrx) ;
    
    
    if(generalCase==FALSE) {
        for(i=0; i<3; i++)
            for(j=0; j<3; j++) {
                if(i!=j)
                    Transform->data[i][j]=0 ;
                else
                    Transform->data[i][j] *= -1.0 ;
            }
        
        delta_0 = Data1_momentsMatrix3x3->data[0][0]/Data2_momentsMatrix3x3->data[0][0] ;
        delta_1 = Data1_momentsMatrix3x3->data[1][1]/Data2_momentsMatrix3x3->data[1][1] ;
        delta_2 = Data1_momentsMatrix3x3->data[2][2]/Data2_momentsMatrix3x3->data[2][2] ;
        
        printf("\ndelta %f %f %f \n", delta_0, delta_1, delta_2) ;
        Transform->data[0][0] = sqrt(delta_0) ;
        Transform->data[1][1] = sqrt(delta_1) ;
        Transform->data[2][2] = sqrt(delta_2) ;
        
        if( FixDigitalSamplingProb==YYES ) {
            Transform->data[0][0] = sqrt(delta_0) ;
            Transform->data[1][1] = sqrt(delta_1) ;
            Transform->data[2][2] = sqrt(delta_2) ;
        }
        
        Mat_Print(Transform);
    }
    
    
    if(RigidTransformOnly==TRUE) {
        /* get transformation */
        Mat_A_equal_BxC(temp1, EigenVector1, EigenValueMatrix1) ;             /*A=BxC*/
        Mat_A_equal_BxC(temp2, temp1, EigenValueMatrix2) ;             /*A=BxC*/
        for(i=0; i<3; i++)
            for(j=0; j<3; j++) {
                Eigen1Matrix->data[i][j]=EigenVector1->data[i][j] ;
                Eigen2Matrix->data[i][j]=EigenVector2->data[j][i] ;
            }
        Mat_A_equal_BxC(Transform, Eigen1Matrix, Eigen2Matrix) ;
        Mat_Print(Transform);
    }
    
    
    /* transform on data1 */
    for(k=0;k<z_size2;k++)
        for(i=0;i<image_size;i++)
            for(j=0;j<image_size;j++) {
                current_position[0] = (i-center2.x) ;
                current_position[1] = (j-center2.y) ;
                current_position[2] = (k-center2.z) ;
                
                Mat_times_Vector( transformed_position, Transform, current_position) ; /* T3 =T1xT2 x Average */
                x = static_cast<int>(transformed_position[0]+center1.x +0.5);
                y = static_cast<int>(transformed_position[1]+center1.y +0.5);
                z = static_cast<int>(transformed_position[2]+center1.z +0.5);
                
                if(x>=0 && x<image_size && y>=0 && y<image_size && z>=0 && z<z_size1)
                    data1_to_2[k][i][j] = data1[z][x][y];
                else
                    data1_to_2[k][i][j] = 0 ;
            }
    
    /* transform on data3: optional */
    if( OperateSameTransformation==TRUE )
        for(k=0;k<z_size2;k++)
            for(i=0;i<image_size;i++)
                for(j=0;j<image_size;j++) {
                    current_position[0] = (i-center2.x) ;
                    current_position[1] = (j-center2.y) ;
                    current_position[2] = (k-center2.z) ;
                    
                    Mat_times_Vector( transformed_position, Transform, current_position) ; /* T3 =T1xT2 x Average */
                    x = static_cast<int>(transformed_position[0]+center1.x +0.5);
                    y = static_cast<int>(transformed_position[1]+center1.y +0.5);
                    z = static_cast<int>(transformed_position[2]+center1.z +0.5);
                    
                    if(x>=0 && x<image_size && y>=0 && y<image_size && z>=0 && z<z_size3)
                        data3_to_2[k][i][j] = data3[z][x][y];
                    else
                        data3_to_2[k][i][j] = 0 ;
                }
    printf("\nFinally calculated transformation:\n") ;
    Mat_Print(Transform); printf("\n\n\n") ;
    Write_Transformation( Transform, center1, center2, filename) ; /* July 2001 */
    
    
    /* free */
    FreeMatrix(Data1_momentsMatrix3x3);
    FreeMatrix(Data2_momentsMatrix3x3);
    FreeMatrix(TransferedMtrx);
    FreeMatrix(Transform);
    FreeMatrix(temp1);
    FreeMatrix(temp2);
    FreeMatrix(EigenVector1) ;
    FreeMatrix(EigenVector2) ;
    FreeMatrix(EigenValueMatrix1) ;
    FreeMatrix(EigenValueMatrix2) ;
    FreeMatrix(Eigen1Matrix) ;
    FreeMatrix(Eigen2Matrix) ;
    
    free_vectorSHEN(EigenValue1, 0, 3-1) ;
    free_vectorSHEN(EigenValue2, 0, 3-1) ;
    free_vectorSHEN(current_position, 0, 3-1) ;
    free_vectorSHEN(transformed_position, 0, 3-1) ;
}


/* July 2001 */
void Write_Transformation( Matrix *Transform, Fvector3d center1, Fvector3d center2, char filename[180]) {
    FILE *fp;
    int i, j ;
    
    double *Temp_transf ;
    int Tm, Tn ;
    
    Tm = Transform->height ;
    Tn = Transform->width ;
    
    /* allocate space */
    Temp_transf = (double *)calloc( Tm*Tn, sizeof(double) );
    
    
    /* for T1 */
    for(i=0; i<Tm; i++)
        for(j=0; j<Tn; j++)
            Temp_transf[i*Tn+j] = Transform->data[i][j] ;
    
    
    /* begin to write*/
    if((fp=fopen(filename, "wb"))==NULL) {
        printf("cannot open file\n");
        exit(1);
    }
    
    fseek(fp, 0L, SEEK_SET);
    
    fwrite( Temp_transf,  sizeof(double), Tm*Tn,  fp ) ;
    fwrite( &center1,      sizeof(Fvector3d),  1,  fp ) ;
    fwrite( &center2,      sizeof(Fvector3d),  1,  fp ) ;
    
    fclose(fp);
    
    /* free */
    free(Temp_transf) ;
}
