package edu.jhu.ece.iacl.algorithms.SpineSeg;
 
import java.io.*;
import java.util.*;

import javax.vecmath.Point3f;
import javax.vecmath.Point3i;

import Jama.Matrix;
import gov.nih.mipav.model.structures.ModelImage;
import gov.nih.mipav.view.*;

import edu.jhmi.rad.medic.libraries.*;
import edu.jhmi.rad.medic.methods.DigitalHomeomorphism;
import edu.jhmi.rad.medic.methods.FastMarching;
import edu.jhmi.rad.medic.utilities.*;
import edu.jhmi.rad.medic.structures.*;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities;
import edu.jhu.ece.iacl.algorithms.volume.TransformVolume;
import edu.jhu.ece.iacl.jist.io.ImageDataReaderWriter;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataByte;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.plugins.registration.MedicAlgorithmApproxHomeomorphicDef;
import edu.jhu.ece.iacl.plugins.registration.MedicAlgorithmTransformVolume;

/**
 *
 *  This class handles full structure atlas information:
 *	shape, topology, relations, etc.
 *
 *	@version    March 2012
 *	@author     Pierre-Louis Bazin
 *	@author     Min Chen
 *		
 *
 */
 
public class SpineDeformableAtlas {
	private static final String cvsversion = "$Revision: 1.1 $"; 
	public static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", ""); 

	public static String get_version() {
	   return revnum;
	}


	// structures: basic information	
	private		int					classes;			// number of strcutures in the atlas
	private		String[]			name;				// their names
	private		byte[]				label;				// their labels
	private		String[]			topology;			// their topology type
	//private		int					transformMode;	// the type of transform (from possible ones below)
	//private static final	int   	NONE = 0;
	//private static final	int   	DEFORMABLE = 3;
	
	// shape maps
	//private		boolean[]			hasShape;			// flag to notify which structures use a shape atlas
	private		ImageData[]		shapeAtlases;				// the shape images
	//private		String[]			shapeFile;			// the shape filenames
	//private		int[]				nsx,nsy,nsz;		// the dimensions for each shape image
	//private		float[][]				center;
	//private 	float[]				rsx,rsy,rsz; 		// shape resolutions
	//private		int[]				minx,miny,minz;		// the lowest image coordinate with non zero prior for each structure
	//private		int[]				maxx,maxy,maxz;		// the highest image coordinate with non zero prior for each structure
	private		float[]				shapeConsistency;	// the amount of variability in the shape (0: very variable, 1: fixed)
	private		float[]				smoothingConsistency;	// the amount of variability in the smoothing (0: very variable, 1: fixed)
	//private		boolean[][]			shapeCoupling;		// the amount of variability in the shape (0: very variable, 1: fixed)
	//private		boolean[]			registeredShape;
	//private		float				shapeSlope = 2.0f;			// amount of slope recovery to use
	private 	byte 				topBgLabel;
	private 	int 				bgIndex;
	// topology template
	private		ImageDataByte			topologyAtlas;		// the topology template for the segmentation
	private		ImageDataByte			cordAtlas;		// the topology template with just the spinal cord
	
	// Intensity Image Atlas 
	private ImageData intensityImageAtlas; //MRI Corresponding with topology atlas 
	
	// intensity models
	private		boolean[]			hasIntensity;	// which intensity models are available
	private		float[][]			intensity;		// the intensity models, normalised between 0 and 1
	
	// intensity variance models
	private		boolean[]			hasIntensityVariance;	// which intensity models are available
	private		float[][]			intensityVariance;		// the intensity models, normalised between 0 and 1
	
	// intensity models in use..
	public		static final int	T1_SPGR = 0;
	public		static final int	T2 = 1;
	public		static final int	FLAIR = 2;
	public		static final int	T1_MPRAGE = 3;
	public		static final int	T1_RAW = 4;
	public		static final int	PD = 5;
	public		static final int	PDFSE = 6;
	public		static final int	DIR = 7;
	public		static final int	MT_KKI = 8;
	public		static final int	T1_GE = 9;
	public		static final int	T1_NIH = 10;
	public		static final int	T2_NIH = 11;
	public		static final int	INTENSITY = 12;	// the number of possible intensity models
	
	// modality weighting
	private		float[][]			modweight;		// the weighting for modality and obj / lesion (temporary)
	private		static final int	OBJTYPES = 2;	// number of different object types (for now obj and lesion)


	private		float[][]			optimizedFactor;	// the factors for optimized distances of CSF/GM/WM
	private	static final 	int		OPTIMIZED = 3;	
	
	// lesions models
	private		boolean				hasLesions;		// whether there is a lesion model
	private		float[]				lesion;			// the lesion model for each intensity
	
