package edu.jhmi.rad.medic.methods;

import gov.nih.mipav.view.*;
import gov.nih.mipav.model.structures.ModelImage;
import gov.nih.mipav.model.structures.ModelStorageBase;
import gov.nih.mipav.model.algorithms.AlgorithmMorphology25D;


import edu.jhmi.rad.medic.utilities.*;

/**
 * GMM via EM algorithm, modeling mixing coefficients with Dirichlett
 * 
 * @author Navid Shiee
 * 
 */
public class AtlasEMSegmentation  {



	// numerical quantities

	private static final    float   PI    =3.14159265358979f;
	private static final    float   Log2PI = (float)Math.log(PI);



	// invalid label flag, must be <0
	private	static	final	byte	EMPTY = -1;

	// data and membership buffers
	private 	float[][][][] 	images;  			// original images
	private     boolean[][][]   mask;               // image mask  
	private 	float[][][][]   mems;   			// membership functions for each class
	private 	byte[][][]		classification;  	// hard classification


	private 	float[][] 		centroid;   		// mean intensity for each class
	private     float[][][]     covariance;         // Covariance Matrix for each class
	private     float[]         mixingCoef;         // Mixing Coefficients of GMM

	

	private 	float			res= 1.0f;          //image resolutions
	private static	int 		nx,ny,nz,nc;   		// images dimensions
	private static	int 	  	classes;			// total number of classes (i.e. except outliers)
	private		byte[]			templateLabel;		// intensity labels for the template images
	private		float[]			Ihigh;				// The min max of the centroid range
	private     float[]         Iscale;


	private 	String  mixing_coefficients_type;   // How to estimate the mixing coefficietns
	private		double  Z;                          




	// parameters
	private 	float   		smoothing;   	// RFCM smoothing parameter
	private 	float[][][][]   atlas_prior; 	// mixing coefficients computed from the atlas
	private 	byte[][][]		segmentation;   // current segmentation






	// atlas prior
	private		DemonToadDeformableAtlas	atlas;          //statistical atlas


	// option flags
	private 	boolean 		useField;           		//flag for correcting inhomogeneity field
	private 	float[][][][]	field;  					// inhomogeneity field


	// computation flags
	private 	boolean 		isWorking;
	private 	boolean 		isCompleted;
	

	// for debug and display
	ViewUserInterface			UI;
	ViewJProgressBar            progressBar;
	static final boolean		debug=true;
	static final boolean		verbose=true;


	/**
	 *  constructors for different cases: with/out outliers, with/out selective constraints
	 */
	public AtlasEMSegmentation(float[][][][] images_, DemonToadDeformableAtlas atlas_,
			int nx_, int ny_, int nz_, int nc_, float res_,
			float smooth_, 
			String mixingMode, 
			ViewUserInterface UI_, ViewJProgressBar bar_) {
		images = images_;
		atlas = atlas_;
		nx = nx_;
		ny = ny_;
		nz = nz_;
		nc = nc_;
		res =res_;
		classes = atlas.getNumber();
		templateLabel = atlas.getLabels();

		// coefficient (assumes normalized imagess)
		smoothing = smooth_/6.0f/classes;
		System.out.println("Smoothing is : " + smoothing);

		mixing_coefficients_type = mixingMode;

		UI = UI_;
		progressBar = bar_;

		useField = false;
		field = new float[nc][][][];
		for (int c=0;c<nc;c++) field[c] = null;

		// init all the arrays
		try {
			mems = new float[nx][ny][nz][classes];
			classification = new byte[nx][ny][nz];
			mask = new boolean[nx][ny][nz];
			Ihigh = new float[nc];
			Iscale = new float[nc];
			centroid = new float[nc][classes];
			covariance = new float[classes][nc][nc];
			mixingCoef = new float[classes];
			atlas_prior = new float[classes][nx][ny][nz];
			segmentation = new byte[nx][ny][nz]; 
		} catch (OutOfMemoryError e){
			isWorking = false;
			finalize();
			System.out.println(e.getMessage());
			return;
		}
		isWorking = true;




		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			segmentation[x][y][z] = EMPTY;
			for (byte k=0;k<classes;k++) {
				mems[x][y][z][k] = 0.0f;
			}
		}
		for (int c=0;c<nc;c++) {
			for (int k=0;k<classes;k++) {
				centroid[c][k] = 0.0f;
			}
			Ihigh[c] = 1.0f;
			Iscale[c] = 1.0f;
		}

