package edu.jhmi.rad.medic.methods;

import java.io.*;
import java.util.*;

import gov.nih.mipav.view.*;
import gov.nih.mipav.model.structures.jama.*;

import edu.jhmi.rad.medic.libraries.*;
import edu.jhmi.rad.medic.utilities.*;
 
/**
 *
 *  This algorithm handles the inhomogeneity correction
 *	for classification algorithms like EM and K-means.
 *	<p>
 *	The method is based on estimating the parameters of a low degree polynomial.
 *
 *
 *	@version    December 2004
 *	@author     Pierre-Louis Bazin
 *	@see		SegmentationFCM
 *	@see		FuzzyToads
 *	@see		LightMarchingSegmentation
 *
 */
 
public class InhomogeneityCorrectionEM {
		
	// numerical quantities
	private static final	float   INF=1e30f;
	private static final	float   ZERO=1e-30f;
	
	// data buffers
	private 	float[][][]			image;  			// original image
	private 	float[][][][]		mems;				// membership function
	private 	float[]				centroids;			// cluster centroids
	private 	float[]				variances;			// cluster variances
	private 	float[][][]			field;  			// inhomogeneity field
	private 	float[][][][]		fields;  			// separate inhomogeneity fields
	private 	float[][]			transform = null;	// transformation matrix
	private 	boolean[][][]		mask;   			// image mask: true for data points
	private 	byte[][][]			classification;   	// image classification: for discarding masked objects, if used
	private		static	int			nx,ny,nz;   		// image dimensions
	private		static	float		rx,ry,rz;   		// image resolutions
	private		static	int			dimensions;			// image dimensions (2D or 3D)
	
	// parameters
	private 	int 		clusters;   	// number of clusters
	private 	int 		classes;    	// number of classes in original membership: > clusters if outliers
    private		int			degree;			// polynomial function degree
	private		byte[]		templateLabel;	// the label for each class, if used
	private		String[]	objType;		// the type of object for each class, if used
	private		boolean		useObj;			// whether or not to use object type information
				
	// computation variables
	private float[]			pol;		// the array for the polynomial basis
	private	int				Np;			// the number of polynomial coefficients
	
	private	int				subsample = 3;	// the sub-sampling to speed up the computations
	
	// type of correction: image field, centroid field or separate centroid fields
	public static final	int   	NONE = 0;
	public static final	int		IMAGE = 1;
	public static final	int   	CENTROIDS = 2;
	public static final	int   	SEPARATE = 3;
	private	int					correctionType;	
	
	// computation flags
	private 	boolean 		isWorking;
	private 	boolean 		isCompleted;
	
	// for debug and display
	ViewUserInterface			UI;
    ViewJProgressBar            progressBar;
	static final boolean		debug=false;
	static final boolean		verbose=false;
	