	// for debug and display
	private static final boolean		debug=true;
	private static final boolean		verbose=true;
	
	
	private ImageData referenceImg;//Image that needs to be segemented
	private int XN,YN,ZN;
	/**
	 *	constructor: load the atlas information from a file.
	 *	<p>
	 *	The atlas files follow a certain template; 
	 *  the separator between numbers is a tab, not a space.
	 */
	public SpineDeformableAtlas(String filename, ImageData refImage){
		
		//transformMode = NONE;
		referenceImg = refImage;
		XN = referenceImg.getRows();
		YN = referenceImg.getCols();
		ZN = referenceImg.getSlices();
		loadAtlas(filename);
		
		
		//Transform all atlases into same resolution/dimensions as reference image
		
		//Set Tranform Matrix as Identity
		Matrix idMatrix = new Matrix(4,4);
		for(int ii=0;ii<4;ii++) idMatrix.set(ii,ii,1);
		
		Point3i dimensions;
		Point3f resolutions;
		float[] res;
		res = referenceImg.getHeader().getDimResolutions();
		resolutions = new Point3f(res[0], res[1], res[2]);
		dimensions = new Point3i(referenceImg.getRows(), referenceImg.getCols(), referenceImg.getSlices());

		//Then use transform to rescale and then VABRA def to deform
		ModelImage tempModelImage =new ImageDataInt(topologyAtlas).getModelImageCopy();		
		topologyAtlas= new ImageDataByte(TransformVolume.transform(tempModelImage,TransformVolume.Interpolation.Nearest_Neighbor, idMatrix, resolutions, dimensions));
		tempModelImage.disposeLocal();
		topologyAtlas.setHeader(referenceImg.getHeader());
		topologyAtlas.setName(referenceImg.getName()+"topology_def");
		
		//Set atlas with just the cord
		cordAtlas = topologyAtlas.mimic(); 
		
		for(int i = 0; i < XN; i ++) for(int j = 0; j < YN; j ++) for(int k = 0; k < ZN; k ++){
			if(topologyAtlas.getInt(i, j, k) == 5){
				cordAtlas.set(i, j, k, 5);
			}
		}
		
		
		for (int q = 0; q < shapeAtlases.length; q++) {
			tempModelImage = shapeAtlases[q].getModelImageCopy();		
			shapeAtlases[q] =new ImageDataFloat(TransformVolume.transform(tempModelImage,
					TransformVolume.Interpolation.Trilinear, idMatrix, resolutions, dimensions));
			shapeAtlases[q].setHeader(referenceImg.getHeader());
			shapeAtlases[q].setName(referenceImg.getName()+"shape_def_"+q);
			tempModelImage.disposeLocal();
		}
		

	}
	
	public void loadPreDeformedAtlases(ImageData deformedTopologyAtlas, ImageData deformedPriors) {
		//applyDefFieldToAtlases();
		
		if(deformedTopologyAtlas != null){
			topologyAtlas = new ImageDataByte(deformedTopologyAtlas);
			topologyAtlas.setHeader(referenceImg.getHeader());
			topologyAtlas.setName(referenceImg.getName()+"topology_def");
		}
		
		//adaptivelyBuildTopologyAtlas();
		

		
		//hack for setting priors - Build this into algorithm correctly
		if(deformedPriors != null){
			int XN=deformedPriors.getRows();
			int YN=deformedPriors.getCols();
			int ZN=deformedPriors.getSlices();
			int[] order = new int[]{0, 2, 3, 1};//relates order from atlas to input atlas(fix this)
			
			for (int q = 0; q < shapeAtlases.length; q++) {
				shapeAtlases[q] = new ImageDataFloat(XN,YN,ZN);
				
				for(int i = 0; i < XN; i ++) for(int j = 0; j < YN; j ++) for(int k = 0; k < ZN; k ++){
					shapeAtlases[q].set(i,j,k,Math.max(0,deformedPriors.getFloat(i, j, k, order[q])-.05));
				}
				shapeAtlases[q].setHeader(referenceImg.getHeader());
				shapeAtlases[q].setName(referenceImg.getName()+"shape_def_"+q);
			}
			
		}
		
		//if (!priorShapeCoupling) initShapeCoupling();
	}
	
	public void buildTopologyAtlasFromCordAtlas(){
		
		topologyAtlas = cordAtlas;
		dilateWithNewLabel(topologyAtlas, 3, 0, 2);
		fillboundary(topologyAtlas,2,2,0);
		dilateWithNewLabel(topologyAtlas, 10, 0, 10);
		fillboundary(topologyAtlas,10,1,0);
		//build Surrounding region as 10 voxels around the csf+cord
		
	}
	
	private void fillboundary(ImageData imgToFill, int labelToFill, int depthToFill,int bgVal){
		int XN = imgToFill.getRows();
		int YN = imgToFill.getCols();
		int ZN = imgToFill.getSlices();
		
		//clean outter shell
		for(int i = 0; i < XN; i ++) for(int j = 0; j < YN; j ++) for(int k = 0; k < depthToFill; k++){
			if(imgToFill.getInt(i, j, k) != bgVal){
				imgToFill.set(i,j,k,labelToFill);
			}
			
			if(imgToFill.getInt(i, j, ZN-k-1) != bgVal){
				imgToFill.set(i,j,ZN-k-1,labelToFill);
			}			
		}
		
		for(int i = 0; i < XN; i ++) for(int j = 0; j < depthToFill; j ++) for(int k = 0; k < ZN; k++){
			if(imgToFill.getInt(i, j, k) != bgVal){
				imgToFill.set(i,j,k,labelToFill);
			}
			
			if(imgToFill.getInt(i, YN-j-1, k) != bgVal){
				imgToFill.set(i,YN-j-1,k,labelToFill);
			}			
		}
		
		for(int i = 0; i < depthToFill; i ++) for(int j = 0; j < YN; j ++) for(int k = 0; k < ZN; k++){
			if(imgToFill.getInt(i, j, k) != bgVal){
				imgToFill.set(i,j,k,labelToFill);
			}
			
			if(imgToFill.getInt(XN-i-1, j, k) != bgVal){
				imgToFill.set(XN-i-1,j,k,labelToFill);
			}			
		}
		
		
	}
	
