package edu.jhmi.rad.medic.methods;

import java.io.*;
import java.util.*;
import gov.nih.mipav.view.*;

import edu.jhmi.rad.medic.libraries.*;
import edu.jhmi.rad.medic.utilities.*;
import edu.jhmi.rad.medic.structures.*;

/**
 *
 *  This class handles full structure atlas information:
 *	shape, objType, relations, etc.
 *
 *	@version    June 2006
 *	@author     Pierre-Louis Bazin
 *		
 *
 */
 
public class DotsAtlas {
	
	// structures: basic information	
	private		int					classes;			// number of structures in the atlas
	private		String[]			name;				// their names
	private		byte[]				label;				// their labels
	private		String[]			objType;			// their objType type
	private		String				atlasFile;			// the atlas file
	
	// atlas quantities
	private 	int 				nix,niy,niz; 			// image dimensions
	private 	float 				rix,riy,riz; 			// image resolutions
	private		int					orient,orix,oriy,oriz;		// image orientations
	
	// spatial transformations
	private		float[]				transform;		// the transform parameters to get into image space
	private		float[][]			rotation;		// the associated rotation matrix
	private		float[][][]			shapeTransform; // the global transform matrix (XI = A XP) for each shape
	private		int					Nd;				// transform dimension
	private		float				maxscale = 4.0f;
	private		float				x0i,y0i,z0i;		// the center of the image
	private		float				scalingFactor = 0.1f; // maximum variation due to scaling
	
	// shape maps
	private		boolean[]			hasShape;			// flag to notify which structures use a shape atlas
	private		float[][]			shape;				// the shape images
	private		String[]			shapeFile;			// the shape filenames
	private		boolean[]			hasDir;			// flag to notify which structures use a shape atlas
	private		float[][][]			direction;				// the shape images
	private		String[]			dirFile;			// the shape filenames
	private		int[]				nsx,nsy,nsz;		// the dimensions for each shape image
	private 	float[]				rsx,rsy,rsz; 		// shape resolutions
	private		float				shapeScale;			//	the slope of the sigmoid prior based on distance functions
	private		int					labelSamples;		// number of samples in the shape/distasnce/contact/direction model
	private		int[]				minx,miny,minz;		// the lowest image coordinate with non zero prior for each structure
	private		int[]				maxx,maxy,maxz;		// the highest image coordinate with non zero prior for each structure
	private		float[][]			center;
	private		float[][]			tractOverlap;		// the amount of variability in the shape (0: very variable, 1: fixed)
	private		boolean				priorTractOverlap = false;
	
	// objType template
	private		boolean				hasTopology;	// flag to notify if there is a objType template
	private		byte[]				template;		// the objType template for the segmentation
	private		String				templateFile;	// the objType template for the segmentation
	private		int					ntx,nty,ntz;	// the objType atlas dimensions
	private 	float 				rtx,rty,rtz; 	// image resolutions
	private		float				x0t,y0t,z0t;	// the center of the objType image
	
	// for registration
	private 	float[]			famap;
	private 	float[][]		mems;
	private 	short[][]		lbmems;
	private		int				nbest;
	private		int				dataSource;
	private	static final int	FA=1;
	private	static final int	MEMS=2;
	
	private		ParametricTransform		transformModel;	// the type of transform (from possible ones below)
	private		int					transformMode;	// the type of transform (from possible ones below)
	private static final	int   	NONE = 0;
	private static final	int   	PARAMETRIC = 1;
	private static final	int   	DEFORMABLE = 2;
	private		boolean[]			registeredShape;
	
	// non-linear registration
	private 	DotsWarping		demons;
	
	// Levenberg Marquardt parameters
	private		float		chisq,oldchisq;			// the chi square error value
	private		float		lambda	=	1.0f;		// the Levenberg-Marquardt parameter for Levenberg Marquardt estimation
	//private		float[]		hessian, gradient;			// the coefficients to compute 
	private		float		chiPrecision = 1e-3f;	// the lower limit on the chi square value 
	private		float		lfactor = 1.5f;			// the multiplicative factor for the adaptive term
	private		int			itSupp = 10;				// maximum of steps if the cost function is not improving
	private		int			itMax,itPlus,Nturn;		// counters for various loops
	private static final	float   INIT_LAMBDA = 1;
	private		float		minEdiff = 1e-6f;		// the minimum variation of energy to require a better alignment
	private		float		minLambda = 0.001f;		// the minimum variation of energy to require a better alignment
	private		int			subsample = 3;			// scale for the registration: just subsample the volume 
	private		int 		levels = 1;				// number of image scales (pyramid)
	private		int			offset = 0;				// offset used in subsampling (cyclic)
	private 	boolean		precompute=true;
	
	// preset computation arrays for speed up
	private		float[]		XP;
	
	// constants
	private static final	float	PI2 = (float)(Math.PI/2.0);
	
	// for debug and display
	private static final boolean		debug=true;
	private static final boolean		verbose=true;
	
	
	// numerics
	private static final	float   INF=1e30f;
	private static final	float   ZERO=1e-30f;

	// convenience tags
	private static final	int   X=0;
	private static final	int   Y=1;
	private static final	int   Z=2;
	private static final	int   T=3;

	/**
	 *	constructor: load the atlas information from a file.
	 *	<p>
	 *	The atlas files follow a certain template; 
	 *  the separator between numbers is a tab, not a space.
	 */
	public DotsAtlas(String filename) {
		
		transformMode = NONE;
		Nd = 0;
		transform = null;
		
		labelSamples = 0;
		loadAtlas(filename);
		
		XP = new float[3];
	}
	
	/**
	 *	constructor: create an empty atlas.
	 */
	public DotsAtlas(int Nc_) {
		
		classes = Nc_;
		
		transformMode = NONE;
		Nd = 0;
		transform = null;
		
		labelSamples = 0;
		
		// allocate everiything
		name = new String[classes];
		label = new byte[classes];
		objType = new String[classes];
			
		hasShape = new boolean[classes];
		for (int n=0;n<classes;n++) hasShape[n] = false;
		shape = new float[classes][];
		nsx = new int[classes];
		nsy = new int[classes];
		nsz = new int[classes];
		minx = new int[classes];
		miny = new int[classes];
		minz = new int[classes];
		maxx = new int[classes];
		maxy = new int[classes];
		maxz = new int[classes];
		center = new float[classes][3];
	}
	
	/** clean-up: destroy membership and centroid arrays */
	public final void finalize() {
		shape = null;
		template = null;
		System.gc();
	}
	
	/** link the variables */
	final public int 		getNumber() { return classes; }
	final public String[] 	getNames() { return name; }
	final public String[] 	getObjectType() { return objType; }
	final public byte[]		getLabels() { return label; }
	
	final public void		setName(int id, String txt) { name[id] = txt; }
	final public void		setLabel(int id, byte val) { label[id] = val; }
	final public void		setTopology(int id, String txt) { objType[id] = txt; }
	
	final public void setShapeScale(float s_) { shapeScale = s_; }
	
	final public byte[] 	getTemplate() { return template; }
	final public float[] 	getShape(int n) { return shape[n]; }
	final public float[][] 	getShapes() { return shape; }
	
	final public void	setTemplate(byte[] t) { template = t; }
	final public void	setShapes(float[][] s) { shape = s; }
	final public void	setShape(int n, float[] s) { shape[n] = s; }
	
	final public float	getOverlap(int id1, int id2) { return tractOverlap[id1][id2]; }
	
	final public int 	getLabelSamples() { return labelSamples; }
	
	final public float[] 		getTransform() { 		
		return transform; 
	}
	final public void 		setTransform(float[] trans) { transform = trans; }
	
	final public boolean isRegistered(int k) {
		return registeredShape[k];
	}
		
	final public boolean isDeformable() {
		return (transformMode==DEFORMABLE);
	}
		
	final public float[][] exportDeformationField() { 
		if (transformMode==DEFORMABLE) return demons.exportTransformField();
		else return null; 
	}