	/**
	 *  constructor for K-means and EM
	 *	note: all images passed to the algorithm are just linked, not copied
	 */
	public InhomogeneityCorrectionEM(int deg_, int type_,
					 float[][][] image_, boolean [][][] mask_, 
					 float[][][][] mems_, float[] cent_,
					 float[] var_,
					 int nx_, int ny_, int nz_,
					 float rx_, float ry_, float rz_,
					 int classes_, int clusters_, int dim_,
					 ViewUserInterface UI_, ViewJProgressBar bar_) {
		
		image = image_;
		classification = null;
		mask = mask_;
		mems = mems_;
		centroids = cent_;
		variances = var_;
		
		degree = deg_;
		if (degree==1) Np = 4;
		else if (degree==2) Np = 10;
		else if (degree==3) Np = 20;
		else if (degree==4) Np = 35;
		else {
			isWorking = false;
			return;
		}		
		
		correctionType = type_;
		
		nx = nx_;
		ny = ny_;
		nz = nz_;
		
		rx = rx_;
		ry = ry_;
		rz = rz_;
		
		dimensions = dim_;
		
		classes = classes_;
		clusters = clusters_;
		useObj = false;
		templateLabel = null;
		objType = null;
		
		UI = UI_;
		progressBar = bar_;
		
		// init all the new arrays
		try {
			if (correctionType==SEPARATE) {
				fields = new float[nx][ny][nz][clusters];
				field = null;
			} else {
				field = new float[nx][ny][nz];
				fields = null;
			}
			pol = new float[Np];
			transform = new float[3][4];
		} catch (OutOfMemoryError e){
			isWorking = false;
            finalize();
			System.out.println(e.getMessage());
			return;
		}
		isWorking = true;

		// init values
		if (correctionType==SEPARATE) {
			for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) for (int k=0;k<clusters;k++) {
				fields[x][y][z][k] = 1.0f;
			}
		} else {
			for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
				field[x][y][z] = 1.0f;
			}
		}		
		for (int i=0;i<3;i++) {
			for (int j=0;j<4;j++)
				transform[i][j] = 0.0f;
			transform[i][i] = 1.0f;
		}
		if (debug) MedicUtilPublic.displayMessage("IC:initialisation\n");
	}

	/** clean-up: destroy membership and centroid arrays */
	public final void finalize() {
		field = null;
		pol = null;
		transform = null;
		System.gc();
	}
	
    /** accessor for computed data */ 
    public final float[][][] getField() { return field; }
	public final float[][][][] getFields() { return fields; }
	/** accessor for computed data */
	public final void importTransform(float[][] trans_) { 
		for (int i=0;i<3;i++) for (int j=0;j<4;j++) transform[i][j] = trans_[i][j]; 
	}
	/** accessor for computed data */
	public final void setMemberships(float[][][][] mem) { mems = mem; }
	/** accessor for computed data */
	public final void setMask(boolean[][][] msk) { mask = msk; }
	/** accessor for computed data */
	public final void setImage(float[][][] img) { image = img; }
    /** computation flags */
	public final boolean isWorking() { return isWorking; }
	/** computation flags */
	public final boolean isCompleted() { return isCompleted; }
	
    /** 
	 *  compute the inhomogeneity field using a Chebyshev polynomial basis
	 *	(main function to call)
	 */
	 final public float computeCorrectionField(int deg, boolean eqVars) {
		 if (correctionType==IMAGE) {
			 if (dimensions==2) return computeNormalized2DPolynomialImageField(deg, eqVars);
			 else return computeNormalized3DPolynomialImageField(deg, eqVars);
		 } else if (correctionType==CENTROIDS) {
			 if (dimensions==3) return computeNormalized3DPolynomialCentroidField(deg, eqVars);
		 } else if (correctionType==SEPARATE) {
			 if (dimensions==3) return computeNormalized3DPolynomialSeparateField(deg, eqVars);
		 }
		 return -1;
	 }
    
    /** 
	 *  compute the inhomogeneity field using a Chebyshev polynomial basis
	 *  with adequate normalization
	 */
    final private float computeNormalized3DPolynomialImageField(int deg, boolean eqVars) {
        int x,y,z,n,m,k;
        int progress, mod;
        long inner_loop_time;
        float den,num;
	float min,max,mean;
	float	diff,val;
	JamaMatrix          coeff,factor,img;
	int np = 0;
	float w;
		        
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();

		// compute the degree and size
		if (deg>degree) {
			deg = degree;
			np = Np;
		} else if (deg==0) {
			return 0;
		} else if (deg==1) {
			np = 4;
		} else if (deg==2) {
			np = 10;
		} else if (deg==3) {
			np = 20;
		} else if (deg>=4) {
			np = 35;
		}		
		// factor = pol*mems*centroids*pol; img = pol*image
		img = new JamaMatrix(np,1,0.0f);
        factor = new JamaMatrix(np,np,0.0f);
	
		progress = 0;
        mod = nx*ny*nz/10; // mod is 1 percent of length

		for (x=0;x<nx;x+=subsample) for (y=0;y<ny;y+=subsample) for (z=0;z<nz;z+=subsample) {
			progress++;
			if ( (verbose) && (progress%mod==0) )
                progressBar.updateValue(10*Math.round( (float)progress/(float)mod),false);
			
			if (useObj) {
				for (k=0;k<clusters;k++) if (classification[x][y][z]==templateLabel[k]) if (objType[k].equals("obj") || objType[k].equals("bg")) {
					// valid image point
					compute3DChebyshevCoefficients(x,y,z,deg);		
					for (n=0;n<np;n++) {
						for (k=0;k<clusters;k++) {
							if (objType[k].equals("obj") || objType[k].equals("bg")) {
							    if (eqVars) w = mems[x][y][z][k];
							    else w = mems[x][y][z][k]/(2*variances[k]);
								img.set(n,0, img.get(n,0) + w*image[x][y][z]*centroids[k]*pol[n] );
								for (m=0;m<np;m++) {
									factor.set(n,m, factor.get(n,m) + w*image[x][y][z]*image[x][y][z]*pol[n]*pol[m] );
								}
							}
						}
					}
				}	
			} else {
				if (mask[x][y][z]) {
					// valid image point
					compute3DChebyshevCoefficients(x,y,z,deg);		
					for (n=0;n<np;n++) {
						for (k=0;k<clusters;k++) {
						    if (eqVars) w = mems[x][y][z][k];
						    else w = mems[x][y][z][k]/(2*variances[k]);
						    img.set(n,0, img.get(n,0) + w*image[x][y][z]*centroids[k]*pol[n] );
						    for (m=0;m<np;m++) {
							factor.set(n,m, factor.get(n,m) + w*image[x][y][z]*image[x][y][z]*pol[n]*pol[m] );
							}
						}
					}
				}
			}
		}
		for (n=0;n<np;n++) {
			img.set(n,0, img.get(n,0)/(float)(nx*ny*nz) );
			for (m=0;m<np;m++)
				factor.set(n,m, factor.get(n,m)/(float)(nx*ny*nz) );
		}
		if (debug) {
			//MedicUtilPublic.displayMessage("factor: "+factor.matrixToString(5,2)+"\n");
			//MedicUtilPublic.displayMessage("img: "+img.matrixToString(5,2)+"\n");
		}
		// coeff = factor^-1 img
		coeff = factor.solve(img);
        
		//if (debug) MedicUtilPublic.displayMessage("coeff: "+coeff.matrixToString(5,2)+"\n");
		
		// transfer to the field
		min = INF;
		max = -INF;
		mean = 0.0f;
		den = 0.0f;
		diff = 0.0f;
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			val = field[x][y][z];
			field[x][y][z] = 0.0f;
			if (useObj) {
				for (k=0;k<clusters;k++) if (classification[x][y][z]==templateLabel[k]) if (objType[k].equals("obj") || objType[k].equals("bg")) {
					compute3DChebyshevCoefficients(x,y,z,deg);		
					for (n=0;n<np;n++) {
						field[x][y][z] += coeff.get(n,0)*pol[n];
					}
					if (field[x][y][z] < min) min = field[x][y][z];
					else if (field[x][y][z] > max) max = field[x][y][z];
					mean += field[x][y][z];
					den++;
				}
			} else {
				if (mask[x][y][z]) {
					compute3DChebyshevCoefficients(x,y,z,deg);		
					for (n=0;n<np;n++) {
						field[x][y][z] += coeff.get(n,0)*pol[n];
					}
					if (field[x][y][z] < min) min = field[x][y][z];
					else if (field[x][y][z] > max) max = field[x][y][z];
					mean += field[x][y][z];
					den++;
				}
			}
			if (field[x][y][z]>val) diff += field[x][y][z]-val;
			else diff += val-field[x][y][z];		
		}
		mean = mean/den;
		// normalize and fill the masked area with 1 rather than 0
		// in case of moving masks
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			if (field[x][y][z]==0) field[x][y][z] = 1.0f;
			else field[x][y][z] /= mean;
		}
		
		// output
		if (verbose) {
			MedicUtilPublic.displayMessage("IC: degree "+deg+", <min|avg|max> : <"+min+"|"+mean+"|"+max+">\n");
		}
		
		
		// clean-up
		factor = null;
		img = null;
		coeff = null;
		
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return diff;
    } // computeNormalizedPolynomialField
    
    /** 
	 *  compute the inhomogeneity field using a Chebyshev polynomial basis
	 *  with adequate normalization
	 */
    final private float computeNormalized2DPolynomialImageField(int deg, boolean eqVars) {
        int x,y,z,n,m,k;
        int progress, mod;
        long inner_loop_time;
        float den,num;
		float min,max,mean;
		JamaMatrix          coeff,factor,img;
		int np = 0;
		float w;
		        
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();

		// compute the degree and size
		if (deg>degree) {
			deg = degree;
			np = Np;
		} else if (deg==0) {
			return 0;
		} else if (deg==1) {
			np = 3;
		} else if (deg==2) {
			np = 6;
		} else if (deg==3) {
			np = 10;
		} else if (deg>=4) {
			np = 15;
		}		
		// factor = pol*mems*centroids*pol; img = pol*image
		img = new JamaMatrix(np,1,0.0f);
        factor = new JamaMatrix(np,np,0.0f);
	
		progress = 0;
        mod = nx*ny*nz/10; // mod is 1 percent of length

		for (x=0;x<nx;x+=subsample) for (y=0;y<ny;y+=subsample) for (z=0;z<nz;z++) {
			progress++;
			if ( (verbose) && (progress%mod==0) )
                progressBar.updateValue(10*Math.round( (float)progress/(float)mod),false);
			
			if (mask[x][y][z]) {
				compute2DChebyshevCoefficients(x,y,z,deg);		
				for (n=0;n<np;n++) {
					for (k=0;k<clusters;k++) {
					    if (eqVars) w = mems[x][y][z][k];
					    else w = mems[x][y][z][k]/(2*variances[k]);

						img.set(n,0, img.get(n,0) + w*image[x][y][z]*centroids[k]*pol[n] );
						for (m=0;m<np;m++) {
							factor.set(n,m, factor.get(n,m) + w*image[x][y][z]*image[x][y][z]*pol[n]*pol[m] );
						}
					}
				}
			}
		}
		for (n=0;n<np;n++) {
			img.set(n,0, img.get(n,0)/(float)(nx*ny*nz) );
			for (m=0;m<np;m++)
				factor.set(n,m, factor.get(n,m)/(float)(nx*ny*nz) );
		}
		if (debug) {
			//MedicUtilPublic.displayMessage("factor: "+factor.matrixToString(5,2)+"\n");
			//MedicUtilPublic.displayMessage("img: "+img.matrixToString(5,2)+"\n");
		}
		// coeff = factor^-1 img
		coeff = factor.solve(img);
        
		//if (debug) MedicUtilPublic.displayMessage("coeff: "+coeff.matrixToString(5,2)+"\n");
		
		// transfer to the field
		min = INF;
		max = -INF;
		mean = 0.0f;
		den = 0.0f;
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			field[x][y][z] = 0.0f;
			if (mask[x][y][z]) {
				compute2DChebyshevCoefficients(x,y,z,deg);		
				for (n=0;n<np;n++) {
					field[x][y][z] += coeff.get(n,0)*pol[n];
				}
				if (field[x][y][z] < min) min = field[x][y][z];
				else if (field[x][y][z] > max) max = field[x][y][z];
				mean += field[x][y][z];
				den++;
			}
		}
		mean = mean/den;
		// normalize
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			field[x][y][z] /= mean;
		}
		
		// output
		if (verbose) {
			MedicUtilPublic.displayMessage("IC: degree "+deg+", <min|avg|max> : <"+min+"|"+mean+"|"+max+">\n");
		}
		
		
		// clean-up
		factor = null;
		img = null;
		coeff = null;
		
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return 0;
    } // computeNormalizedPolynomialField
    

	/** 
	 *  compute the inhomogeneity field using a Chebyshev polynomial basis
	 *  with adequate normalization (variant: the field factors with the centroids)
	 */
    final private float computeNormalized3DPolynomialCentroidField(int deg, boolean eqVars) {
        int x,y,z,n,m,k;
        int progress, mod;
        long inner_loop_time;
        float den,num;
		float min,max,mean;
		float	diff,val;
		JamaMatrix          coeff,factor,img;
		int np = 0;
		float w;
		        
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();

		// compute the degree and size
		if (deg>degree) {
			deg = degree;
			np = Np;
		} else if (deg==0) {
			return 0;
		} else if (deg==1) {
			np = 4;
		} else if (deg==2) {
			np = 10;
		} else if (deg==3) {
			np = 20;
		} else if (deg>=4) {
			np = 35;
		}		
		// factor = pol*mems*centroids*pol; img = pol*image
		img = new JamaMatrix(np,1,0.0f);
        factor = new JamaMatrix(np,np,0.0f);
	
		progress = 0;
        mod = nx*ny*nz/10; // mod is 1 percent of length

		for (x=0;x<nx;x+=subsample) for (y=0;y<ny;y+=subsample) for (z=0;z<nz;z+=subsample) {
			progress++;
			if ( (verbose) && (progress%mod==0) )
                progressBar.updateValue(10*Math.round( (float)progress/(float)mod),false);
			
			if (useObj) {
				for (k=0;k<clusters;k++) if (classification[x][y][z]==templateLabel[k]) if (objType[k].equals("obj") || objType[k].equals("bg")) {
					// valid image point
					compute3DChebyshevCoefficients(x,y,z,deg);		
					for (n=0;n<np;n++) {
						for (k=0;k<clusters;k++) {
							if (objType[k].equals("obj") || objType[k].equals("bg")) {
							    if (eqVars) w = mems[x][y][z][k];
							    else w = mems[x][y][z][k]/(2*variances[k]);

								img.set(n,0, img.get(n,0) + w*image[x][y][z]*centroids[k]*pol[n] );
								for (m=0;m<np;m++) {
									factor.set(n,m, factor.get(n,m) + w*centroids[k]*centroids[k]*pol[n]*pol[m] );
								}
							}
						}
					}
				}	
			} else {
				if (mask[x][y][z]) {
					// valid image point
					compute3DChebyshevCoefficients(x,y,z,deg);		
					for (n=0;n<np;n++) {
						for (k=0;k<clusters;k++) {
						    if (eqVars) w = mems[x][y][z][k];
						    else w = mems[x][y][z][k]/(2*variances[k]);

							img.set(n,0, img.get(n,0) + w*image[x][y][z]*centroids[k]*pol[n] );
							for (m=0;m<np;m++) {
								factor.set(n,m, factor.get(n,m) + w*centroids[k]*centroids[k]*pol[n]*pol[m] );
							}
						}
					}
				}
			}
		}
		for (n=0;n<np;n++) {
			img.set(n,0, img.get(n,0)/(float)(nx*ny*nz) );
			for (m=0;m<np;m++)
				factor.set(n,m, factor.get(n,m)/(float)(nx*ny*nz) );
		}
		if (debug) {
			//MedicUtilPublic.displayMessage("factor: "+factor.matrixToString(5,2)+"\n");
			//MedicUtilPublic.displayMessage("img: "+img.matrixToString(5,2)+"\n");
		}
		// coeff = factor^-1 img
		coeff = factor.solve(img);
        
		//if (debug) MedicUtilPublic.displayMessage("coeff: "+coeff.matrixToString(5,2)+"\n");
		
		// transfer to the field
		min = INF;
		max = -INF;
		mean = 0.0f;
		den = 0.0f;
		diff = 0.0f;
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			val = field[x][y][z];
			field[x][y][z] = 0.0f;
			if (useObj) {
				for (k=0;k<clusters;k++) if (classification[x][y][z]==templateLabel[k]) if (objType[k].equals("obj") || objType[k].equals("bg")) {
					compute3DChebyshevCoefficients(x,y,z,deg);		
					for (n=0;n<np;n++) {
						field[x][y][z] += coeff.get(n,0)*pol[n];
					}
					if (field[x][y][z] < min) min = field[x][y][z];
					else if (field[x][y][z] > max) max = field[x][y][z];
					mean += field[x][y][z];
					den++;
				}
			} else {
				if (mask[x][y][z]) {
					compute3DChebyshevCoefficients(x,y,z,deg);		
					for (n=0;n<np;n++) {
						field[x][y][z] += coeff.get(n,0)*pol[n];
					}
					if (field[x][y][z] < min) min = field[x][y][z];
					else if (field[x][y][z] > max) max = field[x][y][z];
					mean += field[x][y][z];
					den++;
				}
			}
			if (field[x][y][z]>val) diff += field[x][y][z]-val;
			else diff += val-field[x][y][z];		
		}
		mean = mean/den;
		// normalize and fill the masked area with 1 rather than 0
		// in case of moving masks
		for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			if (field[x][y][z]==0) field[x][y][z] = 1.0f;
			else field[x][y][z] /= mean;
		}
		
		// output
		if (verbose) {
			MedicUtilPublic.displayMessage("IC: degree "+deg+", <min|avg|max> : <"+min+"|"+mean+"|"+max+">\n");
		}
		
		
		// clean-up
		factor = null;
		img = null;
		coeff = null;
		
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return diff;
    } // computeNormalized3DPolynomialInverseField

		/** 
	 *  compute the inhomogeneity field using a Chebyshev polynomial basis
	 *  with adequate normalization (variant: the field factors with one centroid)
	 */
    final private float computeNormalized3DPolynomialSeparateField(int deg, boolean eqVars) {
        int x,y,z,n,m;
        int progress, mod;
        long inner_loop_time;
        float den,num;
		float min,max,mean;
		float	diff,val;
		JamaMatrix          coeff,factor,img;
		int np = 0;
		float w;
		        
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();

		// compute the degree and size
		if (deg>degree) {
			deg = degree;
			np = Np;
		} else if (deg==0) {
			return 0;
		} else if (deg==1) {
			np = 4;
		} else if (deg==2) {
			np = 10;
		} else if (deg==3) {
			np = 20;
		} else if (deg>=4) {
			np = 35;
		}		
		diff = 0.0f;
			
		for (int k=0;k<clusters;k++) {
			// factor = pol*mems*centroids*pol; img = pol*image
			img = new JamaMatrix(np,1,0.0f);
			factor = new JamaMatrix(np,np,0.0f);
	
			for (x=0;x<nx;x+=subsample) for (y=0;y<ny;y+=subsample) for (z=0;z<nz;z+=subsample) {
				progress++;
				if ( (verbose) && (progress%mod==0) )
					progressBar.updateValue(10*Math.round( (float)progress/(float)mod),false);
				
				if (useObj) {
					if (classification[x][y][z]==templateLabel[k]) if (objType[k].equals("obj") || objType[k].equals("bg")) {
						// valid image point
						compute3DChebyshevCoefficients(x,y,z,deg);		
						for (n=0;n<np;n++) {
							if (objType[k].equals("obj") || objType[k].equals("bg")) {
							    if (eqVars) w = mems[x][y][z][k];
							    else w = mems[x][y][z][k]/(2*variances[k]);

								img.set(n,0, img.get(n,0) + w*image[x][y][z]*centroids[k]*pol[n] );
								for (m=0;m<np;m++) {
									factor.set(n,m, factor.get(n,m) + w*centroids[k]*centroids[k]*pol[n]*pol[m] );
								}
							}
						}
					}	
				} else {
					if (mask[x][y][z]) {
						// valid image point
						compute3DChebyshevCoefficients(x,y,z,deg);		
						for (n=0;n<np;n++) {
						    if (eqVars) w = mems[x][y][z][k];
						    else w = mems[x][y][z][k]/(2*variances[k]);

							img.set(n,0, img.get(n,0) + w*image[x][y][z]*centroids[k]*pol[n] );
							for (m=0;m<np;m++) {
								factor.set(n,m, factor.get(n,m) + w*centroids[k]*centroids[k]*pol[n]*pol[m] );
							}
						}
					}
				}
			}
			for (n=0;n<np;n++) {
				img.set(n,0, img.get(n,0)/(float)(nx*ny*nz) );
				for (m=0;m<np;m++)
					factor.set(n,m, factor.get(n,m)/(float)(nx*ny*nz) );
			}
			if (debug) {
				//MedicUtilPublic.displayMessage("factor: "+factor.matrixToString(5,2)+"\n");
				//MedicUtilPublic.displayMessage("img: "+img.matrixToString(5,2)+"\n");
			}
			// coeff = factor^-1 img
			coeff = factor.solve(img);
			
			//if (debug) MedicUtilPublic.displayMessage("coeff: "+coeff.matrixToString(5,2)+"\n");
			
			// transfer to the field
			min = INF;
			max = -INF;
			mean = 0.0f;
			den = 0.0f;
			for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
				val = fields[x][y][z][k];
				fields[x][y][z][k] = 0.0f;
				if (useObj) {
					if (classification[x][y][z]==templateLabel[k]) if (objType[k].equals("obj") || objType[k].equals("bg")) {
						compute3DChebyshevCoefficients(x,y,z,deg);		
						for (n=0;n<np;n++) {
							fields[x][y][z][k] += coeff.get(n,0)*pol[n];
						}
						if (fields[x][y][z][k] < min) min = fields[x][y][z][k];
						else if (fields[x][y][z][k] > max) max = fields[x][y][z][k];
						mean += fields[x][y][z][k];
						den++;
					}
				} else {
					if (mask[x][y][z]) {
						compute3DChebyshevCoefficients(x,y,z,deg);		
						for (n=0;n<np;n++) {
							fields[x][y][z][k] += coeff.get(n,0)*pol[n];
						}
						if (fields[x][y][z][k] < min) min = fields[x][y][z][k];
						else if (fields[x][y][z][k] > max) max = fields[x][y][z][k];
						mean += fields[x][y][z][k];
						den++;
					}
				}
				if (fields[x][y][z][k]>val) diff += fields[x][y][z][k]-val;
				else diff += val-fields[x][y][z][k];		
			}
			mean = mean/den;
			// normalize and fill the masked area with 1 rather than 0
			// in case of moving masks
			for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
				if (fields[x][y][z][k]==0) fields[x][y][z][k] = 1.0f;
				else fields[x][y][z][k] /= mean;
			}
			
			// output
			if (verbose) {
				MedicUtilPublic.displayMessage("IC: degree "+deg+", <min|avg|max> : <"+min+"|"+mean+"|"+max+">\n");
			}
			// clean-up
			factor = null;
			img = null;
			coeff = null;
		
		}		
		
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return diff;
    } // computeNormalizedPolynomialField


	/** compute the 3D Chebyshev polynomial basis for one point */
	private final void compute3DChebyshevCoefficients(int x, int y, int z, int deg) {
		float rx,ry,rz;
		
		for (int n=0;n<Np;n++) {
			pol[n] = 0.0f;
		}
		
		rx = x/(float)(nx-1);
		ry = y/(float)(ny-1);
		rz = z/(float)(nz-1);

		if (deg==1) {
			pol[0] = 1.0f;
			// x
			pol[1] = rx;
			pol[2] = ry;
			pol[3] = rz;
		} else if (deg==2) {
			pol[0] = 1.0f;
			// x
			pol[1] = rx;
			pol[2] = ry;
			pol[3] = rz;
			// x^2
			pol[4] = 2.0f*rx*rx - 1.0f;
			pol[5] = 2.0f*ry*ry - 1.0f;
			pol[6] = 2.0f*rz*rz - 1.0f;
			// xy
			pol[7] = pol[1]*pol[2];
			pol[8] = pol[2]*pol[3];
			pol[9] = pol[3]*pol[1];
		} else if (deg==3) {
			pol[0] = 1.0f;
			// x
			pol[1] = rx;
			pol[2] = ry;
			pol[3] = rz;
			// x^2
			pol[4] = 2.0f*rx*rx - 1.0f;
			pol[5] = 2.0f*ry*ry - 1.0f;
			pol[6] = 2.0f*rz*rz - 1.0f;
			// xy
			pol[7] = pol[1]*pol[2];
			pol[8] = pol[2]*pol[3];
			pol[9] = pol[3]*pol[1];
			// x^3
			pol[10] = 4.0f*rx*rx*rx - 3.0f*rx;
			pol[11] = 4.0f*ry*ry*ry - 3.0f*ry;
			pol[12] = 4.0f*rz*rz*rz - 3.0f*rz;
			// xy^2
			pol[13] = pol[4]*pol[2];
			pol[14] = pol[1]*pol[5];
			pol[15] = pol[5]*pol[3];
			pol[16] = pol[2]*pol[6];
			pol[17] = pol[6]*pol[1];
			pol[18] = pol[3]*pol[4];
			// xyz
			pol[19] = pol[1]*pol[2]*pol[3];
		} else if (deg>=4) {
			pol[0] = 1.0f;
			// x
			pol[1] = rx;
			pol[2] = ry;
			pol[3] = rz;
			// x^2
			pol[4] = 2.0f*rx*rx - 1.0f;
			pol[5] = 2.0f*ry*ry - 1.0f;
			pol[6] = 2.0f*rz*rz - 1.0f;
			// xy 
			pol[7] = pol[1]*pol[2];
			pol[8] = pol[2]*pol[3];
			pol[9] = pol[3]*pol[1];
			// x^3
			pol[10] = 4.0f*rx*rx*rx - 3.0f*rx;
			pol[11] = 4.0f*ry*ry*ry - 3.0f*ry;
			pol[12] = 4.0f*rz*rz*rz - 3.0f*rz;
			// x^2y 
			pol[13] = pol[4]*pol[2];
			pol[14] = pol[1]*pol[5];
			pol[15] = pol[5]*pol[3];
			pol[16] = pol[2]*pol[6];
			pol[17] = pol[6]*pol[1];
			pol[18] = pol[3]*pol[4];
			// xyz 
			pol[19] = pol[1]*pol[2]*pol[3];
			// x^4
			pol[20] = 8.0f*rx*rx*rx*rx - 8.0f*rx*rx + 1.0f;
			pol[21] = 8.0f*ry*ry*ry*ry - 8.0f*ry*ry + 1.0f;
			pol[22] = 8.0f*rz*rz*rz*rz - 8.0f*rz*rz + 1.0f;
			// x^3y 
			pol[23] = pol[10]*pol[2];
			pol[24] = pol[10]*pol[3];
			pol[25] = pol[11]*pol[3];
			pol[26] = pol[11]*pol[1];
			pol[27] = pol[12]*pol[1];
			pol[28] = pol[12]*pol[2];
			// x^2y^2 
			pol[29] = pol[4]*pol[5];
			pol[30] = pol[5]*pol[6];
			pol[31] = pol[6]*pol[4];
			// x^2yz 
			pol[32] = pol[4]*pol[2]*pol[3];
			pol[33] = pol[1]*pol[5]*pol[3];
			pol[34] = pol[1]*pol[2]*pol[6];
		} 
		return;
	}
	
	/** compute the 2D Chebyshev polynomial basis for one point */
	private final void compute2DChebyshevCoefficients(int x, int y, int z, int deg) {
		float rx,ry,rz;
		
		for (int n=0;n<Np;n++) {
			pol[n] = 0.0f;
		}
		
		rx = x/(float)(nx-1);
		ry = y/(float)(ny-1);
		rz = z/(float)(nz-1);

		if (deg==1) {
			pol[0] = 1.0f;
			// x
			pol[1] = rx;
			pol[2] = ry;
		} else if (deg==2) {
			pol[0] = 1.0f;
			// x
			pol[1] = rx;
			pol[2] = ry;
			// x^2
			pol[3] = 2.0f*rx*rx - 1.0f;
			pol[4] = 2.0f*ry*ry - 1.0f;
			// xy
			pol[5] = pol[1]*pol[2];
		} else if (deg==3) {
			pol[0] = 1.0f;
			// x
			pol[1] = rx;
			pol[2] = ry;
			// x^2
			pol[3] = 2.0f*rx*rx - 1.0f;
			pol[4] = 2.0f*ry*ry - 1.0f;
			// xy
			pol[5] = pol[1]*pol[2];
			// x^3
			pol[6] = 4.0f*rx*rx*rx - 3.0f*rx;
			pol[7] = 4.0f*ry*ry*ry - 3.0f*ry;
			// xy^2
			pol[8] = pol[3]*pol[2];
			pol[9] = pol[1]*pol[4];
		} else if (deg>=4) {
			pol[0] = 1.0f;
			// x
			pol[1] = rx;
			pol[2] = ry;
			// x^2
			pol[3] = 2.0f*rx*rx - 1.0f;
			pol[4] = 2.0f*ry*ry - 1.0f;
			// xy 
			pol[5] = pol[1]*pol[2];
			// x^3
			pol[6] = 4.0f*rx*rx*rx - 3.0f*rx;
			pol[7] = 4.0f*ry*ry*ry - 3.0f*ry;
			// x^2y 
			pol[8] = pol[3]*pol[2];
			pol[9] = pol[1]*pol[4];
			// x^4
			pol[10] = 8.0f*rx*rx*rx*rx - 8.0f*rx*rx + 1.0f;
			pol[11] = 8.0f*ry*ry*ry*ry - 8.0f*ry*ry + 1.0f;
			// x^3y 
			pol[12] = pol[6]*pol[2];
			pol[13] = pol[7]*pol[1];
			// x^2y^2 
			pol[14] = pol[3]*pol[4];
		} 
		return;
	}
	
	/** 
	 *	export field 
	 */
	public final float[][][] exportField() {
		int 	x,y,z;
		float[][][]	Field = new float[nx][ny][nz];
		
        for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			Field[x][y][z] = field[x][y][z];
		}
		return Field;
	} // exportField

	/** 
	 *	export fields 
	 */
	public final float[][][][] exportFields() {
		int 	x,y,z;
		float[][][][]	Fields = new float[nx][ny][nz][clusters];
		
        for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) for (int k=0;k<clusters;k++) {
			Fields[x][y][z][k] = fields[x][y][z][k];
		}
		return Fields;
	} // exportFields

	/** 
	 *	export field for a single class 
	 */
	public final float[][][] exportFields(int k) {
		int 	x,y,z;
		float[][][]	Fields = new float[nx][ny][nz];
		
        for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			Fields[x][y][z] = fields[x][y][z][k];
		}
		return Fields;
	} // exportFields

}