	private void dilateWithNewLabel(ImageData imgTodilate, int dilationSize, int bgVal, int newLabel){
		boolean setPoint;
		int XN = imgTodilate.getRows();
		int YN = imgTodilate.getCols();
		int ZN = imgTodilate.getSlices();
		for(int i = 0; i < XN; i ++) for(int j = 0; j < YN; j ++) for(int k = 0; k < ZN; k ++){
			if(imgTodilate.getInt(i, j, k) == bgVal){
				setPoint = false;
				NeighborSrch:
				for(int x = -dilationSize; x <= dilationSize; x++) 
					for(int y = -dilationSize; y <= dilationSize; y++) 
						for(int z = -dilationSize; z <= dilationSize; z++){
					if( i+x>=0 && i+x<XN && j+y>=0 && j+y<YN && k+z>=0 && k+z<ZN){
						if(imgTodilate.getInt(i+x, j+y, k+z) != bgVal && imgTodilate.getInt(i+x, j+y, k+z) != newLabel){
							setPoint = true;
							break NeighborSrch;
						}
					}
				}
				if(setPoint) imgTodilate.set(i, j, k, newLabel);
			}
		}
		
	}
	
	
	public void setShapePriorFromDefField(ImageData atlasToSubDefField){
		
		//ImageData origDef = atlasToSubDefField;
		boolean[][][][] splitLabels = splitLabelsIntoBool(topologyAtlas);
		float[][][][] dist, signedDist; 
		float maxDistFloat = (float)300;
		//calculate signed Distance Functions(For Memberships)
		dist = new float[classes][XN][YN][ZN];
		signedDist = new float[classes][][][];
		for(int c = 0; c < classes; c++){
			for (int i = 0; i < XN; i++) for (int j = 0; j < YN; j++)for (int k = 0; k < ZN; k++) {
						if(splitLabels[c][i][j][k]) dist[c][i][j][k] = 0;
						else dist[c][i][j][k] = 1;
			}//end for loops
			//calc distance transform
			signedDist[c] = FastMarching.signedDistanceFunction(dist[c],XN,YN,ZN,maxDistFloat); 
		}
		
		//Calculate Full Membership from DefField
		calcShapePriorFromDefField(atlasToSubDefField,splitLabels,signedDist);
		
	}
	
	
	private boolean[][][][] splitLabelsIntoBool(ImageData origLabels){

		boolean[][][][] splitLabels = new boolean[label.length][XN][YN][ZN];
		byte currentLabel;
		for(int l = 0; l<classes;l++){
			currentLabel = label[l];
			for(int i =0; i < XN; i++)
				for(int j =0; j < YN; j++)
					for(int k =0; k < ZN; k++){

						if(origLabels.getByte(i, j, k) == currentLabel){
							splitLabels[l][i][j][k] = true;
						}else{
							splitLabels[l][i][j][k] = false;
						}


					}
		}
		return splitLabels;
	}	
	
	
private void calcShapePriorFromDefField(ImageData defField, boolean[][][][] splitLabelMap, float[][][][] distTrans){
		
		float[][][] currentMembership;
		for (int c = 0; c < classes; c++){
			//System.out.format("Calculation membership for label" +labels.get(i) +"\n");
			currentMembership = calcMembership(defField, splitLabelMap[c],distTrans[c],(label[c]==0));
			
			for(int i = 0; i<XN;i++)
				for(int j = 0; j<YN;j++)
					for(int k = 0; k<ZN;k++){
						shapeAtlases[c].set(i,j,k,Math.max(0,currentMembership[i][j][k]-.05));
					}
		}	
		//reminder - component gets flipped
	}
	
	private float[][][] calcMembership(ImageData origDef, boolean[][][] labelMap, float[][][] distTrans, boolean isBg){

		float[][][] labelMembership = new float[XN][YN][ZN];
		boolean isboundary;
		List<int[]> objList = new ArrayList<int[]>();
		//find boundaries of objects
		for(int i=0; i < XN; i++)for(int j =0; j < YN; j++)	for(int k =0; k < ZN; k++){
					if(labelMap[i][j][k]){
						isboundary = false;
						outerLoop:
						for (int ii=-1;ii<=1;ii++) for (int jj=-1;jj<=1;jj++) for (int kk=-1;kk<=1;kk++) {
							if(i+ii < 0 || i+ii > XN-1 || j+jj < 0 || j+jj > YN-1 || k+kk < 0 || k+kk > ZN-1){
								//if out of bounds then is boundary
								isboundary = true;
								break outerLoop;
							}else{
								if (!labelMap[i+ii][j+jj][k+kk]){ //if adjacent is not in object, then is boundary
									isboundary = true;
									break outerLoop;
								}
							}
						}
						if(isboundary)objList.add(new int[]{i,j,k});
					}
					
				}
		
		System.out.format("Number of voxels in Object:" +objList.size() +"\n");
		float maxDistance = 0;
		float x, y, z;
		ImageData distTransImage = new ImageDataFloat(distTrans);
		for(int i=0; i < XN; i++){
			//System.out.format("Current i:" + i +"\n");
			for(int j =0; j < YN; j++)
				for(int k =0; k < ZN; k++){
					x=i+origDef.getFloat(i, j, k,0);
					y=j+origDef.getFloat(i, j, k,1);
					z=k+origDef.getFloat(i, j, k,2);
					
					if(distTrans == null){
						if(RegistrationUtilities.NNInterpolationBool(labelMap, XN, YN, ZN, x, y, z)){ 
							//if deform to the target, then distance is zero
							labelMembership[i][j][k] = 0;
						}else{//else find closest boundary point of object
							labelMembership[i][j][k] = minDistanceSquareToLabels(x, y, z, objList);
						}
					}
					else{
						if(x < 0 || x > XN-1 || y < 0 || y > YN - 1|| z < 0 || z > ZN-1){
							
							if(isBg){
								labelMembership[i][j][k] = -1; //Boundary case for background, flag as full membership
							}else{
								//System.out.format("Out of Bounds, Calculating\n");
								labelMembership[i][j][k] = (float)Math.sqrt(minDistanceSquareToLabels(x, y, z, objList));
							}
						}else{
							labelMembership[i][j][k] = (float)RegistrationUtilities.TrilinearInterpolation(distTransImage, XN, YN, ZN, x, y, z);
						}
							
						labelMembership[i][j][k] = labelMembership[i][j][k] * distSquare(origDef.getFloat(i, j, k,0),origDef.getFloat(i, j, k,1),origDef.getFloat(i, j, k,2),0,0,0);
					}
					
					
					
					if(labelMembership[i][j][k] > maxDistance){
						maxDistance = labelMembership[i][j][k];
						//System.out.format("Max Distance:" + maxDistance +"\n");
					}
					
				}
		}
		
		maxDistance = (float)Math.log(maxDistance + 1);
		
		for(int i=0; i < XN; i++)
			for(int j =0; j < YN; j++)
				for(int k =0; k < ZN; k++){
					if(labelMembership[i][j][k] < 0){//If flagged, then force to 1;
						labelMembership[i][j][k] = 1;
					}else
						labelMembership[i][j][k] = (float)(1 - Math.log(labelMembership[i][j][k]+1)/maxDistance); 
				}

		return labelMembership;
	}
	
