package edu.jhmi.rad.medic.methods;
 import java.io.*;
import java.util.*;
import gov.nih.mipav.view.*;

import edu.jhmi.rad.medic.libraries.*;
import edu.jhmi.rad.medic.utilities.*;

/**
 *
 *  This algorithm handles the main Fuzzy C-Means operations
 *  for longitudinal segmentation of 4D data.
 *	
 *	@version    December 2004
 *	@author     Pierre-Louis Bazin
 *		
 *
 */
 
public class LongitudinalFCM {
		
	// numerical quantities
	private static final	float   INF=1e30f;
	private static final	float   ZERO=1e-30f;
	
	// data buffers
	private 	float[][][][]		image;  			// original image (4D)
	private 	float[][][][][]		mems;				// membership function (4D)
	private 	float[][]			centroids;			// class centroids (4D)
	private 	float[][][][]		field;  			// inhomogeneity field (4D)
	private 	float[][][][][]		edges;  			// edge map (4D)
	private 	boolean[][][][]		mask;   			// image mask: true for data points
	private static	int				nx,ny,nz,nt;   		// image dimensions
	private static	float			rx,ry,rz;   		// image resolutions
	
	// parameters
	private 	int 		clusters;
	private 	int 		classes;
	private 	float[] 	smoothing;
    private 	float 		fuzziness;
    private 	float[] 	temporal;
    private		float[]		outlier;
				
	// transformation
	private		boolean		useRegistration;
	private		float[]		maskVal;
	private		float[][][]	transforms;		// array of 3x4 transform matrices for each image

	// computation variables
	private		float[]         prev;
	private		float[][]		wt;			// polynomials for Cubic Lagrangian
	private		float[]			Imax, Imin; // image boundaries for interpolation
        
	// computation flags
	private 	boolean 		isWorking;
	private 	boolean 		isCompleted;
	private 	boolean[] 		useEdges;
	private 	boolean[] 		useField;
	
	// for debug and display
	ViewUserInterface			UI;
    ViewJProgressBar            progressBar;
	static final boolean		debug=true;
	static final boolean		verbose=true;
	