	final public float[][] exportRigidDeformationField() { 
		float[][] s = new float[3][nix*niy*niz];
		
		float[][] mat =new float[3][4];
		transformModel.precomputeImageToTemplateMatrix(mat, transform, transformModel.computeRotation(transform), 1.0f);
		/*
		MedicUtil.displayMessage("transform matrix: \n ["+mat[X][X]+", "+mat[X][Y]+", "+mat[X][Z]+", "+mat[X][T]+"]\n ["
															   +mat[Y][X]+", "+mat[Y][Y]+", "+mat[Y][Z]+", "+mat[Y][T]+"]\n ["
															   +mat[Z][X]+", "+mat[Z][Y]+", "+mat[Z][Z]+", "+mat[Z][T]+"]\n");
		*/
		for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
			int xyz = x+nix*y+nix*niy*z;
			s[X][xyz] = mat[X][X]*x + mat[X][Y]*y + mat[X][Z]*z + mat[X][T] - x*rix/rtx;
			s[Y][xyz] = mat[Y][X]*x + mat[Y][Y]*y + mat[Y][Z]*z + mat[Y][T] - y*riy/rty;
			s[Z][xyz] = mat[Z][X]*x + mat[Z][Y]*y + mat[Z][Z]*z + mat[Z][T] - z*riz/rtz;
		}
		return s;
	}
			
	final public float[]	exportLabels() {
		float[] lb = new float[classes];
		for (int k=0;k<classes;k++) lb[k] = (float)label[k];
		return lb;
	}
	/*
	final public void importShape(int n, float[][][] img, int nsx_, int nsy_, int nsz_) {
		shape[n] = null;
		nsx[n] = nsx_;
		nsy[n] = nsy_;
		nsz[n] = nsz_;
		shape[n] = new float[nsx[n]][nsy[n]][nsz[n]];
		for (int x=0;x<nsx[n];x++) for (int y=0;y<nsy[n];y++) for (int z=0;z<nsz[n];z++) 
			shape[n][x][y][z] = img[x][y][z];
	}
	
	final public void importShapes(float[][][][] img, int nsx_, int nsy_, int nsz_) {
		for (int k=0;k<classes;k++) {
			nsx[k] = nsx_;
			nsy[k] = nsy_;
			nsz[k] = nsz_;
			shape[k] = new float[nsx[k]][nsy[k]][nsz[k]];
		}
		for (int k=0;k<classes;k++) for (int x=0;x<nsx[k];x++) for (int y=0;y<nsy[k];y++) for (int z=0;z<nsz[k];z++) 
			shape[k][x][y][z] = img[k][x][y][z];
	}
	*/
	final public int[] getTemplateDim() {
		int[] dim = new int[3];
		dim[0] = ntx;
		dim[1] = nty;
		dim[2] = ntz;
		return dim;
	}
	
   final public int[] getShapeDim(int n) {
		int[] dim = new int[3];
		dim[0] = nsx[n];
		dim[1] = nsy[n];
		dim[2] = nsz[n];
		return dim;
	}
	
    final public int[] getImageDim() {
		int[] dim = new int[3];
		dim[0] = nix;
		dim[1] = niy;
		dim[2] = niz;
		return dim;
	}
	
    final public float[] getImageRes() {
		float[] res = new float[3];
		res[0] = rix;
		res[1] = riy;
		res[2] = riz;
		return res;
	}
	
	final public int[] getImageOrient() {
		int[] ori = new int[4];
		ori[0] = orient;
		ori[1] = orix;
		ori[2] = oriy;
		ori[3] = oriz;
		return ori;
	}
	final public boolean hasTopology() { return hasTopology; }
	final public boolean hasShape(int id) { return hasShape[id]; }
	
	
	final public void addLabelSample() { labelSamples++; }
	
    /** 
	 *  set image-related information for segmentation
	 */
	final public void setImageInfo(int nix_, int niy_, int niz_, float rix_, float riy_, float riz_, int orient_, int orix_, int oriy_, int oriz_) {
		nix = nix_; niy = niy_; niz = niz_;
		rix = rix_; riy = riy_; riz = riz_;
		orient = orient_;
		orix = orix_; oriy = oriy_; oriz = oriz_;
		
		x0i = nix/2.0f;
		y0i = niy/2.0f;
		z0i = niz/2.0f;
		
		if (debug) {
			System.out.print("dimensions: "+nix+", "+niy+", "+niz+"\n");
			System.out.print("resolutions: "+rix+", "+riy+", "+riz+"\n");
			System.out.print("orientation: "+orient+" | "+orix+", "+oriy+", "+oriz+"\n");
			System.out.print("Atlas\n");
			System.out.print("dimensions: "+ntx+", "+nty+", "+ntz+"\n");
			System.out.print("resolutions: "+rtx+", "+rty+", "+rtz+"\n");
		}
	}
	
    /** 
	 *  set image-related information for segmentation
	 */
	final public void setShapeInfo(int num, int nsx_, int nsy_, int nsz_, float rsx_, float rsy_, float rsz_) {
		nsx[num] = nsx_; 
		nsy[num] = nsy_; 
		nsz[num] = nsz_;
		
		rsx[num] = rsx_; 
		rsy[num] = rsy_; 
		rsz[num] = rsz_;
	}
	
    /** 
	 *  generate atlas image from information
	 */
	 /*
    final public byte[][][] generateDistanceAtlasClassification() {
		float dist,min,count;
		float[]	energy = new float[classes];
		int best=0;
		byte[][][] img = new byte[nix][niy][niz];
		
		for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
			// compute each class probability : attribute the highest
			min = 1e30f; best = -1;
			img[x][y][z] = label[best];
		}
		return img;
	}
	*/
	
    /** 
	 *  generate atlas image from information
	 */
    final public byte[] generateClassification() {
		float dist,max,count;
		int best=0;
		byte[] img = new byte[nix*niy*niz];
		
		for (int xyz=0;xyz<nix*niy*niz;xyz++) {
			// compute each class probability : attribute the highest
			max = 0; best = -1;
			for (int k=0;k<classes;k++) {
				if (shape[k][xyz]>max) {
					best = k;
					max = shape[k][xyz];
				}
			}
			if (best>-1) img[xyz] = label[best];
			else img[xyz] = 0;
		}
		return img;
	}
	
    /** 
	 *  generate atlas image from information
	 */
    final public byte[] generateTransformedClassification() {
		float dist,max,count,val;
		int best=0;
		byte[] img = new byte[nix*niy*niz];
		float[] XP=new float[3];
		
		for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
			// compute each class probability : attribute the highest
			max = 0; best = -1;
			for (int k=0;k<classes;k++) {
				if (precompute) XP = fastImageToShapeCoordinates(k,x,y,z);
				else {
					if (transformMode==DEFORMABLE)
						XP = demons.getCurrentMapping(x,y,z);
					else
						XP = transformModel.imageToTemplate(x,y,z,transform,rotation, 1.0f);
				}
				
				val = ImageFunctions.linearInterpolation(shape[k],0.0f,XP[0],XP[1],XP[2],nsx[k],nsy[k],nsz[k]);	
				if (val>max) {
					best = k;
					max = val;
				}
			}
			if (best>-1) img[x+y*nix+z*nix*niy] = label[best];
			else img[x+y*nix+z*nix*niy] = 0;
		}
		return img;
	}
	
    /** 
	 *  generate atlas image from information
	 */
    final public float[][] generateTransformedShapes() {
		float[][] img = new float[classes][nix*niy*niz];
		float[] XP=new float[3];
		
		for (int k=0;k<classes;k++) {
			for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
				if (precompute) XP = fastImageToShapeCoordinates(k,x,y,z);
				else {
					if (transformMode==DEFORMABLE)
						XP = demons.getCurrentMapping(x,y,z);
					else
						XP = transformModel.imageToTemplate(x,y,z,transform,rotation, 1.0f);
				}
				
				img[k][x+y*nix+z*nix*niy] = ImageFunctions.linearInterpolation(shape[k],0.0f,XP[0],XP[1],XP[2],nsx[k],nsy[k],nsz[k]);	
			}
		}
		return img;
	}
	
	/** display the atlas data */
	final public String displayNames() {
		String output = "Structures \n";
		
		for (int k=0;k<classes;k++) {
			output += name[k]+" ("+objType[k]+")	"+label[k]+"\n";
		}
		
		return output;	
	}
		/** display the atlas data */
	final public String displayRegisteredShapes() {
		String output = "Registered Shapes (0/1: true/false) \n";
		
		for (int k=0;k<classes;k++) output += k+"	";
		output += "\n";
		for (int k=0;k<classes;k++) {
			if (registeredShape[k]) output += "1	";
			else output += "0	";
		}
		output += "\n";
		
		return output;	
	}
	/** display the atlas data */
	/*
	final public String displayCenters() {
		String output = "Centers \n";
		
		for (int k=0;k<classes;k++) {
			output += name[k]+" ( ";
			for (int i=0;i<3;i++) output += center[k][i]+" ";
			output += ")\n";
		}
		
		return output;	
	}
	*/
	/**
     *  compute the center of each object in a hard segmentation
     */
	/* 
    public final void findCenters(byte img[][][]) {
		float count;
		
		for (int k=0;k<classes;k++) {
			for (int n=0;n<3;n++) center[k][n] = 0.0f;
			count = 0.0f;
			for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
				if (img[x][y][z]==label[k]) {
					center[k][0] += x;
					center[k][1] += y;
					center[k][2] += z;
					count++;
				}
			}
			if (count>0) {
				center[k][0] = center[k][0]/count;
				center[k][1] = center[k][1]/count;
				center[k][2] = center[k][2]/count;
			}
		}
		return;
	}
	*/
	/**
     *  compute the center of each object in the shape priors
     */
	 /*
    public final void findShapeCenters() {
		float count;
		
		for (int k=0;k<classes;k++) {
			for (int n=0;n<3;n++) center[k][n] = 0.0f;
			count = 0.0f;
			for (int x=0;x<nsx[k];x++) for (int y=0;y<nsy[k];y++) for (int z=0;z<nsz[k];z++) {
				center[k][0] += shape[k][x][y][z]*x;
				center[k][1] += shape[k][x][y][z]*y;
				center[k][2] += shape[k][x][y][z]*z;
				count += shape[k][x][y][z];
			}
			if (count>0) {
				center[k][0] = center[k][0]/count;
				center[k][1] = center[k][1]/count;
				center[k][2] = center[k][2]/count;
			}
		}
		return;
	}
	*/
	/** 
	 *  allocate the space for an atlas shape image
	 */
	 /*
    final public void createBlankShapeAtlas() {
		shape = new float[classes][nix][niy][niz];
		for (int n=0;n<classes;n++) {
			nsx[n] = nix; nsy[n] = niy; nsz[n] = niz;
			rsx[n] = rix; rsy[n] = riy; rsz[n] = riz;
			minx[n] = 0; miny[n] = 0; minz[n] = 0;
			maxx[n] = nix; maxy[n] = niy; maxz[n] = niz;
			center[n][0] = 0.0f;
			center[n][1] = 0.0f;
			center[n][2] = 0.0f;
			hasShape[n] = false;
			shapeFile[n] = "atlas_"+name[n]+".raw";
		}
	}
	*/
	
	/** 
	 *  compute atlas quantities from a new image.
	 *	Image dimensions and resolutions must match.
	 */
	/* 
    final public void addImageToShapeAtlas(byte[][][] img, float slope) {
        
		// check for existence
		System.out.println("create Atlas ? "+labelSamples);
		if (labelSamples==0) createBlankShapeAtlas();
		shapeScale = slope;
		
		System.out.println("init");
		float[][][] tmp = new float[nix][niy][niz];
        for (int k=0;k<classes;k++) {
			for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
				tmp[x][y][z] = 0.0f;
			}
			
			System.out.println("boundariy");
			for (int x=1;x<nix-1;x++) for (int y=1;y<niy-1;y++) for (int z=1;z<niz-1;z++) {
				if (  (img[x+1][y][z]!=img[x][y][z])
					||(img[x-1][y][z]!=img[x][y][z])
					||(img[x][y+1][z]!=img[x][y][z])
					||(img[x][y-1][z]!=img[x][y][z])
					||(img[x][y][z+1]!=img[x][y][z])
					||(img[x][y][z-1]!=img[x][y][z]) ) {
						if (img[x][y][z]==label[k]) {
							tmp[x][y][z] = 0.5f;
							hasShape[k] = true;
						}
				}
			}
			if (hasShape[k]) {
			
				// compute linear membership
				float factor = 0.5f/slope;
				int D = Numerics.ceil(slope);
				
				System.out.println("propagation "+D+", "+factor);
				for (int d=0;d<D;d++) {
					for (int x=1;x<nix-1;x++) for (int y=1;y<niy-1;y++) for (int z=1;z<niz-1;z++) {
						if (tmp[x][y][z]==0) {
							if (tmp[x+1][y][z]-factor>tmp[x][y][z]) tmp[x][y][z] = Numerics.max(tmp[x+1][y][z]-factor,0.0f);
							if (tmp[x-1][y][z]-factor>tmp[x][y][z]) tmp[x][y][z] = Numerics.max(tmp[x-1][y][z]-factor,0.0f);
							if (tmp[x][y+1][z]-factor>tmp[x][y][z]) tmp[x][y][z] = Numerics.max(tmp[x][y+1][z]-factor,0.0f);
							if (tmp[x][y-1][z]-factor>tmp[x][y][z]) tmp[x][y][z] = Numerics.max(tmp[x][y-1][z]-factor,0.0f);
							if (tmp[x][y][z+1]-factor>tmp[x][y][z]) tmp[x][y][z] = Numerics.max(tmp[x][y][z+1]-factor,0.0f);
							if (tmp[x][y][z-1]-factor>tmp[x][y][z]) tmp[x][y][z] = Numerics.max(tmp[x][y][z-1]-factor,0.0f);
						}
					}
				}
				for (int x=1;x<nix-1;x++) for (int y=1;y<niy-1;y++) for (int z=1;z<niz-1;z++) {
					if (img[x][y][z]==label[k]) tmp[x][y][z] = 1.0f-tmp[x][y][z];
				}
				
				System.out.println("transfer to shape atlas");
				
				// transfer into prior map
				for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
					// add to the prior map
					shape[k][x][y][z] = ( labelSamples*shape[k][x][y][z] + tmp[x][y][z] ) / (labelSamples + 1.0f );
				}
			}
		}
		return;
    }
	*/
	/** normalize the shape atlas between 0 and 1: the sum over all priors must be 1 eveywhere */
	/*
	final public void normalizeShapeAtlas() {
		for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
			float sum = 0;
			for (int k=0;k<classes;k++) {
				sum += shape[k][x][y][z];
			}
			if (sum>0) {
				for (int k=0;k<classes;k++) {
					shape[k][x][y][z] = shape[k][x][y][z]/sum;
				}
			}
		}
	}
    */
	
	/**
	 *	read template image (the image must be in bytes)
	 */
	private byte[] loadTemplateImage(String filename, int Nx, int Ny, int Nz) {
		// read the raw data
		byte[] buffer = null;
		try {
          File f = new File( filename );
		  //System.out.println("exists ? "+f.exists());
          //System.out.println("can read ? "+f.canRead());
          FileInputStream fis = new FileInputStream( f );
            
		   buffer = new byte[Nx*Ny*Nz];
		   fis.read(buffer);
           fis.close();
		} catch (IOException io) {
           System.out.println("i/o pb: "+io.getMessage());
		}
		/*
        // convert to the image format
		byte [][][] img  = new byte[Nx][Ny][Nz];
		for (int x=0;x<Nx;x++) for (int y=0;y<Ny;y++) for (int z=0;z<Nz;z++) {
			img[x][y][z] = buffer[x + Nx*y + Nx*Ny*z];
		}
		buffer = null;
		
		return img;
		*/
		return buffer;
	}
	
	/**
	 *	read shape image (the image must be in float, little endian)
	 */
	private final float[] loadShapeImage(String filename, int Nx, int Ny, int Nz) {
		// read the raw data
		byte[] buffer = null;
		try {
           File f = new File( filename );
           FileInputStream fis = new FileInputStream( f );
            
		   buffer = new byte[4*Nx*Ny*Nz];
		   fis.read(buffer);
           fis.close();
		} catch (IOException io) {
           System.out.println("i/o pb: "+io.getMessage());
		}
		// convert to the image format
		float [] img  = new float[Nx*Ny*Nz];

		for (int xyz=0;xyz<Nx*Ny*Nz;xyz++) {
			int b1 = buffer[4*(xyz)+0] & 0xff;
			int b2 = buffer[4*(xyz)+1] & 0xff;
			int b3 = buffer[4*(xyz)+2] & 0xff;
			int b4 = buffer[4*(xyz)+3] & 0xff;
			// big endian
			//int tmpInt = ((b1 << 24) | (b2 << 16) | (b3 << 8) | b4);
			// little endian
			int tmpInt = ((b4 << 24) | (b3 << 16) | (b2 << 8) | b1);

			img[xyz] = Float.intBitsToFloat(tmpInt);
		}
        buffer = null;
		
		return img;
	}
	
	/**
	 *	read shape image (the image must be in float, little endian)
	 */
	private final float[][] loadDirectionImage(String filename, int Nx, int Ny, int Nz) {
		// read the raw data
		byte[] buffer = null;
		try {
           File f = new File( filename );
           FileInputStream fis = new FileInputStream( f );
            
		   buffer = new byte[12*Nx*Ny*Nz];
		   fis.read(buffer);
           fis.close();
		} catch (IOException io) {
           System.out.println("i/o pb: "+io.getMessage());
		}
		// convert to the image format
		float[][] img  = new float[3][Nx*Ny*Nz];

		for (int xyz=0;xyz<Nx*Ny*Nz;xyz++) for (int d=0;d<3;d++) {
			int b1 = buffer[4*(xyz+Nx*Ny*Nz*d)+0] & 0xff;
			int b2 = buffer[4*(xyz+Nx*Ny*Nz*d)+1] & 0xff;
			int b3 = buffer[4*(xyz+Nx*Ny*Nz*d)+2] & 0xff;
			int b4 = buffer[4*(xyz+Nx*Ny*Nz*d)+3] & 0xff;
			// big endian
			//int tmpInt = ((b1 << 24) | (b2 << 16) | (b3 << 8) | b4);
			// little endian
			int tmpInt = ((b4 << 24) | (b3 << 16) | (b2 << 8) | b1);

			img[d][xyz] = Float.intBitsToFloat(tmpInt);
		}
        buffer = null;
		
		return img;
	}
	
	/** 
	 *	load the atlas data from a file. 
	 *  All associated images are loaded at this time
	 */
	final public void loadAtlas(String filename) {
		if (verbose) System.out.println("loading atlas file: "+filename);
		try {
            File f = new File(filename);
			String dir = f.getParent();
            FileReader fr = new FileReader(f);
            BufferedReader br = new BufferedReader(fr);
            String line = br.readLine();
			StringTokenizer st;
			String imageFile, labelFile;
            // Exact corresponding template
            if (!line.equals("Structure Atlas File (edit at your own risks)")) {
                System.out.println("not a proper Structure Atlas file");
                br.close();
                fr.close();
                return;
            }
			line = br.readLine();
			while (line!=null) {
				if (line.startsWith("Structures")) {
					//System.out.println(line);
					// Structures:	classes	label	objType
					st = new StringTokenizer(line, "	");
					st.nextToken();
					classes = MipavUtil.getInt(st);
					name = new String[classes];
					label = new byte[classes];
					objType = new String[classes];
					for (int n=0;n<classes;n++) {
						// Name:label:objType
						line = br.readLine();
						st = new StringTokenizer(line, "	");
						name[n] = st.nextToken();
						label[n] = (byte)MipavUtil.getInt(st);
						objType[n] = st.nextToken();
					}
					// allocate other quantities
					hasTopology = false;
					hasShape = new boolean[classes];
					for (int n=0;n<classes;n++) hasShape[n] = false;
					hasDir = new boolean[classes];
					for (int n=0;n<classes;n++) hasDir[n] = false;
					shape = new float[classes][];
					direction = new float[classes][3][];
					nsx = new int[classes];
					nsy = new int[classes];
					nsz = new int[classes];
					rsx = new float[classes];
					rsy = new float[classes];
					rsz = new float[classes];
					minx = new int[classes];
					miny = new int[classes];
					minz = new int[classes];
					maxx = new int[classes];
					maxy = new int[classes];
					maxz = new int[classes];
					center = new float[classes][3];
					shapeFile = new String[classes];
					for (int n=0;n<classes;n++) shapeFile[n] = null;
					dirFile = new String[classes];
					for (int n=0;n<classes;n++) dirFile[n] = null;
					templateFile = null;
					registeredShape = new boolean[classes];
					for (int n=0;n<classes;n++) registeredShape[n] = true;
					registeredShape[0] = false;
					tractOverlap = new float[classes][classes];
					for (int n=0;n<classes;n++) for (int m=0;m<classes;m++) tractOverlap[n][m] = 1.0f;
					if (debug) System.out.println(displayNames());
				} else
				if (line.startsWith("Topology Atlas")) {
					//System.out.println(line);
					// File: name
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					st.nextToken();
					imageFile = dir+File.separator+st.nextToken();
					if (debug) System.out.print("file: "+imageFile+"\n");
					// Dimensions: ntx nty ntz
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					st.nextToken();
					ntx = MipavUtil.getInt(st);
					nty = MipavUtil.getInt(st);
					ntz = MipavUtil.getInt(st);
					x0t = ntx/2.0f;
					y0t = nty/2.0f;
					z0t = ntz/2.0f;
					if (debug) System.out.print("dims: "+ntx+" "+nty+" "+ntz+"\n");
					// Resolutions: rtx rty rtz
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					st.nextToken();
					rtx = MipavUtil.getFloat(st);
					rty = MipavUtil.getFloat(st);
					rtz = MipavUtil.getFloat(st);
					if (debug) System.out.print("res: "+rtx+"x"+rty+"x"+rtz+"\n");
					template = loadTemplateImage(imageFile, ntx, nty, ntz);
					hasTopology = true;
					templateFile = imageFile;
				} else
				if (line.startsWith("Shape Atlas")) {
					//if (debug) System.out.println(line);
					// Shape:	labelSamples
					st = new StringTokenizer(line, "	");
					st.nextToken();
					labelSamples = MipavUtil.getInt(st);
					// Dimensions: nix niy niz
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					st.nextToken();
					nix = MipavUtil.getInt(st);
					niy = MipavUtil.getInt(st);
					niz = MipavUtil.getInt(st);
					x0i = nix/2.0f;
					y0i = niy/2.0f;
					z0i = niz/2.0f;
					if (debug) System.out.print("image dim: "+nix+"x"+niy+"x"+niz+"\n");
					// Resolutions: rix riy riz
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					st.nextToken();
					rix = MipavUtil.getFloat(st);
					riy = MipavUtil.getFloat(st);
					riz = MipavUtil.getFloat(st);
					if (debug) System.out.print("image res: "+rix+"x"+riy+"x"+riz+"\n");
					// Orientations: orient orix oriy oriz
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					st.nextToken();
					orient = MipavUtil.getInt(st);
					orix = MipavUtil.getInt(st);
					oriy = MipavUtil.getInt(st);
					oriz = MipavUtil.getInt(st);
					if (debug) System.out.print("image orient: "+orient+"|"+orix+"x"+oriy+"x"+oriz+"\n");
					line = br.readLine();
					while (line.startsWith("Structure:")) {
						// find structure id
						st = new StringTokenizer(line, "	");
						st.nextToken();
						String title = st.nextToken();
						int id=-1;
						for (int n=0;n<classes;n++) {
							if (title.equals(name[n])) { id = n; break; }
						}
						if (debug) System.out.print("Shape: "+name[id]+"\n");
						if (id==-1) {
							line = br.readLine();
							line = br.readLine();
							line = br.readLine();
							line = br.readLine();
						} else {
							// Proba: name
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
							imageFile = dir+File.separator+st.nextToken();
							// Dir: name
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
							labelFile = dir+File.separator+st.nextToken();
							// Dimensions: nix niy niz
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
							nsx[id] = MipavUtil.getInt(st);
							nsy[id] = MipavUtil.getInt(st);
							nsz[id] = MipavUtil.getInt(st);
							if (debug) System.out.print("dim: "+nsx[id]+"x"+nsy[id]+"x"+nsz[id]+"\n");
							// Resolutions: rsx rsy rsz
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
							rsx[id] = MipavUtil.getFloat(st);
							rsy[id] = MipavUtil.getFloat(st);
							rsz[id] = MipavUtil.getFloat(st);
							if (debug) System.out.print("res: "+rsx[id]+"x"+rsy[id]+"x"+rsz[id]+"\n");
							// Center: cx cy cz
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
							center[id][0] = MipavUtil.getFloat(st);
							center[id][1] = MipavUtil.getFloat(st);
							center[id][2] = MipavUtil.getFloat(st);
							if (debug) System.out.print("center: "+center[id][0]+"x"+center[id][1]+"x"+center[id][2]+"\n");
							// min, max : initial values
							minx[id] = 0; miny[id] = 0; minz[id] = 0;
							maxx[id] = nsx[id]; maxy[id] = nsy[id]; maxz[id] = nsz[id];
			
							shape[id] = loadShapeImage(imageFile, nsx[id], nsy[id], nsz[id]);
							hasShape[id] = true;
							shapeFile[id] = imageFile;
							
							direction[id] = loadDirectionImage(labelFile, nsx[id], nsy[id], nsz[id]);
							hasDir[id] = true;
							dirFile[id] = labelFile;
						}
						line = br.readLine();
					}
				} else
				if (line.startsWith("Registered Shapes")) {
					if (debug) System.out.println(line);
					// Type value value
					line = br.readLine();
					if (!line.startsWith("0") && !line.startsWith("1")) line = br.readLine();
					if (debug) System.out.println(line);
					st = new StringTokenizer(line, "	");
					for (int n=0;n<classes;n++) {
						registeredShape[n] = (MipavUtil.getInt(st)==1);
					}
					if (debug) System.out.println(displayRegisteredShapes());
				} else
				if (line.startsWith("Overlap matrix")) {
					//  list of labels (for convenience)
					line = br.readLine();
					//if (debug) System.out.println(line);
					// Label [0 1] [0 1] [0 1]... (size: classes x classes)
					for (int n=0;n<classes;n++) {
						line = br.readLine();
						//if (debug) System.out.println(line);
						st = new StringTokenizer(line, "	");
						String type = st.nextToken();
						//if (debug) System.out.println(type);
						for (int m=0;m<classes;m++) {
							tractOverlap[n][m] = MipavUtil.getFloat(st);
						}
						//if (debug) System.out.println("\n");
					}
					// make sure the matrix is symmetric
					for (int n=0;n<classes;n++) {
						for (int m=0;m<classes;m++) {
							tractOverlap[n][m] = Numerics.max(tractOverlap[n][m],tractOverlap[m][n]);
						}
					}
					priorTractOverlap = true;
					if (debug) System.out.println(displayTractOverlap());
				}
				line = br.readLine();
				if (debug) System.out.println(line);
			}		
			br.close();
            fr.close();
			atlasFile = filename;
        }
        catch (FileNotFoundException e) {
            System.out.println(e.getMessage());
        }
        catch (IOException e) {
            System.out.println(e.getMessage());
        } 
		catch (OutOfMemoryError e){
			System.out.println(e.getMessage());
		}
		catch (Exception e) {
			System.out.println(e.getMessage());
        }

		if (debug) MedicUtilPublic.displayMessage("initialisation\n");
	}

	
	/** transformations: how to get a transformed value 
	 *	simple hypotheses: start from the same system (same origin, same resolution)
	 */
	public final void getTransformedShape(float[] val, int x, int y, int z) {
		if (precompute) getFastTransformedShape(val,x,y,z);
		else getRegularTransformedShape(val,x,y,z);
	}
	
	/** transformations: how to get a transformed value 
	 *	simple hypotheses: start from the same system (same origin, same resolution)
	 */
	private final void getRegularTransformedShape(float[] val, int x, int y, int z) {
		if (transformMode==NONE) {
			for (int k=0;k<classes;k++) val[k] = shape[k][x+y*nsx[k]+z*nsx[k]*nsy[k]];
			return;
		}
		
		float[] XP=new float[3];
		boolean noCoordinates = true;
		for (int k=0;k<classes;k++) {
			if ( (x<minx[k]) || (x>=maxx[k]) || (y<miny[k]) || (y>=maxy[k]) || (z<minz[k]) || (z>=maxz[k]) ) {
				if (k==0) val[k] = 1.0f;
				else val[k] = 0.0f;
			} else {
				if (transformMode==DEFORMABLE) {
					XP = demons.getCurrentMapping(x,y,z);
				} else if (noCoordinates) {
					XP = transformModel.imageToTemplate(x,y,z,transform,rotation,1.0f);
					noCoordinates = false;
				}
				if (k==0) val[k] = ImageFunctions.linearInterpolation(shape[k],1.0f,XP[0],XP[1],XP[2],nsx[k],nsy[k],nsz[k]);
				else val[k] = ImageFunctions.linearInterpolation(shape[k],0.0f,XP[0],XP[1],XP[2],nsx[k],nsy[k],nsz[k]);
			}
		}
	}
	
	/** transformations: how to get a transformed value 
	*	simple hypotheses: start from the same system (same origin, same resolution)
	*/
	private final void getFastTransformedShape(float[] val, int x, int y, int z) {
		if (transformMode==NONE) {
			for (int k=0;k<classes;k++) val[k] = shape[k][x+y*nsx[k]+z*nsx[k]*nsy[k]];
			return;
		}
		
		float[] XP=new float[3];
		boolean noCoordinates = true;
		for (int k=0;k<classes;k++) {
			if ( (x<minx[k]) || (x>=maxx[k]) || (y<miny[k]) || (y>=maxy[k]) || (z<minz[k]) || (z>=maxz[k]) ) {
				if (k==0) val[k] = 1.0f;
				else val[k] = 0.0f;
			} else {
				if (noCoordinates) {
					XP = fastImageToShapeCoordinates(k,x,y,z);
					noCoordinates = false;
				}
				if (k==0) val[k] = ImageFunctions.linearInterpolation(shape[k],1.0f,XP[0],XP[1],XP[2],nsx[k],nsy[k],nsz[k]);
				else val[k] = ImageFunctions.linearInterpolation(shape[k],0.0f,XP[0],XP[1],XP[2],nsx[k],nsy[k],nsz[k]);
			}
		}		
	}
	
	/** transformations: how to get a transformed value 
	 *	simple hypotheses: start from the same system (same origin, same resolution)
	 */
	public final void getTransformedDirection(float[][] dir, int x, int y, int z) {
		if (transformMode==NONE) {
			for (int k=0;k<classes;k++) for (int d=0;d<3;d++)
				dir[k][d] = direction[k][d][x+y*nsx[k]+z*nsx[k]*nsy[k]];
			return;
		}
		
		float[] XP=new float[3];
		boolean noCoordinates = true;
		for (int k=0;k<classes;k++) {
			if ( (x<minx[k]) || (x>=maxx[k]) || (y<miny[k]) || (y>=maxy[k]) || (z<minz[k]) || (z>=maxz[k]) ) {
				for (int d=0;d<3;d++) dir[k][d] = 0.0f;
			} else {
				// get the direction at the corresponding atlas location
				if (transformMode==DEFORMABLE) {
					XP = demons.getCurrentMapping(x,y,z);
				} else if (noCoordinates) {
					XP = transformModel.imageToTemplate(x,y,z,transform,rotation,1.0f);
					noCoordinates = false;
				}
				// nearest neighbor interpolation (!)
				int id = Numerics.round(Numerics.bounded(XP[0],0,nsx[k]-1)) 
						+ nsx[k]*Numerics.round(Numerics.bounded(XP[1],0,nsy[k]-1)) 
						+ nsx[k]*nsy[k]*Numerics.round(Numerics.bounded(XP[2],0,nsz[k]-1));
				for (int d=0;d<3;d++) dir[k][d] = direction[k][d][id];
				
				// inverse transform to put in image space
				if (transformMode==DEFORMABLE) {
					demons.mapDirection(dir[k],x,y,z);
				} else {
					transformModel.templateToImageDirection(dir[k],x,y,z,transform,rotation,1.0f);
				}
			}
		}
	}
	
	/** transformations: compute the shape bounding box in image with current transform
	 */
    public final void computeTransformedShapeBoundingBox() {
		float[] XP=new float[3];
		
		for (int k=0;k<classes;k++) {
			if (precompute) transformModel.precomputeImageToTemplateMatrix(shapeTransform[k],transform,rotation,1.0f);
			minx[k] = nix; miny[k] = niy; minz[k] = niz; 
			maxx[k] = 0; maxy[k] = 0; maxz[k] = 0;
		
			for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
				if (precompute) XP = fastImageToShapeCoordinates(k,x,y,z);
				else XP = transformModel.imageToTemplate(x,y,z,transform,rotation,1.0f);
				
				if (ImageFunctions.linearInterpolation(shape[k],0.0f,XP[0],XP[1],XP[2],nsx[k],nsy[k],nsz[k])>0) {
					if (x<minx[k]) minx[k] = x;
					if (y<miny[k]) miny[k] = y;
					if (z<minz[k]) minz[k] = z;
					if (x>maxx[k]) maxx[k] = x;
					if (y>maxy[k]) maxy[k] = y;
					if (z>maxz[k]) maxz[k] = z;
				}
			}
		}
		return;
	}
	
	
	/** transformations: re-compute the template using the transform
	 */
	public final void computeTransformedTemplate() {
		float[] XP=new float[3];
		byte[] tmp = new byte[nix*niy*niz];
		
		boolean unknownLabel = false;
		for (int x=0;x<nix;x++) for (int y=0;y<niy;y++) for (int z=0;z<niz;z++) {
			XP = transformModel.imageToTemplate(x,y,z,transform,rotation,1.0f);
			tmp[x+y*nix+z*nix*niy] = ImageFunctions.nearestNeighborInterpolation(template,label[0],XP[0],XP[1],XP[2],ntx,nty,ntz);
			// debug
			boolean wrong=true;
			for (int k=0;k<classes;k++) {
				if (tmp[x+y*nix+z*nix*niy]==label[k]) { wrong=false; }
			}
			if (wrong) unknownLabel = true;
		}
		if (unknownLabel) System.out.println("warning: incorrect label image \n");

		template = tmp;
		
		return;
	}
	
	
	/** transformations: how to update the transform using memberships */
	/**
	 * compute the image position given the membership functions
	 * for a given level l
     * performs only one iteration
	 */
    final private float computeRegistrationCoefficients(float[] hessian, float[] gradient, float[] trans) {
        float vec,mat,dp, res,val,num=0,den,weight;
        float xP=0,yP=0,zP=0;
        float dPx,dPy,dPz,priorT;
        float[][]       rot = null,dRa = null,dRb = null,dRc = null;
        float[]         dprior = new float[Nd];
		float[] 		Xi = new float[3];
        float[][] 		dXi;
		RotationMatrix  R;
		int				maskId;
		boolean			outside = false;
		int				Npt;
		int				n,count;
		float			limit;
		float			cost,norm;
		float[]			mem = new float[classes];
		
        // init the coefficients
		cost = 0.0f;
		norm = 0.0f;
		for (int i=0;i<Nd;i++) {
			hessian[i] = 0.0f;
			gradient[i] = 0.0f;
		}
		
        // set up rotation parameters
        if (transformModel.useRotation()) {
		   R = transformModel.computeRotationMatrix(trans);
		   rot = R.getMatrix();
		   dRa = R.derivatives(1.0f, 0.0f, 0.0f);
		   dRb = R.derivatives(0.0f, 1.0f, 0.0f);
		   dRc = R.derivatives(0.0f, 0.0f, 1.0f);
		}
		if (precompute) for (int k=0;k<classes;k++) if (objType[k].equals("bundle") ) 
			transformModel.precomputeImageToTemplateMatrix(shapeTransform[k],trans,rot,1.0f);
		
		if (debug) System.out.println(displayTransform(trans));	
			
		// main loop
		for (int x=offset;x<nix;x+=subsample) for (int y=offset;y<niy;y+=subsample) for (int z=offset;z<niz;z+=subsample) {
            // factor : classes
            vec = 0.0f; mat = 0.0f;
			for (int k=0;k<classes;k++) if (objType[k].equals("bundle") ) {
				// compute the local position
				if (precompute) Xi = fastImageToShapeCoordinates(k,x,y,z);
				else Xi = transformModel.imageToTemplate(x,y,z,trans,rot,1.0f);
				
				// compute interpolated values
				priorT = ImageFunctions.linearInterpolation(shape[k],0.0f,Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
				
				// check if the region is zero: no calculation needed then
				if (priorT>0) {
					int xyz = x + y*nix + z*nix*niy;
					// data term : function of the FA alone ?
					weight = imageRegistrationWeight(xyz,k);
					
					// spatial derivatives
					dPx = ImageFunctions.linearInterpolationXderivative(shape[k],Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
					dPy = ImageFunctions.linearInterpolationYderivative(shape[k],Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
					dPz = ImageFunctions.linearInterpolationZderivative(shape[k],Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
					
					// coordinate derivatives
					dXi = transformModel.imageToTemplateDerivatives(x,y,z,trans,rot,dRa,dRb,dRc,1.0f);
							
					// assemble everything
					for (int i=0;i<Nd;i++) {
						dprior[i] = dPx*dXi[0][i] + dPy*dXi[1][i] + dPz*dXi[2][i];
					}
					cost += registrationCost(weight,priorT,x,y,z,k);
					norm += registrationNorm(weight,priorT,x,y,z,k);
					for (int i=0;i<Nd;i++) {
						gradient[i] += registrationCostGradient(weight,priorT,dprior,x,y,z,k,i);
						hessian[i] += registrationCostHessian(weight,priorT,dprior,x,y,z,k,i);
					}
				}
			}
        }
		if (cost>ZERO) {
			for (int i=0;i<Nd;i++) {
				gradient[i] = gradient[i]/norm;
				hessian[i]  = hessian[i]/norm;
			}
		} else {
			for (int i=0;i<Nd;i++) {
				gradient[i] = 0.0f;
				hessian[i]  = 1.0f;
			}
		}

		System.out.println(displayVector(gradient));
        System.out.println(displayVector(hessian));
        return cost/norm;
    } // computeRegistrationCoefficients
    
	/** transformations: how to update the transform using memberships */
	/**
	 * compute the image position given the membership functions
	 * for a given level l
     * performs only one iteration
	 */
    final private float computeRegistrationEnergy(float[] trans) {
        float weight;
        float[] Xi = new float[3];
        float dPx,dPy,dPz,priorT;
        float[][]       rot = null;
		RotationMatrix  R;
		int				maskId;
		boolean			outside = false;
		float			cost,norm;
		
        // init the coefficients
		cost = 0.0f;
		norm = 0.0f;
		
        // set up rotation parameters
        if (transformModel.useRotation()) {
		   R = transformModel.computeRotationMatrix(trans);
		   rot = R.getMatrix();
		}
		if (precompute) for (int k=0;k<classes;k++) if (objType[k].equals("bundle") ) 
			transformModel.precomputeImageToTemplateMatrix(shapeTransform[k],trans,rot,1.0f);
		
		System.out.println(displayTransform(trans));	
		
		// main loop
		for (int x=offset;x<nix;x+=subsample) for (int y=offset;y<niy;y+=subsample) for (int z=offset;z<niz;z+=subsample) {
            // factor : classes
            for (int k=0;k<classes;k++) if (objType[k].equals("bundle") ) {
				// compute the local position
				if (precompute) Xi = fastImageToShapeCoordinates(k,x,y,z);
				else Xi = transformModel.imageToTemplate(x,y,z,trans,rot,1.0f);
				
			
				// compute interpolated values
				priorT = ImageFunctions.linearInterpolation(shape[k],0.0f,Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
				
				// check if the region is zero: no calculation needed then
				if (priorT>0) {
					int xyz = x + y*nix + z*nix*niy;
					// data term : function of the memberships
					weight = imageRegistrationWeight(xyz,k);
					
					cost += registrationCost(weight,priorT,x,y,z,k);
					norm += registrationNorm(weight,priorT,x,y,z,k);
				}
			}
        }
 
        return cost/norm;
    } // computeRegistrationEnergy
    
	/**
	 * compute the image position given the membership functions
	 * for a given level l
     * performs only one iteration
	 */
    final private float computeScaledRegistrationCoefficients(float[] hessian, float[] gradient, float[] trans, float[] fa, float[][] shp, int npx, int npy, int npz, int[] nspx, int[] nspy, int[] nspz, float scale) {
        float vec,mat,dp, res,val,num=0,den,weight;
        float xP=0,yP=0,zP=0;
        float dPx,dPy,dPz,priorT;
        float[][]       rot = null,dRa = null,dRb = null,dRc = null;
        float[]         dprior = new float[Nd];
		float[] Xi;
        float[][] dXi;
		RotationMatrix  R;
		int				maskId;
		boolean			outside = false;
		int				Npt;
		int				n,count;
		float			limit;
		float			cost,norm;
		float[]			mem = new float[classes];
		
        // init the coefficients
		cost = 0.0f;
		norm = 0.0f;
		for (int i=0;i<Nd;i++) {
			hessian[i] = 0.0f;
			gradient[i] = 0.0f;
		}
		
		// set up rotation parameters
        if (transformModel.useRotation()) {
		   R = transformModel.computeRotationMatrix(trans);
		   rot = R.getMatrix();
		   dRa = R.derivatives(1.0f, 0.0f, 0.0f);
		   dRb = R.derivatives(0.0f, 1.0f, 0.0f);
		   dRc = R.derivatives(0.0f, 0.0f, 1.0f);
		}
		if (precompute) for (int k=0;k<classes;k++) if (objType[k].equals("bundle") ) 
			transformModel.precomputeImageToTemplateMatrix(shapeTransform[k],trans,rot,scale);
		
		// main loop
		for (int x=offset;x<npx;x+=subsample) for (int y=offset;y<npy;y+=subsample) for (int z=offset;z<npz;z+=subsample) {
            // factor : classes
            vec = 0.0f; mat = 0.0f;
			for (int k=0;k<classes;k++) if (objType[k].equals("bundle") ) {
				// compute the local position
				if (precompute) Xi = fastImageToShapeCoordinates(k,x,y,z);
				else Xi = transformModel.imageToTemplate(x,y,z,trans,rot,scale);
				
				// compute interpolated values
				priorT = ImageFunctions.linearInterpolation(shp[k],0.0f,Xi[0],Xi[1],Xi[2],nspx[k],nspy[k],nspz[k]);
				
				// check if the region is zero: no calculation needed then
				if (priorT>0) {
					int xyz = x + y*npx + z*npx*npy;
					// data term : function of the memberships
					weight = fa[xyz]*fa[xyz];
					
					// derivatives
					dPx = ImageFunctions.linearInterpolationXderivative(shp[k],Xi[0],Xi[1],Xi[2],nspx[k],nspy[k],nspz[k]);
					dPy = ImageFunctions.linearInterpolationYderivative(shp[k],Xi[0],Xi[1],Xi[2],nspx[k],nspy[k],nspz[k]);
					dPz = ImageFunctions.linearInterpolationZderivative(shp[k],Xi[0],Xi[1],Xi[2],nspx[k],nspy[k],nspz[k]);
					
					// coordinate derivatives
					dXi = transformModel.imageToTemplateDerivatives(x,y,z,trans,rot,dRa,dRb,dRc,scale);
							
					// assemble everiything
					for (int i=0;i<Nd;i++) {
						dprior[i] = dPx*dXi[0][i] + dPy*dXi[1][i] + dPz*dXi[2][i];
					}
					cost += registrationCost(weight,priorT,x,y,z,k);
					norm += registrationNorm(weight,priorT,x,y,z,k);
					for (int i=0;i<Nd;i++) {
						gradient[i] += registrationCostGradient(weight,priorT,dprior,x,y,z,k,i);
						hessian[i] += registrationCostHessian(weight,priorT,dprior,x,y,z,k,i);
					}
				}
			}
        }
		for (int i=0;i<Nd;i++) {
			gradient[i] /= Numerics.max(ZERO,norm);
			hessian[i] /= Numerics.max(ZERO,norm);
		}

        return cost/norm;
    } // computeRegistrationCoefficients
    
	/** transformations: how to update the transform using memberships */
	/**
	 * compute the image position given the membership functions
	 * for a given level l
     * performs only one iteration
	 */
    final private float computeScaledRegistrationEnergy(float[] trans, float[] fa, float[][] shp, int npx, int npy, int npz, int[] nspx, int[] nspy, int[] nspz, float scale) {
        float weight;
        float[] Xi;
        float dPx,dPy,dPz,priorT;
        float[][]       rot = null;
		RotationMatrix  R;
		int				maskId;
		boolean			outside = false;
		float			cost,norm;
		
        // init the coefficients
		cost = 0.0f;
		norm = 0.0f;
		
        // set up rotation parameters
        if (transformModel.useRotation()) {
		   R = transformModel.computeRotationMatrix(trans);
		   rot = R.getMatrix();
		}
		if (precompute) for (int k=0;k<classes;k++) if (objType[k].equals("bundle") ) 
			transformModel.precomputeImageToTemplateMatrix(shapeTransform[k],trans,rot,scale);
		
		// main loop
		for (int x=offset;x<npx;x+=subsample) for (int y=offset;y<npy;y+=subsample) for (int z=offset;z<npz;z+=subsample) {
            // factor : classes
            for (int k=0;k<classes;k++) if (objType[k].equals("bundle") ) {
				// compute the local position
				if (precompute) Xi = fastImageToShapeCoordinates(k,x,y,z);
				else Xi = transformModel.imageToTemplate(x,y,z,trans,rot,scale);
				
				// compute interpolated values
				priorT = ImageFunctions.linearInterpolation(shp[k],0.0f,Xi[0],Xi[1],Xi[2],nspx[k],nspy[k],nspz[k]);
				
				// check if the region is zero: no calculation needed then
				if (priorT>0) {
					int xyz = x + y*npx + z*npx*npy;
					// data term : function of the memberships
					weight = fa[xyz]*fa[xyz];
					
					cost += registrationCost(weight,priorT,x,y,z,k);
					norm += registrationNorm(weight,priorT,x,y,z,k);
				}
			}
        }
 
        return cost/norm;
    } // computeRegistrationEnergy
    
    final private float computeSingleRegistrationCoefficients(float[] hessian, float[] gradient, float[] trans, int k) {
        float vec,mat,dp, res,val,num=0,den,weight;
        float xP=0,yP=0,zP=0;
        float dPx,dPy,dPz,priorT;
        float[][]       rot = null,dRa = null,dRb = null,dRc = null;
        float[]         dprior = new float[Nd];
		float[] 		Xi = new float[3];
        float[][] 		dXi = new float[3][Nd];
		float[]         regMat,regVec,regParam;
		RotationMatrix  R;
		int				maskId;
		boolean			outside = false;
		int				Npt;
		int				n,count;
		float			limit;
		float			cost,norm;
		float[]			mem = new float[classes];
		//float[] 		pC = new float[classes];
		
        // init the coefficients
		cost = 0.0f;
		norm = 0.0f;
		for (int i=0;i<Nd;i++) {
			hessian[i] = 0.0f;
			gradient[i] = 0.0f;
		}
		
        // set up rotation parameters
        if (transformModel.useRotation()) {
		   R = transformModel.computeRotationMatrix(trans);
		   rot = R.getMatrix();
		   dRa = R.derivatives(1.0f, 0.0f, 0.0f);
		   dRb = R.derivatives(0.0f, 1.0f, 0.0f);
		   dRc = R.derivatives(0.0f, 0.0f, 1.0f);
		}
		if (precompute) if ( (k!=0) ) 
			transformModel.precomputeImageToTemplateMatrix(shapeTransform[k],trans,rot,1.0f);
		
		// main loop
		for (int x=offset;x<nix;x+=subsample) for (int y=offset;y<niy;y+=subsample) for (int z=offset;z<niz;z+=subsample) {
            // factor : classes
            vec = 0.0f; mat = 0.0f;
			
			// compute the local position
			if (precompute) Xi = fastImageToShapeCoordinates(k,x,y,z);
			else Xi = transformModel.imageToTemplate(x,y,z,trans,rot,1.0f);
			
		
			// compute interpolated values
			priorT = ImageFunctions.linearInterpolation(shape[k],0.0f,Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
			
			// check if the region is zero: no calculation needed then
			if (priorT>0) {
				int xyz = x + y*nix + z*nix*niy;
				// data term : function of the FA alone ?
				weight = imageRegistrationWeight(xyz,k);
				
				// derivatives
				dPx = ImageFunctions.linearInterpolationXderivative(shape[k],Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
				dPy = ImageFunctions.linearInterpolationYderivative(shape[k],Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
				dPz = ImageFunctions.linearInterpolationZderivative(shape[k],Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
				
				// coordinate derivatives
				dXi = transformModel.imageToTemplateDerivatives(x,y,z,trans,rot,dRa,dRb,dRc,1.0f);
							
				// assemble everiything
				for (int i=0;i<Nd;i++) {
					dprior[i] = dPx*dXi[0][i] + dPy*dXi[1][i] + dPz*dXi[2][i];
				}
				cost += registrationCost(weight,priorT,x,y,z,k);
				norm += registrationNorm(weight,priorT,x,y,z,k);
				for (int i=0;i<Nd;i++) {
					gradient[i] += registrationCostGradient(weight,priorT,dprior,x,y,z,k,i);
					hessian[i] += registrationCostHessian(weight,priorT,dprior,x,y,z,k,i);
				}
			}
        }
		for (int i=0;i<Nd;i++) {
			gradient[i] /= norm;
			hessian[i] /= norm;
		}
 
        return cost/norm;
    } // computeRegistrationCoefficients
    
    final private float computeSingleRegistrationEnergy(float[] trans, int k) {
        float weight;
        float[] Xi = new float[3];
        float dPx,dPy,dPz,priorT;
        float[][]       rot = null;
		float[]          regMat,regVec,regParam;
		RotationMatrix  R;
		int				maskId;
		boolean			outside = false;
		float			cost,norm;
		
       // init the coefficients
		cost = 0.0f;
		norm = 0.0f;
		
        // set up rotation parameters
        if (transformModel.useRotation()) {
		   R = transformModel.computeRotationMatrix(trans);
		   rot = R.getMatrix();
		}
		if (precompute) if ( (k!=0) ) 
			transformModel.precomputeImageToTemplateMatrix(shapeTransform[k],trans,rot,1.0f);
		 
		// main loop
		for (int x=offset;x<nix;x+=subsample) for (int y=offset;y<niy;y+=subsample) for (int z=offset;z<niz;z+=subsample) {
            // factor : classes
            
			// compute the local position
			if (precompute) Xi = fastImageToShapeCoordinates(k,x,y,z);
			else Xi = transformModel.imageToTemplate(x,y,z,trans,rot,1.0f);
			
			// compute interpolated values
			priorT = ImageFunctions.linearInterpolation(shape[k],0.0f,Xi[0],Xi[1],Xi[2],nsx[k],nsy[k],nsz[k]);
			
			// check if the region is zero: no calculation needed then
			if (priorT>0) {
				int xyz = x + y*nix + z*nix*niy;
				// data term : function of the memberships
				weight = imageRegistrationWeight(xyz,k);
				
				/*
				// add coupling to selected priors
				getTransformedShape(pC,x,y,z);
				for (int l=0;l<classes;l++) if (shapeCoupling[k][l]) {
					weight -= couplingFactor*pC[l]*pC[l];
				}
				*/
				cost += registrationCost(weight,priorT,x,y,z,k);
				norm += registrationNorm(weight,priorT,x,y,z,k);
			}
        }
 
        return cost/norm;
    } // computeRegistrationEnergy
    
	final private float registrationCost(float w, float pT, int x, int y, int z, int k) {
		return w*pT*pT;
	}
	
	final private float registrationCostGradient(float w, float pT, float[] dpT, int x, int y, int z, int k, int i) {
		return 2.0f*w*pT*dpT[i];
	}
	
	final private float registrationCostHessian(float w, float pT, float[] dpT, int x, int y, int z, int k, int i) {
		return 2.0f*w*dpT[i]*dpT[i];
	}
	
	final private float registrationNorm(float w, float pT, int x, int y, int z, int k) {
		return pT*pT;
	}
   	
	/**
	 * gives the weighting factor for the image data
	 */
    final private float imageRegistrationWeight(int xyz, int k) {
		float p = 0.0f;
		if (dataSource==FA) {
			p = famap[xyz]*famap[xyz];
		} else {
			int b = lbmems[0][xyz];
			
			if (b==k) {
				p = mems[0][xyz];
			} else if (b>=100) {
				int b2 = b%100;
				int b1 = (b-b2)/100;
				
				if (b1==k || b2==k) {
					p = mems[0][xyz];
				}
			}
		}
		return p;
	}
				
	/**
	/**
	 * compute the new transform using gradient descent
	 * performs only one iteration, at scale l
	 */
    final private float registerGradientDescent() {
		float[]		trial = new float[Nd];
		float		E0,E,Eprev;
		boolean		stop = false;
		boolean		admissible = true;
		float[] 	hessian, gradient;
		
		gradient = new float[Nd];
		hessian  = new float[Nd];
		
		// compute the coefficients at current transform
		E0 = computeRegistrationCoefficients(hessian, gradient, transform);
		E = E0;
		
		// search along the line
		int iter = 0;
		while (!stop) {
			Eprev = E;
			
			// new values for the coeffs 
			for (int n=0; n<Nd; n++) 
				//trial[n] = transform[n] + lambda/Numerics.max(ZERO,hessian[n])*gradient[n];
				trial[n] = transform[n] + lambda/Numerics.max(ZERO,hessian[n])*gradient[n];
			
			// is the energy better ?
			E = computeRegistrationEnergy(trial);
			
			if (debug) System.out.print( "a: "+lambda+"("+E0+")->("+E+")\n");

			// test on energy value : maximisation
			if ( E > E0 ) {
				// better value: changing the transform and the scale
				for (int n=0; n<Nd; n++) 
					transform[n] = trial[n];
				
				lambda = lambda*lfactor;
				stop = true;
				if (verbose) MedicUtilPublic.displayMessage(displayTransform(transform));
			} 
			else {
				lambda = lambda/lfactor;
				iter++;
				E = Eprev;
				if (iter>itSupp) {
					stop = true;
					if (debug) System.out.print( "stop search\n");
				}
			}
		}
		
		return (E-E0)/E;
	} // registerGradientDescent

	/**
	 * compute the new transform using gradient descent
	 * performs only one iteration, at scale l
	 */
    final private float registerScaledGradientDescent(float[] fa, float[][] shp, int nsx, int nsy, int nsz, int[] nspx, int[] nspy, int[] nspz, float scale) {
		float[]		trial = new float[Nd];
		float		E0,E,Eprev;
		boolean		stop = false;
		boolean		admissible = true;
		float[] 	hessian, gradient;
		
		gradient = new float[Nd];
		hessian  = new float[Nd];
		
		// compute the coefficients at current transform
		E0 = computeScaledRegistrationCoefficients(hessian, gradient, transform, fa, shp, nsx, nsy, nsz, nspx, nspy, nspz, scale);
		E = E0;
		
		// search along the line
		int iter = 0;
		while (!stop) {
			Eprev = E;
			
			// new values for the coeffs 
			for (int n=0; n<Nd; n++) 
				trial[n] = transform[n] + lambda/Numerics.max(ZERO,hessian[n])*gradient[n];
			
			// is the energy better ?
			E = computeScaledRegistrationEnergy(trial, fa, shp, nsx, nsy, nsz, nspx, nspy, nspz, scale);
			
			if (debug) System.out.print( "a: "+lambda+"("+E0+")->("+E+")\n");

			// test on energy value : maximisation
			if ( E > E0 ) {
				// better value: changing the transform and the scale
				for (int n=0; n<Nd; n++) 
					transform[n] = trial[n];
				
				lambda = lambda*lfactor;
				stop = true;
				if (debug) {
					MedicUtilPublic.displayMessage(displayTransform(transform)+" (E:"+E+")\n");
					if (verbose) System.out.print(displayTransform(transform)+" (E:"+E+")\n");
				} else {
					MedicUtilPublic.displayMessage(".");
					if (verbose) System.out.print(".");
				}				
			} 
			else {
				lambda = lambda/lfactor;
				iter++;
				if (iter>itSupp) {
					stop = true;
					if (debug) System.out.print( "stop search\n");
				}
			}
		}
		
		return (E-E0)/E;
	} // registerScaledGradientDescent

    /** 
	 *	runs a Levenberg-Marquardt step for registering shapes
	 */
	public final void registerShapes() {    
		if (transformMode==DEFORMABLE) 
			for (int t=0;t<itMax;t++) demons.registerImageToTarget();
		else registerAllShapes();
	}
    /** 
	 *	runs a Levenberg-Marquardt step for registering shapes
	 */
	public final void registerAllShapes() {    
		boolean stop;
		
        // one level
        if (debug) MedicUtilPublic.displayMessage("registration\n");
        lambda = INIT_LAMBDA;
		Nturn = 0; itPlus = 0; stop = false;
		float diff = 1;
		for (int n=0;n<itMax && diff>minEdiff && lambda>minLambda;n++) {
			diff = registerGradientDescent();
		}
		// update the rotation coefficients
		if (transformModel.useRotation()) {
			rotation = transformModel.computeRotation(transform);
		}
		computeTransformedShapeBoundingBox();
    }//registerShapes
 
    /** 
	 *	runs a pyramid gradient step for registering shapes
	 */
	public final void registerShapesPyramid() {    
		boolean stop;
		float[] halfFA;
		float[][] halfShape = new float[classes][];
		float scale = 1;
		int npx,npy,npz;
		int[] nspx = new int[classes], nspy = new int[classes], nspz = new int[classes];
		
		// pyramid: scale the image and shapes
		if (debug) MedicUtilPublic.displayMessage("registration \n");
		if (debug) System.out.print("structure atlas registration \n");
		
		scale = 1;
		for (int l=1;l<levels;l++) scale = 2*scale;
		for (int l=levels;l>1;l--) {
			if (debug) MedicUtilPublic.displayMessage("level "+l+"\n");
			if (debug) System.out.print("level "+l+"\n");
			
			// compute the half images
			npx = Numerics.floor(nix/scale);
			npy = Numerics.floor(niy/scale);
			npz = Numerics.floor(niz/scale);
			halfFA = ImageFunctions.subsample(famap,nix,niy,niz,(int)scale);
			halfFA = priorFromFA(halfFA,npx,npy,npz);
			for (int k=0;k<classes;k++) {
				//if (debug) System.out.println("DIMS "+nsx[k]+" "+nsy[k]+" "+nsz[k]+" "+scale);
				 
				nspx[k] = Numerics.floor(nsx[k]/scale);
				nspy[k] = Numerics.floor(nsy[k]/scale);
				nspz[k] = Numerics.floor(nsz[k]/scale);
				halfShape[k] = ImageFunctions.subsample(shape[k],nsx[k],nsy[k],nsz[k],(int)scale);
			}
			
			// perform the gradient descent
			//if (l==levels) subsample = 1; // no subsampling at highest level
			//else subsample = 2*subsample;
			lambda = INIT_LAMBDA;
			//oldchisq = computeScaledRegistrationCoefficients(hessian, gradient, transform, halfFA, halfShape, npx, npy, npz, nspx, nspy, nspz, scale);
			Nturn = 0; itPlus = 0; stop = false;
			//if (debug) System.out.print( "-first--->("+oldchisq+"\n");
			//while (!stop) stop = registerLevenbergMarquardt();
			float diff = 1;
			for (int n=0;n<itMax && diff>minEdiff && lambda>minLambda;n++) {
				diff = registerScaledGradientDescent(halfFA, halfShape, npx, npy, npz, nspx, nspy, nspz, scale);
			}
			// update scaling
			scale = scale/2.0f;
		}
		halfFA = null;
		halfShape = null;
		System.gc();
		
		// update the rotation coefficients
		// update the rotation coefficients
		if (transformModel.useRotation()) {
			rotation = transformModel.computeRotation(transform);
		}
		computeTransformedShapeBoundingBox();
		
		//if (changeTemplate) computeTransformedTemplate();
		
		MedicUtilPublic.displayMessage("\n");
		if (verbose) System.out.print("\n");
		
    }//registerShapesPyramid
 
	/**
	 *	initialize registration parameters
	 */
	public final void initShapeRegistration(float[] fa_, String transformType_, int iter_, int lvl_) {
		// memberships, iterations
		famap = fa_;
		mems = null;
		lbmems = null;
		nbest = 0;
		dataSource = FA;
		
		itMax = iter_;
		levels = lvl_;
		subsample = 3;
		
		// image parameters
		x0i = nix/2.0f; 
		y0i = niy/2.0f; 
		z0i = niz/2.0f; 
			
		// transform
		transformMode = PARAMETRIC;
		MedicUtilPublic.displayMessage("transform: "+transformType_+" (single transform)\n");
		
		transformModel = new ParametricTransform(transformType_, x0i,y0i,z0i, rix,riy,riz, nix,niy,niz, x0t,y0t,z0t, rtx,rty,rtz, ntx,nty,ntz);
		
		Nd = transformModel.getDimension();
		
		// init parameters
		transform = new float[Nd];
			
		for (int n=0;n<Nd;n++) transform[n] = 0.0f;
		if (transformModel.useRotation())
			rotation = transformModel.computeRotation(transform);
			
		// quadratic scale: no pre-computing :(
		if (transformModel.isLinear()) {
			precompute = true;
			shapeTransform = new float[classes][3][4];
			precomputeTransformMatrix(1.0f);
		} else {
			precompute = false;
			shapeTransform = null;
		}		
	}
		
	/**
	 *	initialize registration parameters
	 */
	public final void updateShapeRegistration(float[] fa_, String transformType_, int iter_, int lvl_) {
		// memberships, iterations
		famap = fa_;
		mems = null;
		lbmems = null;
		nbest = 0;
		dataSource = FA;
		
		itMax = iter_;
		levels = lvl_;
		subsample = 3;
		
		// change the transformation type: compute the new parameters
		String oldType = transformModel.getTransformType();
		int newMode;
		if (transformType_.startsWith("deformable_")) {
			transformType_ = transformType_.substring(6);
			newMode = DEFORMABLE;
			MedicUtilPublic.displayMessage("transform: "+transformType_+" (deformable)\n");
		} else {
			newMode = PARAMETRIC;
			MedicUtilPublic.displayMessage("transform: "+transformType_+" (parametric)\n");
		}
		
		transformModel = new ParametricTransform(transformType_, x0i,y0i,z0i, rix,riy,riz, nix,niy,niz, x0t,y0t,z0t, rtx,rty,rtz, ntx,nty,ntz);
		
		if (newMode==PARAMETRIC) {
			String newType = transformType_;
			
			// init parameters
			transform = transformModel.changeTransformType(transform, oldType, newType);
			
			// update other parameters
			Nd = transformModel.getDimension();
			transformMode = newMode;
		} else {
			float[][] mat =new float[3][4];
			transformModel.precomputeImageToTemplateMatrix(mat, transform, transformModel.computeRotation(transform), 1.0f);
			
			// start Demons
			demons = new DotsWarping(shape, classes, fa_, null, null, 0, registeredShape,
												ntx, nty,ntz, rtx, rty, rtz, nix, niy, niz, rix, riy, riz,
												1.0f, 1.0f, true, 0.0f, 4.0f, iter_, iter_,
												DotsWarping.GAUSS_DIFFUSION,
												DotsWarping.SYMMETRIC,
												DotsWarping.COMPOSITIVE,
												mat);
			
			transformMode = newMode;
		}
		// quadratic scale: no pre-computing :(
		if (transformMode==PARAMETRIC && transformModel.isLinear()) {
			precompute = true;
			shapeTransform = new float[classes][3][4];
			precomputeTransformMatrix(1.0f);
		} else {
			precompute = false;
			shapeTransform = null;
		}
	}
		
	/**
	 *	initialize registration parameters
	 */
	public final void updateShapeRegistration(float[][] mems_, short[][] lb_, int nb_, 
												String transformType_, int iter_, int lvl_, float dSmooth_, float dScale_) {
		// memberships, iterations
		famap = null;
		mems = mems_;
		lbmems = lb_;
		nbest = nb_;
		dataSource = MEMS;
		
		//normalizeShapePriors();
		
		itMax = iter_;
		levels = lvl_;
		subsample = 3;
		
		// change the transformation type: compute the new parameters
		String oldType = transformModel.getTransformType();
		int newMode;
		if (transformType_.startsWith("deformable_")) {
			transformType_ = transformType_.substring(11);
			newMode = DEFORMABLE;
			MedicUtilPublic.displayMessage("transform: "+transformType_+" deformable\n");
			transformModel = new ParametricTransform("rigid", x0i,y0i,z0i, rix,riy,riz, nix,niy,niz, x0t,y0t,z0t, rtx,rty,rtz, ntx,nty,ntz);
		} else {
			newMode = PARAMETRIC;
			MedicUtilPublic.displayMessage("transform: "+transformType_+" parametric\n");
			transformModel = new ParametricTransform(transformType_, x0i,y0i,z0i, rix,riy,riz, nix,niy,niz, x0t,y0t,z0t, rtx,rty,rtz, ntx,nty,ntz);
		}
		String newType = transformType_;
		
		
		// init parameters
		if (newMode==PARAMETRIC) {
			transform = transformModel.changeTransformType(transform, oldType, newType);
			// update other parameters
			Nd = transformModel.getDimension();
			transformMode = newMode;
		} else {
			float[][] mat =new float[3][4];
			transformModel.precomputeImageToTemplateMatrix(mat, transform, transformModel.computeRotation(transform), 1.0f);
			MedicUtilPublic.displayMessage("deformable transform init: "+transformModel.displayTransform(transform)+"\n");

			int type = DotsWarping.GAUSS_FLUID;
			if (transformType_.equals("gauss_diffusion")) type = DotsWarping.GAUSS_DIFFUSION;
			if (transformType_.equals("gauss_fluid")) type = DotsWarping.GAUSS_FLUID;
			if (transformType_.equals("gauss_mixed")) type = DotsWarping.GAUSS_MIXED;
			
			// start Demons
			demons = new DotsWarping(shape, classes, null, mems_, lb_, nbest, registeredShape,
												ntx, nty, ntz, rtx, rty, rtz, nix, niy, niz, rix, riy, riz,
												dSmooth_, 1.0f, false, 0.0f, dScale_, iter_, iter_,
												type,
												DotsWarping.MOVING,
												DotsWarping.COMPOSITIVE,
												mat);
			
			demons.initializeTransform();
			
			transformMode = newMode;
		}
		
		// update rotation parameters
		if (transformMode==PARAMETRIC && transformModel.useRotation())
			rotation = transformModel.computeRotation(transform);
		
		// quadratic scale: no pre-computing :(
		if (transformMode==PARAMETRIC && transformModel.isLinear()) {
			precompute = true;
			shapeTransform = new float[classes][3][4];
			precomputeTransformMatrix(1.0f);
		} else {
			precompute = false;
			shapeTransform = null;
		}
	}
	
	/** 
	 *	normalizes the priors into memberships
	 */
	public final void normalizeShapePriors() {
		for (int x=0;x<nsx[0];x++) for (int y=0;y<nsy[0];y++) for (int z=0;z<nsz[0];z++) {
            float sum=0.0f;
			for (int k=0;k<classes;k++) {
				sum += shape[k][x+nsx[0]*y+nsx[0]*nsy[0]*z];
			}
			if (sum>ZERO) for (int k=0;k<classes;k++) {
				shape[k][x+nsx[0]*y+nsx[0]*nsy[0]*z] /= sum;
			}
		}
		return;
	}
	
	/** 
	 *	normalizes the priors into memberships
	 */
	public final void changeFiberShapePriors(float factor) {
		for (int x=0;x<nsx[0];x++) for (int y=0;y<nsy[0];y++) for (int z=0;z<nsz[0];z++) {
            for (int k=0;k<classes;k++) if (objType[k].equals("fiber")) {
				shape[k][x+nsx[0]*y+nsx[0]*nsy[0]*z] *= factor;
			}
		}
		return;
	}
	
	private final float[] fastImageToShapeCoordinates(int s, int x,int y,int z) {
		float[] X = new float[3];
		X[0] = shapeTransform[s][0][0]*x + shapeTransform[s][0][1]*y + shapeTransform[s][0][2]*z + shapeTransform[s][0][3];
		X[1] = shapeTransform[s][1][0]*x + shapeTransform[s][1][1]*y + shapeTransform[s][1][2]*z + shapeTransform[s][1][3];
		X[2] = shapeTransform[s][2][0]*x + shapeTransform[s][2][1]*y + shapeTransform[s][2][2]*z + shapeTransform[s][2][3];
		
		return X;
	}
	public final void precomputeTransformMatrix(float scale) {
		if (!precompute) return;
		
		float[][] rot = null;
		if (transformModel.useRotation())
			rot = transformModel.computeRotation(transform);
		for (int s=0;s<classes;s++) {
			transformModel.precomputeImageToTemplateMatrix(shapeTransform[s], transform, rot, scale);
		}
	}

	public final String displayTransform(float[] trans) {
		String info = "transform: (";
		for (int n=0;n<Nd-1;n++) info += trans[n]+", ";
		info += trans[Nd-1]+")\n";
		
		return info;
	}
	
	public final String displayMultiTransform(float[][] trans) {
		String info = "";
		for (int k=0;k<classes;k++) info += displayTransform(trans[k]);
		return info;
	}
	
	/** display the atlas data */
	final public String displayTractOverlap() {
		String output = "Tract Overlap Prior \n";
		
		output += "	";
		for (int k=0;k<classes;k++) output += k+"	";
		for (int k=0;k<classes;k++) {
			output += "\n"+k+"	";
			for (int l=0;l<classes;l++) {
				output += tractOverlap[k][l]+"	"; 
			}
		}
		output += "\n";
		
		return output;	
	}
	
	public final String displayVector(float[] vect) {
		String info = "vector: (";
		for (int n=0;n<vect.length-1;n++) info += vect[n]+", ";
		info += vect[vect.length-1]+")\n";
		
		return info;
	}
	
	private final float[] priorFromFA(float[] fa, int nx, int ny, int nz) {
		float[] priorFA = new float[nx*ny*nz];
		float x0=0, y0=0, z0=0, sum=0;
		float faMin = 0.25f;
		float faMax = 0.75f;
		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			int xyz = x+nx*y+nx*ny*z;
			priorFA[xyz] = (faMin+faMax)/(faMax*faMax*PI2)*fa[xyz]*fa[xyz]*faMax*faMax
								/(faMin*faMin + fa[xyz]*fa[xyz])/(faMax*faMax+fa[xyz]*fa[xyz]);
			x0 += priorFA[xyz]*x;
			y0 += priorFA[xyz]*y;
			z0 += priorFA[xyz]*z;
			sum += priorFA[xyz];
		}
		x0 /= sum;
		y0 /= sum;
		z0 /= sum;
		
		float dM = (nx*nx+ny*ny+nz*nz)/48.0f;
		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			int xyz = x+nx*y+nx*ny*z;
			float d0 = (x-x0)*(x-x0)+(y-y0)*(y-y0)+(z-z0)*(z-z0);
			priorFA[xyz] *= dM/(dM+d0);
		}
		return priorFA;
	}

}