	private float distSquare(float x1, float y1, float z1, float x2, float y2, float z2){
		return (float)((x1-x2)*(x1-x2) + (y1-y2)*(y1-y2) + (z1-z2)*(z1-z2));
	}
	
	private float minDistanceSquareToLabels(float x, float y, float z, List<int[]> objList){

		float minDistance = Float.MAX_VALUE;
		float currentDist;

		Iterator<int[]> itr = objList.iterator();
		int[] currentObjPoint;
		
		while(itr.hasNext()){
			currentObjPoint = itr.next(); 
			currentDist = distSquare(x,y,z,currentObjPoint[0],currentObjPoint[1],currentObjPoint[2]); 
			if(currentDist < minDistance){
				minDistance = currentDist;
			}
		}
		
		return minDistance;
	}
	
	//Apply VABRA deformation to Shape and Topology Atlas
	public void applyDefFieldToTopologyAtlas(ImageData atlasToSubDefField){
		ImageDataFloat atlasToSubDefFieldF =new ImageDataFloat(atlasToSubDefField);
		topologyAtlas = new ImageDataByte(RegistrationUtilities.DeformImage3D(
				topologyAtlas, atlasToSubDefFieldF, RegistrationUtilities.InterpolationType.NEAREST_NEIGHTBOR));
		topologyAtlas.setHeader(referenceImg.getHeader());
		topologyAtlas.setName(referenceImg.getName()+"topology_def");
	}
	
	public void applyDefFieldToCordAtlas(ImageData atlasToSubDefField){
		ImageDataFloat atlasToSubDefFieldF =new ImageDataFloat(atlasToSubDefField);
		cordAtlas = new ImageDataByte(RegistrationUtilities.DeformImage3D(
				cordAtlas, atlasToSubDefFieldF, RegistrationUtilities.InterpolationType.NEAREST_NEIGHTBOR));
		cordAtlas.setHeader(referenceImg.getHeader());
		cordAtlas.setName(referenceImg.getName()+"cord_def");
	}
		
	public void applyDefFieldToShapePriors(ImageData atlasToSubDefField){
		//deform shape priors using deformation field
		ImageDataFloat atlasToSubDefFieldF =new ImageDataFloat(atlasToSubDefField);
		for (int q = 0; q < shapeAtlases.length; q++) {
			shapeAtlases[q] = RegistrationUtilities.DeformImage3D(
					shapeAtlases[q], atlasToSubDefFieldF, RegistrationUtilities.InterpolationType.TRILINEAR);
			shapeAtlases[q].setHeader(referenceImg.getHeader());
			shapeAtlases[q].setName(referenceImg.getName()+"shape_def_"+q);
		}
	}
	
	/**
	 *	read template image (the image must be in bytes)
	 */
	private ImageDataByte loadTemplateImage(String filename) {
		ImageDataReaderWriter rw = ImageDataReaderWriter.getInstance();
        File f = new File( filename );
		ImageDataByte out = new ImageDataByte(rw.read(f));
		System.out.format("Loaded toplogy atlas"+ out.getName() +"\n");
		return out;
	}
	/**
	 *	read shape image (the image must be in float)
	 */
	private final ImageDataFloat loadShapeImage(String filename){
		// read the raw data
		ImageDataReaderWriter rw = ImageDataReaderWriter.getInstance();
        File f = new File( filename );
		ImageDataFloat out = new ImageDataFloat(rw.read(f));
		System.out.format("Loaded shaped atlas"+ out.getName() +"\n");
		return out;
	}
	/**
	 *	read any image
	 */
	private final ImageData loadImage(String filename){
		// read the raw data
		ImageDataReaderWriter rw = ImageDataReaderWriter.getInstance();
        File f = new File( filename );
		ImageData out = rw.read(f);
		System.out.format("Loaded atlas"+ out.getName() +"\n");
		return out;
	}
	