		for (int k=0;k<classes;k++) {
			for (int c1=0; c1<nc; c1++) for (int c2=0; c2<nc; c2++) 
				if (c1==c2){
					covariance[k][c1][c2] =1.0f;
				}else{
					covariance[k][c1][c2] = 0.0f;
				}
		}


		//Computing BackGround Mask

		for (int x =0 ; x<nx; x++) for (int y =0 ; y<ny; y++) for (int z =0 ; z<nz; z++){
			mask[x][y][z] = true;
			for (int c=0; c<nc; c++)
				mask[x][y][z] = mask[x][y][z] && (images[c][x][y][z] > 0.0f);

		}
		ModelImage temp_mask = new ModelImage(ModelStorageBase.BOOLEAN, new int[]{nx,ny,nz},
		"lesions");

		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z]){
			temp_mask.set(x,y,z,true);
		}else
			temp_mask.set(x,y,z,false);
		//Fill holes
		AlgorithmMorphology25D morph25 = new AlgorithmMorphology25D(temp_mask, 1, 1,
				13, 1, 1, 1, 1,
				true);
		morph25.runAlgorithm();
		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++)
			if (temp_mask.getBoolean(x,y,z))
				mask[x][y][z] = true;
			else
				mask[x][y][z] = false;
		morph25.finalize();morph25=null;
		temp_mask.disposeLocal();temp_mask=null;

		if (debug) ViewUserInterface.getReference().setGlobalDataText("initialization\n");
	}

	public void finalize() {
		images = null;
		mems = null;
		centroid = null;
		covariance = null;

	}

	/**
	 *	clean up the computation arrays
	 */
	public final void cleanUp() {
		images = null;
		System.gc();
	}

	public final float[] getCentroids(int c) { return centroid[c]; }
	public final float[][] getCentroids() { return centroid; }
	public final void setCentroids(float[][] cent) { centroid = cent; }

	public final float[][][][] getMemberships() { return mems; }
	public final void setMemberships(float[][][][] m_) { 

		for (int x =0; x<nx; x++)for (int y =0; y<ny; y++)for (int z =0; z<nz; z++) for (int k=0; k<classes; k++)
			mems[x][y][z][k] = m_[x][y][z][k]; 
	}



	public final byte[][][] getSegmentation(){return segmentation;}
	public final void setmixingCoef ( float[] mix) { mixingCoef = mix; }

	public final void initializePriors() {

		for (int k =0; k<classes; k++)
			mixingCoef[k] = (float)1/(classes);

	}

	public final void setCovariance(float[][][] cov) {covariance = cov; }
	public final void setCovariance(float[][] cov, int k) {covariance[k] = cov; }
	public final void setDiagCovariance(float[][] diag_){
		float[][] diag = new float[nc][classes];
		diag = diag_;
		for (int k=0; k<classes; k++)
			for (int c1=0; c1<nc; c1++)
				for (int c2=0; c2<nc; c2++){
					if (c1==c2)
						covariance[k][c1][c2] = diag[c1][k];
					else
						covariance[k][c1][c2] = 0.0f ;
				}
	}

	public final float[][][] getCovariance () { return covariance; }


	public final byte[][][] getClassification() { return classification; }


	public final void setIntensityMax(float[] I) { Ihigh = I; }
	public final void setIntensityScale(float[] I) { Iscale = I; }

	public final void resetClassification() { 
		classification = atlas.getTemplate();
	}

	/** add inhomogeneity correction */
	public final void addInhomogeneityCorrection(float[][][] field_, int c) {
		field[c] = field_;
		useField = true;
	}

	public final boolean isWorking() { return isWorking; }
	public final boolean isCompleted() { return isCompleted; }

	public final float[][][][] getAtlasPriors() { return atlas_prior;}
	
	public float[][][][] exportAtlasPriors(){
		float[][][][] temp = new float[classes][nx][ny][nz];
		for (int k=0; k<classes; k++) for (int x = 0; x<nx; x++) for (int y = 0; y<ny; y++) for (int z = 0; z<nz; z++)
			temp[k][x][y][z] = atlas_prior[k][x][y][z];
		return temp;
	}

	public float[][][] exportAtlasPriors(int k){
		float[][][] temp = new float[nx][ny][nz];
		for (int x = 0; x<nx; x++) for (int y = 0; y<ny; y++) for (int z = 0; z<nz; z++)
			temp[x][y][z] = atlas_prior[k][x][y][z];
		return temp;
	}
	

	

	/** 
	 *  Compute EM membership function
	 */
	final public float[] computeMemberships() {

		float den;
		float ngb;
		float change,distance;
		float[] energy = new float[classes];
		float Energy = 0.0f;
		distance =0;

		// main loop

		float[] oldMems = new float[classes]; // Hold the previous values of memberships for a given voxel

		long currentTime = 0;;

		//Computing inverse covarianc matrices
		System.out.print("Computing inverse Covariance...");
		float[][][] inverse = new float[classes][nc][nc];
		float[]  det  =new float[classes];
		for (int k = 0; k<classes; k++){
			det[k] = Math.abs(Numerics.invertMatrix(covariance[k], inverse[k],nc));
		}
		System.out.print("Done!\n");



		float exponent = 0.0f;  //argument of exp() for computing memberships
		float temp =0.0f;
		currentTime = System.currentTimeMillis();
		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z]) {
			den = 1e-30f;

			for (int k=0;k<classes;k++){
				//Saving Previous Membership to compare with new ones 			
				oldMems[k] = mems[x][y][z][k];

				// Energy of Likelihood function for each class
				if ((mixing_coefficients_type.equals("Atlas")||mixing_coefficients_type.startsWith("Adaptive")) )
					if (atlas_prior[k][x][y][z]!=0)
						energy[k] = (float)(-(nc/2.0f) * Log2PI -0.5 * Math.log(det[k]) + Math.log(atlas_prior[k][x][y][z]));
					else
						energy[k]=0;
				else
					energy[k] = (float)(-(nc/2.0f) * Log2PI-0.5 * Math.log(det[k]) + Math.log(mixingCoef[k]));

				if (mixing_coefficients_type.equals("Equal") || mixing_coefficients_type.equals("Class dependent")|| atlas_prior[k][x][y][z]>0) {

					exponent =  0.0f;

					//Compute (x-mu)^{T} Cov^{-1} (x-mu)

					for (int c1= 0; c1<nc; c1++){
						temp = 0.0f;
						for (int c2 = 0; c2<nc; c2++)
							if (useField){
								temp += inverse[k][c1][c2] *
								(field[c2][x][y][z]*images[c2][x][y][z] - centroid[c2][k]);
							}else
								temp += inverse[k][c1][c2] *
								(images[c2][x][y][z] - centroid[c2][k]);
						if (useField)
							exponent += temp 
							*(field[c1][x][y][z]*images[c1][x][y][z] - centroid[c1][k]);
						else
							exponent += temp 
							*(images[c1][x][y][z] - centroid[c1][k]);
					}	
					exponent /= 2.0f;


					//spatial smoothing
					if (smoothing > 0.0f && mask[x][y][z]) { 
						ngb = 0.0f;

						// case by case	: X+
						for (int m=0;m<classes;m++) if (m!=k) 
							ngb += mems[x+1][y][z][m];
						// case by case	: X-
						for (int m=0;m<classes;m++) if (m!=k)
							ngb += mems[x-1][y][z][m];
						// case by case	: Y+
						for (int m=0;m<classes;m++) if (m!=k) 
							ngb += mems[x][y+1][z][m];
						// case by case	: Y-
						for (int m=0;m<classes;m++) if (m!=k) 
							ngb += mems[x][y-1][z][m];
						// case by case	: Z+
						for (int m=0;m<classes;m++) if (m!=k) 
							ngb += mems[x][y][z+1][m];
						// case by case	: Z-
						for (int m=0;m<classes;m++) if (m!=k) 
							ngb += mems[x][y][z-1][m];

						exponent += smoothing*ngb;
					}

					energy[k] -= exponent;


					if (mixing_coefficients_type.equals("Atlas")||mixing_coefficients_type.startsWith("Adaptive")) 
						mems[x][y][z][k] = (float)(Math.exp(-exponent)*atlas_prior[k][x][y][z]/Math.sqrt(det[k]));
					else
						mems[x][y][z][k] = (float)(Math.exp(-exponent)*mixingCoef[k]/Math.sqrt(det[k]));
					den += mems[x][y][z][k];
				} else {
					mems[x][y][z][k] = 0.0f;
				}
			}

			// normalization
			if (den>0.0f){ 
				for (int k=0;k<classes;k++) {
					mems[x][y][z][k] = mems[x][y][z][k]/den;
				}
			}else{
				for (int k=0;k<classes;k++) 
					mems[x][y][z][k] = 0.0f;
			}

			for (int k=0; k<classes; k++){
				Energy += (energy[k]+Math.log(mems[x][y][z][k]) *oldMems[k]);
				change = Math.abs(mems[x][y][z][k]-oldMems[k]);
				if (change > distance) distance = change;
			}


		}else{

			for (int k=0; k<classes; k++)
				mems[x][y][z][k] = 0.0f;
		}


		MedicUtilPublic.displayMessage("Max membership change: " + distance+"\n");

		inverse = null;

		updateSegmentation();
		System.out.println("Posteriors time :" + ((System.currentTimeMillis()-currentTime)/1000 + " seconds"));
		return new float[]{Energy,distance,(float)Z};

	}// computeMembershipsWithPrior

	public float computeEnergy(){
		float energy=0;
		float[] sigma = new float[classes];
		for (int k=0; k<classes; k++)
			sigma[k]=(float)Math.sqrt(covariance[k][0][0]);

		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z])
			for (int k=0; k<classes; k++) if (atlas_prior[k][x][y][z]!=0 ){
					energy += (float)(mems[x][y][z][k]*(Math.log(atlas_prior[k][x][y][z])-0.5 * Log2PI -0.5 * Math.log(sigma[k])- (images[0][x][y][z]-centroid[0][k] )*(images[0][x][y][z]-centroid[0][k] )/(2*sigma[k]*sigma[k])));
			}
		return energy;
	}

	public void computeAtlas(){

		if (mixing_coefficients_type.equals("Atlas")){
			//get updated atlases
			atlas.precomputeTransformMatrix(1.0f);
			atlas_prior = atlas.generateTransformedShapes();
			float norm;
			for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z]){
				norm=0;
				for (int m=0; m<classes; m++){
					norm+=atlas_prior[m][x][y][z];
				}
				if (norm !=0)
					for (int m=0; m<classes; m++)
						atlas_prior[m][x][y][z]/=norm;
			}
		}
	}



	//Computing Mixing Coefficients

	public final void computePriors(){

		System.out.println("Computing Mixing Coefficients....");

		if (mixing_coefficients_type.equals("Equal")){
			for (int k=0; k<classes; k++)
				mixingCoef[k] =1/classes;
		}else{
			for (int k=0; k<classes; k++){
				mixingCoef[k] =0;
				float den = 0;
				for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z] ){
					den += 1.0f;
					mixingCoef[k] += mems[x][y][z][k];
				}
				if (den != 0) mixingCoef[k] /= den;
				else 
					System.out.println("Error computing mixing coefficients");
			}
		} 
		System.out.println("Mixing Coefficients Done!");
	}
	//Computing Centroids

	public final void computeCentriods(){
		float num,den;
		System.out.println("Computing Centroids....");
		for (int c=0; c<nc; c++){
			for (int k =0; k<classes; k++){
				num = den = 0.0f;
				for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z]){

					if (useField)
						num += mems[x][y][z][k] * field[c][x][y][z] * images[c][x][y][z];
					else
						num += mems[x][y][z][k] * images[c][x][y][z];
					den += mems[x][y][z][k];

				}
				if (den != 0.0f)
					centroid[c][k] =num/den;

			}
		}
		System.out.println("Centroids Done!");
		System.out.println(displayCentroids());
	}
	
	public final float[][] computeDiagonalCovariances() {
		float num,den,mem,diff;
		float[][] stddev = new float[nc][classes];
		for (int c=0;c<nc;c++) {
			for (int k=0;k<classes;k++) {
				num = 0;
				den = 0;
				for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
					mem = mems[x][y][z][k];
					if (useField) diff = (field[c][x][y][z]*images[c][x][y][z]-centroid[c][k]);
					else diff = (images[c][x][y][z]-centroid[c][k]);
					num += (mem*mem+1e-30f)*diff*diff;
					den += (mem*mem+1e-30f);
				}
				if (den>0.0) 
					stddev[c][k] = (float)Math.sqrt(num/den);
				else 
					stddev[c][k] = 0.0f;
			}
		}

		return stddev;
    } // computeVariances
	
	//Compute Covariance Matrix
	public final void computeCovariance(){
		float den;
		double[][] temp = new double[nc][nc];
		double[][] tempL = new double[nc][nc];
		System.out.println("Computing Covariance Matrices...");
		for (int k =0; k<classes; k++){
			den = 0.0f;
			//make covariance matrix zero
			for (int c1=0; c1<nc; c1++)
				for (int c2=0; c2<nc; c2++){
					covariance[k][c1][c2] = 0.0f;
					temp[c1][c2] = 0.0;
					tempL[c1][c2] = 0.0;

				}
			//Computing Covariance Matrix for each class	


			for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z]){
				for (int c1=0; c1<nc; c1++)
					for (int c2=0; c2<nc; c2++){
						if (useField)
							covariance[k][c1][c2] += (images[c1][x][y][z]*field[c1][x][y][z] - centroid[c1][k])
							*(images[c2][x][y][z]*field[c2][x][y][z] - centroid[c2][k])
							*mems[x][y][z][k];
						else
							covariance[k][c1][c2] += (images[c1][x][y][z] - centroid[c1][k])
							*(images[c2][x][y][z] - centroid[c2][k])
							*mems[x][y][z][k];
						den += mems[x][y][z][k];

					}
			}
			for (int c1=0; c1<nc; c1++)
				for (int c2=0; c2<nc; c2++){
					if (den != 0.0f)
						covariance[k][c1][c2] /= den;
				}

		}
		System.out.println("Covariance Matrices  Done!");
	}





	/**
	 *	output the centroids
	 */
	public final String displayCentroids() {
		String output = "centroids \n";
		for (int c=0;c<nc;c++) {
			for (int k=0;k<classes;k++) output +=" | "+centroid[c][k] * Iscale[c];
			output += "\n";
			for (int k=0;k<classes;k++) output +=" | "+centroid[c][k] ;
			output += "\n";
		}
		return output;
	}


	public final String displayMixingCoefficients(){
		String output = "Mixing Coef : \n";
		for (int k =0; k<classes; k++)
			output += mixingCoef[k] + " | ";
		return output;

	}
	public final String displayCovariance(){
		String output = "Covariance : \n";
		for (int k=0;k<classes;k++){
			for (int c1=0;c1<nc;c1++){
				for (int c2=0;c2<nc;c2++)
					output += covariance[k][c1][c2] + "  ";
				output += "\n";
			}
			output += "\n";
		}
		//System.out.println(output);
		return output;
	}



	/**
	 *	compute the classification from the labels 
	 */
	public final int updateClassificationFromLabels() {
		int classDiff = 0;
		byte prev;

		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			prev = classification[x][y][z];
			classification[x][y][z] = templateLabel[segmentation[x][y][z]];

			if (prev != classification[x][y][z]) classDiff++;		
		}
		return classDiff;
	}



	/**
	 *	compute the classification from the memberships
	 */
	public final void updateClassificationFromMemberships() {


		int best;
		float bestMem;

		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			classification[x][y][z] = 0;
			best = -1;
			bestMem = 0;
			for (int k=0;k<classes;k++) if (mems[x][y][z][k]>bestMem) {
				best = k;
				bestMem = mems[x][y][z][k];
			}
			if (best == -1)
				classification[x][y][z]=0;
			else 
				classification[x][y][z] = templateLabel[best];

		}

	}



	public final void updateSegmentation(){

		int best;

		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			segmentation[x][y][z] = -1;
			best = 0;
			for (int k=0;k<classes;k++) if (mems[x][y][z][k]>mems[x][y][z][best]) {
				best = k;
			}
			if (mems[x][y][z][best]>0) segmentation[x][y][z] = (byte)best;

		}

	}

	/** 
	 *	export membership functions 
	 */
	public final float[][][][] exportMemberships() {
		float[][][][]	Mems = new float[classes][nx][ny][nz];

		for (int k=0;k<classes;k++) {
			for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
				Mems[k][x][y][z] = mems[x][y][z][k];
			}
		}
		return Mems;
	} // exportMemberships

	public final byte[][][] exportClassification() {
		byte[][][]	Classif = new byte[nx][ny][nz];

		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			Classif[x][y][z] = templateLabel[segmentation[x][y][z]];
		}
		return Classif;
	} // exportClassification
	/**  generate atlas image from information
	 */
	final public byte[][][] generateClassification() {
		float max;
		int best=0;
		byte[][][] img = new byte[nx][ny][nz];

		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			// compute each class probability : attribute the highest
			max = 0; best = -1;
			for (int k=0;k<classes;k++) {
				if (mems[x][y][z][k]>max) {
					best = k;
					max = mems[x][y][z][k];
				}
			}
			if (best>-1) img[x][y][z] = templateLabel[best];
			else img[x][y][z] = 0;
		}
		return img;
	}

	public final float[][][][] generateDuraRemovalOutputs(){

		float[][][][]	Mems = new float[3][nx][ny][nz];

		


		// WM (everything else except background and cerebellum-GM 
		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			Mems[0][x][y][z] = 0.0f;
			Mems[1][x][y][z]= 0.0f;
			Mems[2][x][y][z] = 0.0f;
		}
		for (int k=0;k<classes;k++) {
			if (!atlas.getNames()[k].equals("SulcalCSF") 
					&& !atlas.getNames()[k].equals("Sulcal-CSF")
					&& !atlas.getNames()[k].equals("CorticalGM") 
					&& !atlas.getNames()[k].equals("Cerebrum-GM") 
					&& !atlas.getNames()[k].equals("CerebralGM")
					&& !atlas.getNames()[k].equals("Cerebellum-GM")
					&& !atlas.getNames()[k].equals("CerebellarGM")
					&& !atlas.getNames()[k].equals("GM")
					&& !atlas.getNames()[k].equals("Background")) {
				for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z]){
					Mems[1][x][y][z] += (mems[x][y][z][k]);
					if (segmentation[x][y][z]==k) Mems[2][x][y][z] = 1.0f;
				}
			}
		}

		// GM (Cerebellum and Cortical)
		for (int k=0;k<classes;k++) {
			if (atlas.getNames()[k].equals("CorticalGM") 
					|| atlas.getNames()[k].equals("Cerebrum-GM") 
					|| atlas.getNames()[k].equals("CerebralGM")
					|| atlas.getNames()[k].equals("Cerebellum-GM")
					|| atlas.getNames()[k].equals("GM")
					|| atlas.getNames()[k].equals("CerebellarGM")) {
				for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) if (mask[x][y][z]){
					Mems[0][x][y][z] += mems[x][y][z][k];
					//include all the cerebellum in WM mask
					if (segmentation[x][y][z]== atlas.getNameID("Cerebellum-GM")
							|| segmentation[x][y][z]== atlas.getNameID("CerebellarGM"))
						Mems[2][x][y][z] = 1.0f;
				}
			}
		}

		ModelImage WMmask = new ModelImage(ModelStorageBase.UBYTE, new int[]{nx,ny,nz},
		"lesions");

		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++){
			WMmask.set(x,y,z,(byte)Mems[2][x][y][z]);
		}
		//Fill holes
		AlgorithmMorphology25D morph25 = new AlgorithmMorphology25D(WMmask, 1, 1,
				13, 1, 1, 10, 1,
				true);
		morph25.runAlgorithm();
		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) { 
			if (WMmask.getShort(x,y,z)==1) Mems[2][x][y][z]=255.0f;
			for (int k=0;k<2;k++)
				Mems[k][x][y][z]*=255.0f;
		}
		morph25.finalize();morph25=null;
		WMmask.disposeLocal();WMmask=null;
		return Mems;
	} // generateDuraRemovalOutputs

}