	/**
	 *  constructor
	 *	note: all images passed to the algorithm are just linked, not copied
	 */
	public LongitudinalFCM(float[][][][] image_, boolean [][][][] mask_, 
					int nx_, int ny_, int nz_, int nt_,
					float rx_, float ry_, float rz_, 
					int nclasses_, int nclusters_,
					float smoothing_, float temp_, float fuzzy_,
                    float outlierRatio_, float[] scal_, float[] maskVal_,
					float[] min_, float[] max_, 
					ViewUserInterface UI_, ViewJProgressBar bar_) {
						
		image = image_;
		mask = mask_;
		nx = nx_;
		ny = ny_;
		nz = nz_;
		nt = nt_;
		rx = rx_;
		ry = ry_;
		rz = rz_;
		classes = nclasses_;
		clusters = nclusters_;
		smoothing = new float[nt];
		for (int t=0;t<nt;t++) smoothing[t] = smoothing_*scal_[t]*scal_[t];
        temporal = new float[nt];
		for (int t=0;t<nt;t++) temporal[t] = temp_*scal_[t]*scal_[t];
		fuzziness = fuzzy_;
        outlier = new float[nt];
		for (int t=0;t<nt;t++) outlier[t] = outlierRatio_*scal_[t];
		maskVal = maskVal_;
		
		Imin = new float[nt];
		Imax = new float[nt];
		for (int t=0;t<nt;t++) Imin[t] = min_[t];
		for (int t=0;t<nt;t++) Imax[t] = max_[t];
		
		UI = UI_;
        progressBar = bar_;
		
		// init all the arrays
		try {
			mems = new float[nt][nx][ny][nz][classes];
			centroids = new float[nt][clusters];
			prev = new float[classes];
			transforms = new float[nt][3][4];
			field = new float[nt][][][];
			edges = new float[nt][][][][];
			useField = new boolean[nt];
			useEdges = new boolean[nt];
			wt = ImageFunctions.setup3DCubicLagrangianInterpolation();		        
		} catch (OutOfMemoryError e){
			isWorking = false;
            finalize();
			System.out.println(e.getMessage());
			return;
		}
		isWorking = true;

		// init values
		for (int t=0;t<nt;t++) {
			for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
				for (int k=0;k<classes;k++) {
					mems[t][x][y][z][k] = 0.0f;
				}
			}
			for (int k=0;k<clusters;k++) {
				centroids[t][k] = 0.0f;
			}
			for (int i=0;i<3;i++) {
				for (int j=0;j<4;j++) transforms[t][i][j] = 0.0f;
				transforms[t][i][i] = 1.0f;
			}
			// options
			useField[t] = false;
			useEdges[t] = false;
		}
		if (debug) MedicUtilPublic.displayMessage("initialization\n");
	}

	/** clean-up: destroy membership and centroid arrays */
	final public void finalize() {
		mems = null;
		centroids = null;
		prev = null;
		transforms = null;
		System.gc();
	}
	
    /** accessor for computed data */ 
    public final float[][][][][] getMemberships() { return mems; }
    /** accessor for computed data */ 
    public final float[][][][] getMemberships(int t) { return mems[t]; }
    /** accessor for computed data */ 
    public final void setMemberships(int t, float[][][][] Mem_) { mems[t] = Mem_; }
    /** accessor for computed data */ 
    public final void importMemberships(int t, float[][][][] Mems) { 
        for (int k=0;k<classes;k++) {
            for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
				mems[t][x][y][z][k] = Mems[x][y][z][k];
			}
		}
	}
    /** accessor for computed data */ 
    public final float[][] getCentroids() { return centroids; }
    /** accessor for computed data */ 
    public final float[] getCentroids(int t) { return centroids[t]; }
	
    /** accessor for computed data */ 
    public final float[][] getTransform(int t) { return transforms[t]; }
	/** accessor for computed data */ 
	public final void importTransform(int t, float[][] trans) {
		for (int i=0;i<3;i++) for (int j=0;j<4;j++) {
			transforms[t][i][j] = trans[i][j];
		}
    }

    /** accessor for computed data */ 
	final public void setCentroids(int t, float[] cent) { 
		centroids[t] = cent; 
		if (debug) {
			MedicUtilPublic.displayMessage("centroids: ("+centroids[t][0]);
			for (int k=1;k<clusters;k++) MedicUtilPublic.displayMessage(", "+centroids[t][k]);
			MedicUtilPublic.displayMessage(")\n");
		}
	}
    /** accessor for computed data */ 
	final public void importCentroids(int t, float[] cent) { 
		for (int k=0;k<clusters;k++) centroids[t][k] = cent[k]; 
		if (debug) {
			MedicUtilPublic.displayMessage("centroids: ("+centroids[t][0]);
			for (int k=1;k<clusters;k++) MedicUtilPublic.displayMessage(", "+centroids[t][k]);
			MedicUtilPublic.displayMessage(")\n");
		}

	}
    /** change parameters */
    public final void setMRF(float[] smooth, float[] temp) {
        smoothing = smooth; temporal = temp;
    }
    /** add inhomogeneity correction */
    public final void addInhomogeneityCorrection(int t, float[][][] field_) {
        field[t] = field_;
        useField[t] = true;
    }
    /** add edge parameter */
    public final void addEdgeMap(int t, float[][][][] edges_) {
        edges[t] = edges_;
        useEdges[t] = true;
    }
	/** computation flags */
	public final boolean isWorking() { return isWorking; }
	/** computation flags */
	public final boolean isCompleted() { return isCompleted; }
	
    /** 
	 *  compute the FCM membership functions given the centroids
	 *	with the different options (outliers, field, edges, MRF)
	 */
    final public float computeMemberships() {
        float distance,dist;
        int x,y,z,k,m,t;
        int progress, mod;
        long inner_loop_time;
        float den,num;
        float neighbors, ngb;
        
        if (fuzziness!=2) return computeGeneralMemberships();
		
		distance = 0.0f;
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();
        for (x=1;x<nx-1;x++) for (y=1;y<ny-1;y++) for (z=1;z<nz-1;z++) {
			progress++;
			if ( (verbose) && (progress%mod==0) )
                progressBar.updateValue(Math.round( (float)progress/(float)mod),false);

			for (t=0;t<nt;t++) {
				if ( mask[t][x][y][z] ) {
					den = 0;
					// remember the previous values
					for (k=0;k<classes;k++) prev[k] = mems[t][x][y][z][k];

					for (k=0;k<classes;k++) {

						// data term
						if (k<clusters) {
							if (useField[t]) num = (field[t][x][y][z]*image[t][x][y][z]-centroids[t][k])
												  *(field[t][x][y][z]*image[t][x][y][z]-centroids[t][k]);
							else num = (image[t][x][y][z]-centroids[t][k])
									  *(image[t][x][y][z]-centroids[t][k]);
						} else {
							num = outlier[t]*outlier[t];
						}

						// spatial smoothing
						if (smoothing[t] > 0.0f) { 
							ngb = 0.0f;  
							neighbors = 0.0f;
							// case by case	: X+
							if (mask[t][x+1][y][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][0][x][y][z]*mems[t][x+1][y][z][m]*mems[t][x+1][y][z][m];
								else ngb += mems[t][x+1][y][z][m]*mems[t][x+1][y][z][m];
								neighbors ++;
							}
							// case by case	: X-
							if (mask[t][x-1][y][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][0][x-1][y][z]*mems[t][x-1][y][z][m]*mems[t][x-1][y][z][m];
								else ngb += mems[t][x-1][y][z][m]*mems[t][x-1][y][z][m];
								neighbors ++;
							}
							// case by case	: Y+
							if (mask[t][x][y+1][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][1][x][y][z]*mems[t][x][y+1][z][m]*mems[t][x][y+1][z][m];
								else ngb += mems[t][x][y+1][z][m]*mems[t][x][y+1][z][m];
								neighbors ++;
							}
							// case by case	: Y-
							if (mask[t][x][y-1][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][1][x][y-1][z]*mems[t][x][y-1][z][m]*mems[t][x][y-1][z][m];
								else ngb += mems[t][x][y-1][z][m]*mems[t][x][y-1][z][m];
								neighbors ++;
							}
							// case by case	: Z+
							if (mask[t][x][y][z+1]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][2][x][y][z]*mems[t][x][y][z+1][m]*mems[t][x][y][z+1][m];
								else ngb += mems[t][x][y][z+1][m]*mems[t][x][y][z+1][m];
								neighbors ++;
							}
							// case by case	: Z-
							if (mask[t][x][y][z-1]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][2][x][y][z-1]*mems[t][x][y][z-1][m]*mems[t][x][y][z-1][m];
								else ngb += mems[t][x][y][z-1][m]*mems[t][x][y][z-1][m];
								neighbors ++;
							}
							if (neighbors>0.0) num = num + smoothing[t]*ngb/neighbors;
						}
						// time neighbor correlation
						if (temporal[t]>0.0f) {
							neighbors = 0.0f;
							ngb = 0.0f;
							if (t>0) {
								if (mask[t-1][x][y][z]) for (m=0;m<classes;m++) if (m!=k) {
									ngb += mems[t-1][x][y][z][m]*mems[t-1][x][y][z][m];
									neighbors ++;
								}
							}					
							if (t<nt-1) {
								if (mask[t+1][x][y][z]) for (m=0;m<classes;m++) if (m!=k) {
									ngb += mems[t+1][x][y][z][m]*mems[t+1][x][y][z][m];
									neighbors ++;
								}
							}					
							if (neighbors>0.0) num = num + temporal[t]*ngb/neighbors;
						}
						// invert the result
						if (num>ZERO) num = 1.0f/num;
						else num = INF;

						mems[t][x][y][z][k] = num;
						den += num;
					}

					// normalization
					for (k=0;k<classes;k++) {
						mems[t][x][y][z][k] = mems[t][x][y][z][k]/den;

						// compute the maximum distance
						dist = Math.abs(mems[t][x][y][z][k]-prev[k]);
						if (dist > distance) distance = dist;
					}
				} else {
					for (k=0;k<classes;k++) 
						mems[t][x][y][z][k] = 0.0f;
				}
			}
		}
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return distance;
    } // computeMemberships
    
    /** 
	 *  compute the FCM membership functions given the centroids
	 *	with the different options (outliers, field, edges, MRF)
	 */
    final public float computeGeneralMemberships() {
        float distance,dist;
        int x,y,z,k,m,t;
        int progress, mod;
        long inner_loop_time;
        float den,num;
        float neighbors, ngb;
        
        distance = 0.0f;
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();
        for (x=1;x<nx-1;x++) for (y=1;y<ny-1;y++) for (z=1;z<nz-1;z++) {
			progress++;
			if ( (verbose) && (progress%mod==0) )
                progressBar.updateValue(Math.round( (float)progress/(float)mod),false);

			for (t=0;t<nt;t++) {
				if ( mask[t][x][y][z] ) {
					den = 0;
					// remember the previous values
					for (k=0;k<classes;k++) prev[k] = mems[t][x][y][z][k];

					for (k=0;k<classes;k++) {

						// data term
						if (k<clusters) {
							if (useField[t]) num = (field[t][x][y][z]*image[t][x][y][z]-centroids[t][k])
												  *(field[t][x][y][z]*image[t][x][y][z]-centroids[t][k]);
							else num = (image[t][x][y][z]-centroids[t][k])
									  *(image[t][x][y][z]-centroids[t][k]);
						} else {
							num = outlier[t]*outlier[t];
						}

						// spatial smoothing
						if (smoothing[t] > 0.0f) { 
							ngb = 0.0f;  
							neighbors = 0.0f;
							// case by case	: X+
							if (mask[t][x+1][y][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][0][x][y][z]*Math.pow(mems[t][x+1][y][z][m],fuzziness);
								else ngb += Math.pow(mems[t][x+1][y][z][m],fuzziness);
								neighbors ++;
							}
							// case by case	: X-
							if (mask[t][x-1][y][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][0][x-1][y][z]*Math.pow(mems[t][x-1][y][z][m],fuzziness);
								else ngb += Math.pow(mems[t][x-1][y][z][m],fuzziness);
								neighbors ++;
							}
							// case by case	: Y+
							if (mask[t][x][y+1][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][1][x][y][z]*Math.pow(mems[t][x][y+1][z][m],fuzziness);
								else ngb += Math.pow(mems[t][x][y+1][z][m],fuzziness);
								neighbors ++;
							}
							// case by case	: Y-
							if (mask[t][x][y-1][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][1][x][y-1][z]*Math.pow(mems[t][x][y-1][z][m],fuzziness);
								else ngb += Math.pow(mems[t][x][y-1][z][m],fuzziness);
								neighbors ++;
							}
							// case by case	: Z+
							if (mask[t][x][y][z+1]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][2][x][y][z]*Math.pow(mems[t][x][y][z+1][m],fuzziness);
								else ngb += Math.pow(mems[t][x][y][z+1][m],fuzziness);
								neighbors ++;
							}
							// case by case	: Z-
							if (mask[t][x][y][z-1]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][2][x][y][z-1]*Math.pow(mems[t][x][y][z-1][m],fuzziness);
								else ngb += Math.pow(mems[t][x][y][z-1][m],fuzziness);
								neighbors ++;
							}
							if (neighbors>0.0) num = num + smoothing[t]*ngb/neighbors;
						}
						// time neighbor correlation
						if (temporal[t]>0.0f) {
							neighbors = 0.0f;
							ngb = 0.0f;
							if (t>0) {
								if (mask[t-1][x][y][z]) for (m=0;m<classes;m++) if (m!=k) {
									ngb += Math.pow(mems[t-1][x][y][z][m],fuzziness);
									neighbors ++;
								}
							}					
							if (t<nt-1) {
								if (mask[t+1][x][y][z]) for (m=0;m<classes;m++) if (m!=k) {
									ngb += Math.pow(mems[t+1][x][y][z][m],fuzziness);
									neighbors ++;
								}
							}					
							if (neighbors>0.0) num = num + temporal[t]*ngb/neighbors;
						}
						// invert the result
						if (num>ZERO) num = (float)Math.pow(num,1.0f/(1.0f-fuzziness) );
						else num = INF;

						mems[t][x][y][z][k] = num;
						den += num;
					}

					// normalization
					for (k=0;k<classes;k++) {
						mems[t][x][y][z][k] = mems[t][x][y][z][k]/den;

						// compute the maximum distance
						dist = Math.abs(mems[t][x][y][z][k]-prev[k]);
						if (dist > distance) distance = dist;
					}
				} else {
					for (k=0;k<classes;k++) 
						mems[t][x][y][z][k] = 0.0f;
				}
			}
		}
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return distance;
    } // computeMemberships
    
    /** 
	 *  compute the FCM membership functions given the centroids
	 *	with the different options (outliers, field, edges, MRF)
	 *	including a transform function
	 */
    final public float computeTransformedMemberships() {
        float distance,dist,imgT;
        float xT,yT,zT;
        float xTp,yTp,zTp;
        float xTm,yTm,zTm;
        int x,y,z,k,m,t;
        int progress, mod;
        long inner_loop_time;
        float den,num;
        float neighbors, ngb;
        
		if (fuzziness!=2) return computeGeneralTransformedMemberships();
		
        distance = 0.0f;
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();
        for (x=1;x<nx-1;x++) for (y=1;y<ny-1;y++) for (z=1;z<nz-1;z++) {
			progress++;
			if ( (verbose) && (progress%mod==0) )
                progressBar.updateValue(Math.round( (float)progress/(float)mod),false);

			for (t=0;t<nt;t++) {
				
				// compute the local position: X' = RX+T
				xT = (transforms[t][0][0]*x*rx + transforms[t][0][1]*y*ry + transforms[t][0][2]*z*rz + transforms[t][0][3])/rx;
				yT = (transforms[t][1][0]*x*rx + transforms[t][1][1]*y*ry + transforms[t][1][2]*z*rz + transforms[t][1][3])/ry;
				zT = (transforms[t][2][0]*x*rx + transforms[t][2][1]*y*ry + transforms[t][2][2]*z*rz + transforms[t][2][3])/rz;
				// neighbors
				xTp = (transforms[t][0][0]*(x+1)*rx + transforms[t][0][1]*y*ry     + transforms[t][0][2]*z*rz     + transforms[t][0][3])/rx;
				yTp = (transforms[t][1][0]*x*rx     + transforms[t][1][1]*(y+1)*ry + transforms[t][1][2]*z*rz     + transforms[t][1][3])/ry;
				zTp = (transforms[t][2][0]*x*rx     + transforms[t][2][1]*y*ry     + transforms[t][2][2]*(z+1)*rz + transforms[t][2][3])/rz;
				xTm = (transforms[t][0][0]*(x-1)*rx + transforms[t][0][1]*y*ry     + transforms[t][0][2]*z*rz     + transforms[t][0][3])/rx;
				yTm = (transforms[t][1][0]*x*rx     + transforms[t][1][1]*(y-1)*ry + transforms[t][1][2]*z*rz     + transforms[t][1][3])/ry;
				zTm = (transforms[t][2][0]*x*rx     + transforms[t][2][1]*y*ry     + transforms[t][2][2]*(z-1)*rz + transforms[t][2][3])/rz;

				//if (ImageFunctions.maskInterpolation(mask[t],xT,yT,zT,nx,ny,nz)) {
				if (mask[0][x][y][z]) {
					// pre-compute interpolated values
					//imgT = ImageFunctions.linearInterpolation(image[t],mask[t],maskVal[t],xT,yT,zT,nx,ny,nz);
					imgT = ImageFunctions.cubicLagrangianInterpolation3D(image[t],wt,Imin[t],Imax[t],xT,yT,zT,nx,ny,nz);

					den = 0;
					// remember the previous values
					for (k=0;k<classes;k++) prev[k] = mems[t][x][y][z][k];

					for (k=0;k<classes;k++) {

						// data term
						if (k<clusters) {
							if (useField[t]) num = (field[t][x][y][z]*imgT-centroids[t][k])
												  *(field[t][x][y][z]*imgT-centroids[t][k]);
							else num = (imgT-centroids[t][k])
									  *(imgT-centroids[t][k]);
						} else {
							num = outlier[t]*outlier[t];
						}

						// spatial smoothing
						if (smoothing[t] > 0.0f) { 
							ngb = 0.0f;  
							neighbors = 0.0f;
							// case by case	: X+
							if (mask[0][x+1][y][z]) 
							//if (ImageFunctions.maskInterpolation(mask[t],xTp,yT,zT,nx,ny,nz)) 
								for (m=0;m<classes;m++) if (m!=k) {
									if (useEdges[t]) ngb += edges[t][0][x][y][z]*mems[t][x+1][y][z][m]*mems[t][x+1][y][z][m];
									else ngb += mems[t][x+1][y][z][m]*mems[t][x+1][y][z][m];
									neighbors ++;
								}
							// case by case	: X-
							if (mask[0][x-1][y][z]) 
							//if (ImageFunctions.maskInterpolation(mask[t],xTm,yT,zT,nx,ny,nz)) 
								for (m=0;m<classes;m++) if (m!=k) {
									if (useEdges[t]) ngb += edges[t][0][x-1][y][z]*mems[t][x-1][y][z][m]*mems[t][x-1][y][z][m];
									else ngb += mems[t][x-1][y][z][m]*mems[t][x-1][y][z][m];
									neighbors ++;
								}
							// case by case	: Y+
							if (mask[0][x][y+1][z]) 
							//if (ImageFunctions.maskInterpolation(mask[t],xT,yTp,zT,nx,ny,nz)) 
								for (m=0;m<classes;m++) if (m!=k) {
									if (useEdges[t]) ngb += edges[t][1][x][y][z]*mems[t][x][y+1][z][m]*mems[t][x][y+1][z][m];
									else ngb += mems[t][x][y+1][z][m]*mems[t][x][y+1][z][m];
									neighbors ++;
								}
							// case by case	: Y-
							if (mask[0][x][y-1][z]) 
							//if (ImageFunctions.maskInterpolation(mask[t],xT,yTm,zT,nx,ny,nz)) 
								for (m=0;m<classes;m++) if (m!=k) {
									if (useEdges[t]) ngb += edges[t][1][x][y-1][z]*mems[t][x][y-1][z][m]*mems[t][x][y-1][z][m];
									else ngb += mems[t][x][y-1][z][m]*mems[t][x][y-1][z][m];
									neighbors ++;
								}
							// case by case	: Z+
							if (mask[0][x][y][z+1]) 
							//if (ImageFunctions.maskInterpolation(mask[t],xT,yT,zTp,nx,ny,nz)) 
								for (m=0;m<classes;m++) if (m!=k) {
									if (useEdges[t]) ngb += edges[t][2][x][y][z]*mems[t][x][y][z+1][m]*mems[t][x][y][z+1][m];
									else ngb += mems[t][x][y][z+1][m]*mems[t][x][y][z+1][m];
									neighbors ++;
								}
							// case by case	: Z-
							if (mask[0][x][y][z-1]) 
							//if (ImageFunctions.maskInterpolation(mask[t],xT,yT,zTm,nx,ny,nz)) 
								for (m=0;m<classes;m++) if (m!=k) {
									if (useEdges[t]) ngb += edges[t][2][x][y][z-1]*mems[t][x][y][z-1][m]*mems[t][x][y][z-1][m];
									else ngb += mems[t][x][y][z-1][m]*mems[t][x][y][z-1][m];
									neighbors ++;
								}
							if (neighbors>0.0) num = num + smoothing[t]*ngb/neighbors;
						}
						// time neighbor correlation
						if (temporal[t]>0.0f) {
							neighbors = 0.0f;
							ngb = 0.0f;
							if (t>0) {
								if (mask[t-1][x][y][z]) for (m=0;m<classes;m++) if (m!=k) {
									ngb += mems[t-1][x][y][z][m]*mems[t-1][x][y][z][m];
									neighbors ++;
								}
							}					
							if (t<nt-1) {
								if (mask[t+1][x][y][z]) for (m=0;m<classes;m++) if (m!=k) {
									ngb += mems[t+1][x][y][z][m]*mems[t+1][x][y][z][m];
									neighbors ++;
								}
							}					
							if (neighbors>0.0) num = num + temporal[t]*ngb/neighbors;
						}
						// invert the result
						if (num>ZERO) num = 1.0f/num;
						else num = INF;

						mems[t][x][y][z][k] = num;
						den += num;
					}

					// normalization
					for (k=0;k<classes;k++) {
						mems[t][x][y][z][k] = mems[t][x][y][z][k]/den;

						// compute the maximum distance
						dist = Math.abs(mems[t][x][y][z][k]-prev[k]);
						if (dist > distance) distance = dist;
					}
				} else {
					for (k=0;k<classes;k++) 
						mems[t][x][y][z][k] = 0.0f;
				}
			}
		}
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return distance;
    } // computeTransformedMemberships
    
    /** 
	 *  compute the FCM membership functions given the centroids
	 *	with the different options (outliers, field, edges, MRF)
	 *	including a transform function
	 */
    final public float computeGeneralTransformedMemberships() {
        float distance,dist,imgT;
        float xT,yT,zT;
        float xTp,yTp,zTp;
        float xTm,yTm,zTm;
        int x,y,z,k,m,t;
        int progress, mod;
        long inner_loop_time;
        float den,num;
        float neighbors, ngb;
        

        distance = 0.0f;
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();
        for (x=1;x<nx-1;x++) for (y=1;y<ny-1;y++) for (z=1;z<nz-1;z++) {
			progress++;
			if ( (verbose) && (progress%mod==0) )
                progressBar.updateValue(Math.round( (float)progress/(float)mod),false);

			for (t=0;t<nt;t++) {
				
				// compute the local position: X' = RX+T
				xT = (transforms[t][0][0]*x*rx + transforms[t][0][1]*y*ry + transforms[t][0][2]*z*rz + transforms[t][0][3])/rx;
				yT = (transforms[t][1][0]*x*rx + transforms[t][1][1]*y*ry + transforms[t][1][2]*z*rz + transforms[t][1][3])/ry;
				zT = (transforms[t][2][0]*x*rx + transforms[t][2][1]*y*ry + transforms[t][2][2]*z*rz + transforms[t][2][3])/rz;
				// neighbors
				xTp = (transforms[t][0][0]*(x+1)*rx + transforms[t][0][1]*y*ry     + transforms[t][0][2]*z*rz     + transforms[t][0][3])/rx;
				yTp = (transforms[t][1][0]*x*rx     + transforms[t][1][1]*(y+1)*ry + transforms[t][1][2]*z*rz     + transforms[t][1][3])/ry;
				zTp = (transforms[t][2][0]*x*rx     + transforms[t][2][1]*y*ry     + transforms[t][2][2]*(z+1)*rz + transforms[t][2][3])/rz;
				xTm = (transforms[t][0][0]*(x-1)*rx + transforms[t][0][1]*y*ry     + transforms[t][0][2]*z*rz     + transforms[t][0][3])/rx;
				yTm = (transforms[t][1][0]*x*rx     + transforms[t][1][1]*(y-1)*ry + transforms[t][1][2]*z*rz     + transforms[t][1][3])/ry;
				zTm = (transforms[t][2][0]*x*rx     + transforms[t][2][1]*y*ry     + transforms[t][2][2]*(z-1)*rz + transforms[t][2][3])/rz;

				//if (ImageFunctions.maskInterpolation(mask[t],xT,yT,zT,nx,ny,nz)) {
				if (mask[0][x][y][z]) {
					// pre-compute interpolated values
					//imgT = ImageFunctions.linearInterpolation(image[t],mask[t],maskVal[t],xT,yT,zT,nx,ny,nz);
					imgT = ImageFunctions.cubicLagrangianInterpolation3D(image[t],wt,Imin[t],Imax[t],xT,yT,zT,nx,ny,nz);

					den = 0;
					// remember the previous values
					for (k=0;k<classes;k++) prev[k] = mems[t][x][y][z][k];

					for (k=0;k<classes;k++) {

						// data term
						if (k<clusters) {
							if (useField[t]) num = (field[t][x][y][z]*imgT-centroids[t][k])
												  *(field[t][x][y][z]*imgT-centroids[t][k]);
							else num = (imgT-centroids[t][k])
									  *(imgT-centroids[t][k]);
						} else {
							num = outlier[t]*outlier[t];
						}

						// spatial smoothing
						if (smoothing[t] > 0.0f) { 
							ngb = 0.0f;  
							neighbors = 0.0f;
							// case by case	: X+
							if (mask[0][x+1][y][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][0][x][y][z]*Math.pow(mems[t][x+1][y][z][m],fuzziness);
								else ngb += Math.pow(mems[t][x+1][y][z][m],fuzziness);
								neighbors ++;
							}
							// case by case	: X-
							if (mask[0][x-1][y][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][0][x-1][y][z]*Math.pow(mems[t][x-1][y][z][m],fuzziness);
								else ngb += Math.pow(mems[t][x-1][y][z][m],fuzziness);
								neighbors ++;
							}
							// case by case	: Y+
							if (mask[0][x][y+1][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][1][x][y][z]*Math.pow(mems[t][x][y+1][z][m],fuzziness);
								else ngb += Math.pow(mems[t][x][y+1][z][m],fuzziness);
								neighbors ++;
							}
							// case by case	: Y-
							if (mask[0][x][y-1][z]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][1][x][y-1][z]*Math.pow(mems[t][x][y-1][z][m],fuzziness);
								else ngb += Math.pow(mems[t][x][y-1][z][m],fuzziness);
								neighbors ++;
							}
							// case by case	: Z+
							if (mask[0][x][y][z+1]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][2][x][y][z]*Math.pow(mems[t][x][y][z+1][m],fuzziness);
								else ngb += Math.pow(mems[t][x][y][z+1][m],fuzziness);
								neighbors ++;
							}
							// case by case	: Z-
							if (mask[0][x][y][z-1]) for (m=0;m<classes;m++) if (m!=k) {
								if (useEdges[t]) ngb += edges[t][2][x][y][z-1]*Math.pow(mems[t][x][y][z-1][m],fuzziness);
								else ngb += Math.pow(mems[t][x][y][z-1][m],fuzziness);
								neighbors ++;
							}
							if (neighbors>0.0) num = num + smoothing[t]*ngb/neighbors;
						}
						// time neighbor correlation
						if (temporal[t]>0.0f) {
							neighbors = 0.0f;
							ngb = 0.0f;
							if (t>0) {
								if (mask[t-1][x][y][z]) for (m=0;m<classes;m++) if (m!=k) {
									ngb += Math.pow(mems[t-1][x][y][z][m],fuzziness);
									neighbors ++;
								}
							}					
							if (t<nt-1) {
								if (mask[t+1][x][y][z]) for (m=0;m<classes;m++) if (m!=k) {
									ngb += Math.pow(mems[t+1][x][y][z][m],fuzziness);
									neighbors ++;
								}
							}					
							if (neighbors>0.0) num = num + temporal[t]*ngb/neighbors;
						}
						// invert the result
						if (num>ZERO) num = (float)Math.pow(num,1.0f/(1.0f-fuzziness) );
						else num = INF;

						mems[t][x][y][z][k] = num;
						den += num;
					}

					// normalization
					for (k=0;k<classes;k++) {
						mems[t][x][y][z][k] = mems[t][x][y][z][k]/den;

						// compute the maximum distance
						dist = Math.abs(mems[t][x][y][z][k]-prev[k]);
						if (dist > distance) distance = dist;
					}
				} else {
					for (k=0;k<classes;k++) 
						mems[t][x][y][z][k] = 0.0f;
				}
			}
		}
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return distance;
    } // computeTransformedMemberships
    
    /**
	 * compute the centroids given the membership functions
	 */
    final public void computeCentroids() {
        int x,y,z,k,t;
		float num,den;
        
		if (fuzziness!=2) {
			computeGeneralCentroids();
			return;
		}
        
		for (t=0;t<nt;t++) {
			for (k=0;k<clusters;k++) {
				num = 0;
				den = 0;
				for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
					if (mask[t][x][y][z]) {
						if (useField[t]) num += mems[t][x][y][z][k]
											   *mems[t][x][y][z][k]
											   *field[t][x][y][z]*image[t][x][y][z];
						else num += mems[t][x][y][z][k]
								   *mems[t][x][y][z][k]*image[t][x][y][z];
						den += mems[t][x][y][z][k]*mems[t][x][y][z][k];
					}
				}
				if (den>0.0) {
					centroids[t][k] = num/den;
				} else {
					centroids[t][k] = 0.0f;
				}
			}
			if (debug) {
				MedicUtilPublic.displayMessage("centroids: ("+centroids[t][0]);
				for (k=1;k<clusters;k++) MedicUtilPublic.displayMessage(", "+centroids[t][k]);
				MedicUtilPublic.displayMessage(")\n");
			}
		}
        return;
    } // computeCentroids
    
    /**
	 * compute the centroids given the membership functions
	 */
    final public void computeGeneralCentroids() {
        int x,y,z,k,t;
		float num,den;
        float val;
		
		for (t=0;t<nt;t++) {
			for (k=0;k<clusters;k++) {
				num = 0;
				den = 0;
				for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
					if (mask[t][x][y][z]) {
						val = (float)Math.pow(mems[t][x][y][z][k],fuzziness);
						if (useField[t]) num += val*field[t][x][y][z]*image[t][x][y][z];
						else num += val*image[t][x][y][z];
						den += val;
					}
				}
				if (den>0.0) {
					centroids[t][k] = num/den;
				} else {
					centroids[t][k] = 0.0f;
				}
			}
			if (debug) {
				MedicUtilPublic.displayMessage("centroids: ("+centroids[t][0]);
				for (k=1;k<clusters;k++) MedicUtilPublic.displayMessage(", "+centroids[t][k]);
				MedicUtilPublic.displayMessage(")\n");
			}
		}
        return;
    } // computeCentroids
    
    /**
	 * compute the centroids given the membership functions
	 */
    final public void computeTransformedCentroids() {
        int x,y,z,k,t;
		float xT,yT,zT,imgT;
		float num,den;
        
		if (fuzziness!=2) {
			computeGeneralTransformedCentroids();
			return;
		}
        
		for (t=0;t<nt;t++) {
			for (k=0;k<clusters;k++) {
				num = 0;
				den = 0;
				for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
					// compute the local position: X' = RX+T
					xT = (transforms[t][0][0]*x*rx + transforms[t][0][1]*y*ry + transforms[t][0][2]*z*rz + transforms[t][0][3])/rx;
					yT = (transforms[t][1][0]*x*rx + transforms[t][1][1]*y*ry + transforms[t][1][2]*z*rz + transforms[t][1][3])/ry;
					zT = (transforms[t][2][0]*x*rx + transforms[t][2][1]*y*ry + transforms[t][2][2]*z*rz + transforms[t][2][3])/rz;

					if (mask[0][x][y][z]) {
					//if (ImageFunctions.maskInterpolation(mask[t],xT,yT,zT,nx,ny,nz) ) {
						// pre-compute interpolated values
						//imgT = ImageFunctions.linearInterpolation(image[t],maskVal[t],xT,yT,zT,nx,ny,nz);
						imgT = ImageFunctions.cubicLagrangianInterpolation3D(image[t],wt,Imin[t],Imax[t],xT,yT,zT,nx,ny,nz);

						if (useField[t]) num += mems[t][x][y][z][k]
											   *mems[t][x][y][z][k]
											   *field[t][x][y][z]*imgT;
						else num += mems[t][x][y][z][k]
								   *mems[t][x][y][z][k]*imgT;
						den += mems[t][x][y][z][k]*mems[t][x][y][z][k];
					}
				}
				if (den>0.0) {
					centroids[t][k] = num/den;
				} else {
					centroids[t][k] = 0.0f;
				}
			}
			if (debug) {
				MedicUtilPublic.displayMessage("centroids: ("+centroids[t][0]);
				for (k=1;k<clusters;k++) MedicUtilPublic.displayMessage(", "+centroids[t][k]);
				MedicUtilPublic.displayMessage(")\n");
			}
		}
        return;
    } // computeTransformedCentroids
    
    /**
	 * compute the centroids given the membership functions
	 */
    final public void computeGeneralTransformedCentroids() {
        int x,y,z,k,t;
		float xT,yT,zT,imgT;
		float num,den;
        float val;
		
		for (t=0;t<nt;t++) {
			for (k=0;k<clusters;k++) {
				num = 0;
				den = 0;
				for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
					// compute the local position: X' = RX+T
					xT = (transforms[t][0][0]*x*rx + transforms[t][0][1]*y*ry + transforms[t][0][2]*z*rz + transforms[t][0][3])/rx;
					yT = (transforms[t][1][0]*x*rx + transforms[t][1][1]*y*ry + transforms[t][1][2]*z*rz + transforms[t][1][3])/ry;
					zT = (transforms[t][2][0]*x*rx + transforms[t][2][1]*y*ry + transforms[t][2][2]*z*rz + transforms[t][2][3])/rz;

					if (mask[0][x][y][z]) {
					//if (ImageFunctions.maskInterpolation(mask[t],xT,yT,zT,nx,ny,nz) ) {
						// pre-compute interpolated values
						//imgT = ImageFunctions.linearInterpolation(image[t],maskVal[t],xT,yT,zT,nx,ny,nz);
						imgT = ImageFunctions.cubicLagrangianInterpolation3D(image[t],wt,Imin[t],Imax[t],xT,yT,zT,nx,ny,nz);
						val = (float)Math.pow(mems[t][x][y][z][k],fuzziness);
						
						if (useField[t]) num += val*field[t][x][y][z]*imgT;
						else num += val*imgT;
						den += val;
					}
				}
				if (den>0.0) {
					centroids[t][k] = num/den;
				} else {
					centroids[t][k] = 0.0f;
				}
			}
			if (debug) {
				MedicUtilPublic.displayMessage("centroids: ("+centroids[t][0]);
				for (k=1;k<clusters;k++) MedicUtilPublic.displayMessage(", "+centroids[t][k]);
				MedicUtilPublic.displayMessage(")\n");
			}
		}
        return;
    } // computeTransformedCentroids
    
	/** 
	 *	returns the hard classification (max_{clusters}(Mems)) 
	 */
	public final byte[][][] exportHardClassification(int t) {
		int 	x,y,z,k,best;
		byte[][][]	classification = new byte[nx][ny][nz];
        float bestmem;
        
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {                    
			if (mask[t][x][y][z]) {
				best = -1;bestmem = 0.0f;
				for (k=0;k<classes;k++) {
					if (mems[t][x][y][z][k] > bestmem) {
						best = k;
						bestmem = mems[t][x][y][z][k];
					}
				}   
				if (best==-1)
					classification[x][y][z] = 0;
				else
					classification[x][y][z] = (byte)(best+1);
			} else {
				classification[x][y][z] = 0;
			}
		}
		return classification;
	} // exportHardClassification

	/** 
	 *	returns the hard classification (max_{clusters}(Mems)) 
	 *  in the case with registration
	 */
	public final byte[][][] exportTransformedHardClassification(int t) {
		int 	x, y, z, k,best;
		float		xT, yT, zT;
		byte[][][]	classification = new byte[nx][ny][nz];
        float bestmem;
        
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {                    
			// compute the local position: X' = RX+T
			xT = (transforms[t][0][0]*x*rx + transforms[t][0][1]*y*ry + transforms[t][0][2]*z*rz + transforms[t][0][3])/rx;
			yT = (transforms[t][1][0]*x*rx + transforms[t][1][1]*y*ry + transforms[t][1][2]*z*rz + transforms[t][1][3])/ry;
			zT = (transforms[t][2][0]*x*rx + transforms[t][2][1]*y*ry + transforms[t][2][2]*z*rz + transforms[t][2][3])/rz;

			if (mask[0][x][y][z]) {
			//if (ImageFunctions.maskInterpolation(mask[t],xT,yT,zT,nx,ny,nz) ) {
				best = -1;bestmem = 0.0f;
				for (k=0;k<classes;k++) {
					if (mems[t][x][y][z][k] > bestmem) {
						best = k;
						bestmem = mems[t][x][y][z][k];
					}
				}   
				if (best==-1)
					classification[x][y][z] = 0;
				else
					classification[x][y][z] = (byte)(best+1);
			} else {
				classification[x][y][z] = 0;
			}
		}
		return classification;
	} // exportHardClassification

	/** 
	 *	returns the transformed image
	 *  in the case with registration
	 */
	public final float[][][] exportTransformedImage(int t) {
		int 	x, y, z;
		float		xT, yT, zT;
		float[][][]	img = new float[nx][ny][nz];
        
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {                    
			// compute the local position: X' = RX+T
			xT = (transforms[t][0][0]*x*rx + transforms[t][0][1]*y*ry + transforms[t][0][2]*z*rz + transforms[t][0][3])/rx;
			yT = (transforms[t][1][0]*x*rx + transforms[t][1][1]*y*ry + transforms[t][1][2]*z*rz + transforms[t][1][3])/ry;
			zT = (transforms[t][2][0]*x*rx + transforms[t][2][1]*y*ry + transforms[t][2][2]*z*rz + transforms[t][2][3])/rz;

			if (ImageFunctions.maskInterpolation(mask[t],xT,yT,zT,nx,ny,nz) ) {
			//if (mask[0][x][y][z]) {
				//img[x][y][z] = ImageFunctions.linearInterpolation(image[t],mask[t],maskVal[t],xT,yT,zT,nx,ny,nz);
				//img[x][y][z] = ImageFunctions.cubicLagrangian3D(image[t],mask[t],wt,Imin[t],Imax[t],xT,yT,zT,nx,ny,nz);
				img[x][y][z] = ImageFunctions.cubicLagrangianInterpolation3D(image[t],wt,Imin[t],Imax[t],xT,yT,zT,nx,ny,nz);
			} else {
				img[x][y][z] = 0.0f;
			}
		}
		return img;
	} // exportHardClassification

	/** 
	 *	export membership functions 
	 */
	public final float[][][][] exportMembership(int t) {
		int 	x,y,z,k;
		float[][][][]	Mems = new float[classes][nx][ny][nz];
		
        for (k=0;k<classes;k++) {
            for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
				Mems[k][x][y][z] = mems[t][x][y][z][k];
			}
		}
		return Mems;
	} // exportMemberships
	
    /** 
	 *	returns the transform in the Mipav transformation matrix style
	 */
	public final double[][] convertTransform(int t) {
		int 	x,y,z,k,best;
		double[][]	mat = new double[4][4];
		       
		// compute the inverse transformation matrix (image to classification)
		// R0 = R^t
		for (int i=0;i<3;i++) for (int j=0;j<3;j++) {
			mat[i][j] = (double)transforms[t][j][i];
		}
		// T0 = -R^t T
        // warning: TransMatrix uses an awful convention where 
		for (int i=0;i<3;i++) {
			mat[i][3] = 0.0;
            for (int j=0;j<3;j++) {
				mat[i][3] += -mat[i][j]*(double)transforms[t][j][3];
			}
		}
		// bottom line
		for (int i=0;i<3;i++) mat[3][i] = 0.0;
		mat[3][3] = 1.0;
		
		return mat;
	} // convertTransformMatrix
	
	/** 
	 *	create ids for the centroids based on ordering
	 */
	public final int[] computeCentroidOrder() {
		int 	k,l;
		int[]	id = new int[classes+1];
		int		lowest;
		float[] cent = new float[clusters];
		
		// copy the centroids
        for (k=0;k<clusters;k++) 
			cent[k] = centroids[0][k];
		
		// add the zero
		id[0] = 0;

		// order them from smallest to largest
		for (k=0;k<clusters;k++) {
			lowest = 0;
			for (l=1;l<clusters;l++) {
				if (cent[l] < cent[lowest]) {
					lowest = l;
				}
			}
			id[k+1] = lowest+1;
			cent[lowest] = INF;
		}
		// keep order for other class types (outliers, etc)
		for (k=clusters;k<classes;k++)
			id[k+1] = k+1;
		
 		return id;
	} // computeCentroidOrder
	
}