	/** 
	 *	load the atlas data from a file. 
	 *  All associated images are loaded at this time
	 */
	final public void loadAtlas(String filename) {
		if (verbose) System.out.println("loading atlas file: "+filename);
		try {
            File f = new File(filename);
			String dir = f.getParent();
            FileReader fr = new FileReader(f);
            BufferedReader br = new BufferedReader(fr);
            String line = br.readLine();
			StringTokenizer st;
			String imageFile;
            // Exact corresponding template
            if (!line.equals("Structure Atlas File (edit at your own risks)")) {
                System.out.println("not a proper Structure Atlas file");
                br.close();
                fr.close();
                return;
            }
			line = br.readLine();
			while (line!=null) {
				if (line.startsWith("Structures")) {
					//System.out.println(line);
					// Structures:	classes	label	topology
					st = new StringTokenizer(line, "	");
					st.nextToken();
					classes = MipavUtil.getInt(st);
					name = new String[classes];
					label = new byte[classes];
					topology = new String[classes];
					for (int n=0;n<classes;n++) {
						// Name:label:topology
						line = br.readLine();
						st = new StringTokenizer(line, "	");
						name[n] = st.nextToken();
						label[n] = (byte)MipavUtil.getInt(st);
						if(name[n] == "Background" || name[n] == "background" ){
							bgIndex = n;
							topBgLabel=label[n];
						}
						topology[n] = st.nextToken();
					}
					// allocate other quantities
					//hasTopology = false;
					//hasShape = new boolean[classes];
				//	for (int n=0;n<classes;n++) hasShape[n] = false;
					shapeAtlases = new ImageDataFloat[classes];
				/*	nsx = new int[classes];
					nsy = new int[classes];
					nsz = new int[classes];
					center = new float[classes][3];
					rsx = new float[classes];
					rsy = new float[classes];
					rsz = new float[classes];
					minx = new int[classes];
					miny = new int[classes];
					minz = new int[classes];
					maxx = new int[classes];
					maxy = new int[classes];
					maxz = new int[classes];*/
					hasIntensity = new boolean[INTENSITY];
					for (int i=0;i<INTENSITY;i++) hasIntensity[i] = false;
					intensity = new float[INTENSITY][classes];
					hasIntensityVariance = new boolean[INTENSITY];
					for (int i=0;i<INTENSITY;i++) hasIntensityVariance[i] = false;
					intensityVariance = new float[INTENSITY][classes];
					for (int i=0;i<INTENSITY;i++) for (int j=0;j<classes;j++) intensityVariance[i][j] = 1.0f;
					modweight = new float[INTENSITY][OBJTYPES];
					for (int i=0;i<INTENSITY;i++) for (int j=0;j<OBJTYPES;j++) modweight[i][j] = 1.0f;
					optimizedFactor = new float[INTENSITY][OPTIMIZED];
					for (int i=0;i<INTENSITY;i++) for (int j=0;j<OPTIMIZED;j++) optimizedFactor[i][j] = 1.0f;
					lesion = new float[INTENSITY];
					//shapeFile = new String[classes];
				//	for (int n=0;n<classes;n++) shapeFile[n] = null;
					shapeConsistency = new float[classes];
					for (int n=0;n<classes;n++) shapeConsistency[n] = 1.0f;
					//shapeCoupling = new boolean[classes][classes];
					//for (int n=0;n<classes;n++) for (int m=0;m<classes;m++) shapeCoupling[n][m] = false;
					//priorShapeCoupling = false;
					smoothingConsistency = new float[classes];
					for (int n=0;n<classes;n++) smoothingConsistency[n] = 1.0f;
					//registeredShape = new boolean[classes];
				//	for (int n=0;n<classes;n++) registeredShape[n] = true;
				//	registeredShape[0] = false;
					//templateFile = null;
					if (debug) System.out.println(displayNames());
				} else
				if (line.startsWith("Intensity Image Atlas")) {
					//System.out.println(line);
					// File: name
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					st.nextToken();
					imageFile = dir+File.separator+st.nextToken();
					if (debug) System.out.print("file: "+imageFile+"\n");
					intensityImageAtlas = loadImage(imageFile);
				} else
				if (line.startsWith("Topology Atlas")) {
					//System.out.println(line);
					// File: name
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					st.nextToken();
					imageFile = dir+File.separator+st.nextToken();
					if (debug) System.out.print("file: "+imageFile+"\n");
					topologyAtlas = loadTemplateImage(imageFile);
				} else
				if (line.startsWith("Shape Atlas")) {
					line = br.readLine();
					while (line.startsWith("Structure:")) {
						// find structure id
						st = new StringTokenizer(line, "	");
						st.nextToken();
						String title = st.nextToken();
						int id=-1;
						for (int n=0;n<classes;n++) {
							if (title.equals(name[n])) { id = n; break; }
						}
						if (debug) System.out.print("Shape: "+name[id]+"\n");
						if (id==-1) {
							line = br.readLine();
							line = br.readLine();
							line = br.readLine();
						} else {
							// File: name
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
							imageFile = dir+File.separator+st.nextToken();
							// Dimensions: nix niy niz
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
						//	nsx[id] = MipavUtil.getInt(st);
						//	nsy[id] = MipavUtil.getInt(st);
						//	nsz[id] = MipavUtil.getInt(st);
						//	if (debug) System.out.print("dim: "+nsx[id]+"x"+nsy[id]+"x"+nsz[id]+"\n");
							// Resolutions: rsx rsy rsz
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
						//	rsx[id] = MipavUtil.getFloat(st);
						//	rsy[id] = MipavUtil.getFloat(st);
						//	rsz[id] = MipavUtil.getFloat(st);
						//	if (debug) System.out.print("res: "+rsx[id]+"x"+rsy[id]+"x"+rsz[id]+"\n");
							// Center: cx cy cz
							line = br.readLine();
							st = new StringTokenizer(line, "	");
							st.nextToken();
			
							shapeAtlases[id] = loadShapeImage(imageFile);
						//	hasShape[id] = true;
						//	shapeFile[id] = imageFile;
						}
						line = br.readLine();
					}
				} else
				if (line.startsWith("Intensity Atlas")) {
					//if (debug) System.out.println(line);
					// Intensity:	intensitySamples
					st = new StringTokenizer(line, "	");
					st.nextToken();
					//intensitySamples = MipavUtil.getInt(st);
					// Type value value value...
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					String type = st.nextToken();
					//if (debug) System.out.println(type);
					int id = modalityId(type);
					
					while (id !=-1 ) {
						for (int n=0;n<classes;n++) {
							intensity[id][n] = MipavUtil.getFloat(st);
						}
						hasIntensity[id] = true;
						// search for next intensity profile
						line = br.readLine();
						st = new StringTokenizer(line, "	");
						type = st.nextToken();
						if (debug) System.out.println(type);
						id = modalityId(type);
					}
					if (debug) System.out.println(displayIntensity());
				} else
				if (line.startsWith("Intensity Variance")) {
					// Type value value value...
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					String type = st.nextToken();
					//if (debug) System.out.println(type);
					int id = modalityId(type);
					
					while (id !=-1 ) {
						for (int n=0;n<classes;n++) {
							intensityVariance[id][n] = MipavUtil.getFloat(st);
						}
						hasIntensityVariance[id] = true;
						// search for next intensity profile
						line = br.readLine();
						st = new StringTokenizer(line, "	");
						type = st.nextToken();
						if (debug) System.out.println(type);
						id = modalityId(type);
					}
					if (debug) System.out.println(displayIntensityVariance());
				} else
/*				if (line.startsWith("Shape Prior Couplings")) {
					// Shape PriorCouplings:
					// Label 0/1 0/1 0/1... (size: classes x classes)
					for (int n=0;n<classes;n++) {
						line = br.readLine();
						//if (debug) System.out.println(line);
						st = new StringTokenizer(line, "	");
						String type = st.nextToken();
						//if (debug) System.out.println(type);
						for (int m=0;m<classes;m++) {
							shapeCoupling[n][m] = MipavUtil.getBoolean(st);
							//if (debug) System.out.print(shapeConsistency[n]+" ");
						}
						//if (debug) System.out.println("\n");
					}
					//priorShapeCoupling = true;
					if (debug) System.out.println(displayShapeCouplings());
				} else*/
				if (line.startsWith("Lesions")) {
					// Lesions: 
					// T1 val T2 val FLAIR val MPRAGE val
					line = br.readLine();
					if (debug) System.out.println(line);
					st = new StringTokenizer(line, "	");
					while (st.hasMoreTokens()) {
						String type = st.nextToken();
						if (debug) System.out.print(type+" ");
						int id = modalityId(type);
						if (id>-1) {
							lesion[id] = MipavUtil.getFloat(st);
							if (debug) System.out.println(lesion[id]);
						}
					}
					if (debug) System.out.println("\n");
					hasLesions = true;
					if (debug) System.out.println(displayLesionModel());
				} else
/*				if (line.startsWith("Slope")) {
					// Slope: val 
					st = new StringTokenizer(line, "	");
					String text = st.nextToken();
					if (debug) System.out.println(text);
					shapeSlope = MipavUtil.getFloat(st);
					if (debug) System.out.println("shape slope = "+shapeSlope);
				} else*/
				if (line.startsWith("Modality weights")) {
					//if (debug) System.out.println(line);
					// Intensity:	intensitySamples
					st = new StringTokenizer(line, "	");
					st.nextToken();
					// Type value value
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					String type = st.nextToken();
					if (debug) System.out.println(type);
					int id = modalityId(type);
					
					while (id !=-1 ) {
						for (int n=0;n<OBJTYPES;n++) {
							modweight[id][n] = MipavUtil.getFloat(st);
						}
						// search for next intensity profile
						line = br.readLine();
						st = new StringTokenizer(line, "	");
						type = st.nextToken();
						if (debug) System.out.println(type);
						id = modalityId(type);
					}
					if (debug) System.out.println(displayModalityWeights());
				} else
				/*if (line.startsWith("Registered Shapes")) {
					if (debug) System.out.println(line);
					// Type value value
					line = br.readLine();
					if (debug) System.out.println(line);
					st = new StringTokenizer(line, "	");
					for (int n=0;n<classes;n++) {
						registeredShape[n] = (MipavUtil.getInt(st)==1);
					}
					if (debug) System.out.println(displayRegisteredShapes());
				} else*/
				if (line.startsWith("Optimized factors")) {
					//if (debug) System.out.println(line);
					// Intensity:	intensitySamples
					st = new StringTokenizer(line, "	");
					st.nextToken();
					// Type value value
					line = br.readLine();
					st = new StringTokenizer(line, "	");
					String type = st.nextToken();
					if (debug) System.out.println(type);
					int id = modalityId(type);
					
					while (id !=-1 ) {
						for (int n=0;n<OPTIMIZED;n++) {
							optimizedFactor[id][n] = MipavUtil.getFloat(st);
						}
						// search for next intensity profile
						line = br.readLine();
						st = new StringTokenizer(line, "	");
						type = st.nextToken();
						if (debug) System.out.println(type);
						id = modalityId(type);
					}
					if (debug) System.out.println(displayOptimizedFactors());
				}
				line = br.readLine();
				if (debug) System.out.println(line);
			}		
			br.close();
            fr.close();
			//atlasFile = filename;
        }
        catch (FileNotFoundException e) {
            System.out.println(e.getMessage());
        }
        catch (IOException e) {
            System.out.println(e.getMessage());
        } 
		catch (OutOfMemoryError e){
			System.out.println(e.getMessage());
		}
		catch (Exception e) {
			System.out.println(e.getMessage());
        }

		if (debug) MedicUtilPublic.displayMessage("initialisation\n");
	}

	/** display the atlas data */
	final public String displayIntensity() {
		String output = "Intensity \n";
		
		output += "T1_SPGR ("+hasIntensity[0]+") : ";
		for (int k=0;k<classes;k++) output += intensity[0][k]+" ";
		output += "\n";
		output += "T2 ("+hasIntensity[1]+") : ";
		for (int k=0;k<classes;k++) output += intensity[1][k]+" ";
		output += "\n";
		output += "FLAIR ("+hasIntensity[2]+") : ";
		for (int k=0;k<classes;k++) output += intensity[2][k]+" ";
		output += "\n";
		output += "T1_MPRAGE ("+hasIntensity[3]+") : ";
		for (int k=0;k<classes;k++) output += intensity[3][k]+" ";
		output += "\n";
		output += "T1_RAW ("+hasIntensity[4]+") : ";
		for (int k=0;k<classes;k++) output += intensity[4][k]+" ";
		output += "\n";
		output += "PD ("+hasIntensity[5]+") : ";
		for (int k=0;k<classes;k++) output += intensity[5][k]+" ";
		output += "\n";
		output += "PDFSE ("+hasIntensity[6]+") : ";
		for (int k=0;k<classes;k++) output += intensity[6][k]+" ";
		output += "\n";
		output += "MT(KKI) ("+hasIntensity[8]+") : ";
		for (int k=0;k<classes;k++) output += intensity[8][k]+" ";
		output += "\n";
		output += "T1(GE) ("+hasIntensity[9]+") : ";
		for (int k=0;k<classes;k++) output += intensity[9][k]+" ";
		output += "\n";
		
		return output;	
	}
	/** display the atlas data */
	final public String displayIntensityVariance() {
		String output = "Intensity Variance \n";
		
		output += "T1_SPGR ("+hasIntensityVariance[0]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[0][k]+" ";
		output += "\n";
		output += "T2 ("+hasIntensityVariance[1]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[1][k]+" ";
		output += "\n";
		output += "FLAIR ("+hasIntensityVariance[2]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[2][k]+" ";
		output += "\n";
		output += "T1_MPRAGE ("+hasIntensityVariance[3]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[3][k]+" ";
		output += "\n";
		output += "T1_RAW ("+hasIntensityVariance[4]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[4][k]+" ";
		output += "\n";
		output += "PD ("+hasIntensityVariance[5]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[5][k]+" ";
		output += "\n";
		output += "PDFSE ("+hasIntensityVariance[6]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[6][k]+" ";
		output += "\n";
		output += "MT(KKI) ("+hasIntensity[8]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[8][k]+" ";
		output += "\n";
		output += "T1(GE) ("+hasIntensity[9]+") : ";
		for (int k=0;k<classes;k++) output += intensityVariance[9][k]+" ";
		output += "\n";
		
		return output;	
	}
	/** display the atlas data */
	final public String displayModalityWeights() {
		String output = "modality Weights \n";
		
		output += "T1 : ";
		for (int k=0;k<OBJTYPES;k++) output += modweight[0][k]+" ";
		output += "\n";
		output += "T2 : ";
		for (int k=0;k<OBJTYPES;k++) output += modweight[1][k]+" ";
		output += "\n";
		output += "FLAIR : ";
		for (int k=0;k<OBJTYPES;k++) output += modweight[2][k]+" ";
		output += "\n";
		output += "MPRAGE : ";
		for (int k=0;k<OBJTYPES;k++) output += modweight[3][k]+" ";
		output += "\n";
		output += "MPRAGE_RAW : ";
		for (int k=0;k<OBJTYPES;k++) output += modweight[4][k]+" ";
		output += "\n";
		
		return output;	
	}
	/** display the atlas data */
	final public String displayOptimizedFactors() {
		String output = "optimized factors \n";
		
		output += "T1 : ";
		for (int k=0;k<OPTIMIZED;k++) output += optimizedFactor[0][k]+" ";
		output += "\n";
		output += "T2 : ";
		for (int k=0;k<OPTIMIZED;k++) output += optimizedFactor[1][k]+" ";
		output += "\n";
		output += "FLAIR : ";
		for (int k=0;k<OPTIMIZED;k++) output += optimizedFactor[2][k]+" ";
		output += "\n";
		output += "MPRAGE : ";
		for (int k=0;k<OPTIMIZED;k++) output += optimizedFactor[3][k]+" ";
		output += "\n";
		output += "MPRAGE_RAW : ";
		for (int k=0;k<OPTIMIZED;k++) output += optimizedFactor[4][k]+" ";
		output += "\n";
		output += "MT(KKI) : ";
		for (int k=0;k<OPTIMIZED;k++) output += optimizedFactor[8][k]+" ";
		output += "\n";
		output += "T1(GE) : ";
		for (int k=0;k<OPTIMIZED;k++) output += optimizedFactor[9][k]+" ";
		output += "\n";
		
		return output;	
	}
	/** display the atlas data */
	final public String displayLesionModel() {
		String output = "Lesion model ("+hasLesions+") : \n";
		
		output += "T1 "+lesion[0];
		output += ", T2 "+lesion[1];
		output += ", FLAIR "+lesion[2];
		output += ", MPRAGE "+lesion[3];
		output += ", MPRAGE_RAW "+lesion[4];
		output += ", PD "+lesion[5];
		output += ", PDFSE "+lesion[6];
		output += "\n";
		
		return output;	
	}
	/** display the atlas data */
	final public String displayNames() {
		String output = "Structures \n";
		
		for (int k=0;k<classes;k++) {
			output += name[k]+" ("+topology[k]+")	"+label[k]+"\n";
		}
		
		return output;	
	}
	/** display the atlas data */
	/*
	final public String displayShapeCouplings() {
		String output = "Shape Prior Couplings \n";
		
		output += "	";
		for (int k=0;k<classes;k++) output += k+"	";
		for (int k=0;k<classes;k++) {
			output += "\n"+k+"	";
			for (int l=0;l<classes;l++) {
				if (shapeCoupling[k][l]) output += "1	";
				else output += "0	";
			}
		}
		output += "\n";
		
		return output;	
	}
	// display the atlas data 
	final public String displayRegisteredShapes() {
		String output = "Registered Shapes (0/1: true/false) \n";
		
		for (int k=0;k<classes;k++) output += k+"	";
		output += "\n";
		for (int k=0;k<classes;k++) {
			if (registeredShape[k]) output += "1	";
			else output += "0	";
		}
		output += "\n";
		
		return output;	
	}*/
	/** display the atlas data */
	final public String displayShapePriors() {
		String output = "Shape Priors \n";
		
		for (int k=0;k<classes;k++) {
			output += shapeConsistency[k]+" ";
		}
		output += "\n";
		
		return output;	
	}
	/** display the atlas data */
	final public String displaySmoothingPriors() {
		String output = "Smoothing Priors \n";
		
		for (int k=0;k<classes;k++) {
			output += smoothingConsistency[k]+" ";
		}
		output += "\n";
		
		return output;	
	}
    
	public static final int modalityId(String type) {
		int id = -1;
		if (type.equals("T1") || type.equals("T1_SPGR")) id = T1_SPGR;
		else if (type.equals("T2")) id = T2;
		else if (type.equals("FLAIR")) id = FLAIR;
		else if (type.equals("MPRAGE") || type.equals("T1_MPRAGE")) id = T1_MPRAGE;
		else if (type.equals("MPRAGE_RAW") || type.equals("T1_RAW")) id = T1_RAW;
		else if (type.equals("PD")) id = PD;
		else if (type.equals("PDFSE")) id = PDFSE;
		else if (type.equals("DIR")) id = DIR;
		else if (type.equals("MT(KKI)")) id = MT_KKI;
		else if (type.equals("T1(GE)")) id = T1_GE;
		else if (type.equals("T1(NIH)")) id = T1_NIH;
		else if (type.equals("T2(NIH)")) id = T2_NIH;
		return id;
	}
	
	public final float[] getShapeArray(int x, int y, int z) {
		float[] val = new float[classes];
			for (int k=0;k<classes;k++) val[k] = shapeAtlases[k].getFloat(x,y,z);
			return val;
	}
	
	public final void cropTemplate(SpinePreprocess preprocess) {
		topologyAtlas = new ImageDataByte(preprocess.cropImage(topologyAtlas, topBgLabel));
		preprocess.updateTransformedTemplate(topologyAtlas);
		return;
	}
	
	public final void cropShapes(SpinePreprocess preprocess) {
		for(int k=0; k < shapeAtlases.length; k++) shapeAtlases[k] = new ImageDataFloat(preprocess.cropImage(shapeAtlases[k],getShapeBgProb(k)));
		return;
	}
	
	public float getShapeBgProb(int k){
		if(k == bgIndex) return 1.0f;
		else return 0.0f;
	}
	
	public byte getTopBgLabel(){
		return topBgLabel;
	}
	
	final public float[] getSmoothingPriors() { return smoothingConsistency; }
	final public float[] getShapePriors() { return shapeConsistency; }
	final public String[] 	getNames() { return name; }

	final public byte[]		getLabels() { return label; }
	final public ImageDataByte 	getCordAtlas() { return cordAtlas; }
	final public ImageDataByte 	getTemplate() { return topologyAtlas; }
	final public ImageData 	getIntensityImageAtlas() { return intensityImageAtlas; }
	final public ImageData 	getShape(int n) { return shapeAtlases[n]; }
	final public ImageData[] 	getShapes() { return shapeAtlases; }
	final public int 		getNumber() { return classes; }
	final public String[] 	getTopology() { return topology; }
	//final public boolean isDeformable() { return (transformMode==DEFORMABLE); }
    
	final public int[][] getIntensityGroups(String[] modality, int nc) {
		int[][]	group = new int[nc][classes];
		int lb = 1;
		for (int n=0;n<nc;n++) {
			for (int k=0;k<classes;k++) {
				group[n][k]=0;
			}
			if ( (modalityId(modality[n])>-1) && (hasIntensity[modalityId(modality[n])]) ) {
				int m = modalityId(modality[n]);
				for (int k=0;k<classes;k++) {
					for (int l=k;l<classes;l++) {
						if (intensity[m][l]==intensity[m][k] && group[n][l]==0) {
							group[n][l] = k+1;
						}
					}
				}
			} else {
				for (int k=0;k<classes;k++) group[n][k] = k+1;
			}	
		}
		return group;
	}
	
	final public float[][] 	getIntensityVariancePriors(String[] modality, int nc) { 
		float[][]	prior = new float[nc][classes];
		for (int n=0;n<nc;n++) {
			if (modalityId(modality[n])>-1) {
				for (int k=0;k<classes;k++) prior[n][k] = intensityVariance[modalityId(modality[n])][k];
			} else {
				for (int k=0;k<classes;k++) prior[n][k] = 1.0f;
			}	
		}
		return prior;
	}
	
	final public float[][] getModalityWeights(String[] modality, int nc) {
		float[][]	w = new float[nc][OBJTYPES];
		for (int n=0;n<nc;n++) {
			if ( (modalityId(modality[n])>-1) && (hasIntensity[modalityId(modality[n])]) ) {
				for (int k=0;k<OBJTYPES;k++) w[n][k] = modweight[modalityId(modality[n])][k];
			} else {
				for (int k=0;k<OBJTYPES;k++) w[n][k] = 1.0f;
			}	
		}
		if (debug) {
			for (int n=0;n<nc;n++) 
				System.out.print("modweight("+n+") = "+w[n][0]+", "+w[n][1]+"\n");
		}
		
		return w;
	}
	
	final public float[] getOptimizedFactor(String modality) {
		float[]	pw = new float[OPTIMIZED];
		if ( (modalityId(modality)>-1) && (hasIntensity[modalityId(modality)]) ) {
			for (int k=0;k<OPTIMIZED;k++) pw[k] = optimizedFactor[modalityId(modality)][k];
		} else {
			for (int k=0;k<OPTIMIZED;k++) pw[k] = 1.0f;
		}	
		if (debug) {
			System.out.print("optimizedfactor = "+pw[0]+", "+pw[1]+", "+pw[2]+"\n");
		}
		
		return pw;
	}
    
    final public float[][] 	getIntensityPriors(String[] modality, int nc) {
		float[][]	prior = new float[nc][classes];
		for (int n=0;n<nc;n++) {
			if ( (modalityId(modality[n])>-1) && (hasIntensity[modalityId(modality[n])]) ) {
				for (int k=0;k<classes;k++) prior[n][k] = intensity[modalityId(modality[n])][k];
			} else {
				for (int k=0;k<classes;k++) prior[n][k] = label[k];
			}	
		}
		return prior;
	}
    
    final public float[] 	getLesionPriors(String[] modality, int nc) {
		float[]	prior = new float[nc];
		if (hasLesions) {
			for (int n=0;n<nc;n++) {
				if ( (modalityId(modality[n])>-1) && (hasIntensity[modalityId(modality[n])]) ) {
					prior[n] = lesion[modalityId(modality[n])];
				} else {
					prior[n] = 0.0f;
				}
			}	
		} else {
			for (int n=0;n<nc;n++) prior[n] = 0.0f;
		}
		return prior;
	}
    
	/** clean-up: destroy membership and centroid arrays */
	public final void finalize() {
		shapeAtlases = null;
		topologyAtlas = null;
		intensityImageAtlas = null;
		intensity = null;
		System.gc();
	}
	
}
