package edu.jhmi.rad.medic.methods;

import java.io.*;
import java.util.*;

import gov.nih.mipav.view.*;
import edu.jhmi.rad.medic.utilities.*;

/**
 *
 *  This algorithm handles tedge adaptation for segmentation algorithms
 *  and for 3D data: computes edges between membership functions 
 *	with different options (EM or FCM-like, linear, etc)
 *
 *	@version    December 2004
 *	@author     Pierre-Louis Bazin
 *	@see		SegmentationFCM
 *
 */
 
public class EdgeAdaptation {
		
	// numerical quantities
	private static final	float   INF=1e30f;
	private static final	float   ZERO=1e-30f;
	
	// data buffers
	private 	float[][][][]		mems;				// membership function
	private 	float[][][][]		edges;  			// edge field (3 directions; forward convention)
	//private 	boolean[][][]		mask;   			// image mask: true for data points
	private		static	int			nx,ny,nz;   		// image dimensions
	
	// parameters
	private 	int 		clusters;   // number of clusters
	private 	int 		classes;    // number of classes in original membership: > clusters if outliers
	private 	float 		smoothing;	// MRF smoothing on the edges
	private 	float 		contrast;	// sharpness of the edges
	private 	float 		prior;		// apriori amount of edges
			
	// computation variables
	
	// computation flags
	private 	boolean 		isWorking;
	private 	boolean 		isCompleted;
	
	// for debug and display
	ViewUserInterface			UI;
    ViewJProgressBar            progressBar;
	static final boolean		debug=true;
	static final boolean		verbose=false;
	
	/**
	 *  constructor
	 *	note: all images passed to the algorithm are just linked, not copied
	 */
	public EdgeAdaptation(float[][][][] mems_,
					int nx_, int ny_, int nz_,
					int classes_, int clusters_,
					float smoothing_, float contrast_, float prior_,
					ViewUserInterface UI_, ViewJProgressBar bar_) {
		
		mems = mems_;
		
		nx = nx_;
		ny = ny_;
		nz = nz_;
		
		classes = classes_;
		clusters = clusters_;
		
		smoothing = smoothing_;
		contrast = contrast_;
		prior = prior_;
		
        UI = UI_;
        progressBar = bar_;
		
		// init all the new arrays
		try {
			edges = new float[3][nx][ny][nz];
		} catch (OutOfMemoryError e){
			isWorking = false;
            finalize();
			System.out.println(e.getMessage());
			return;
		}
		isWorking = true;

		// init values
		for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			for (int n=0;n<3;n++) {
				edges[n][x][y][z] = 1.0f;
			}
		}
		if (debug) MedicUtilPublic.displayMessage("FCM:initialisation\n");
	}

	/** clean-up: destroy membership and centroid arrays */
	public final void finalize() {
		edges = null;
		System.gc();
	}
	
    /** accessor for computed data */ 
    public final float[][][][] getEdges() { return edges; }
    /** accessor for computed data */ 
    public final void setMemberships(float[][][][] mem) { mems = mem; }
	/** change parameters */
    public final void setMRF(float smooth) { smoothing = smooth; }
    /** accessor for computed data */ 
    public final void importEdges(float[][][][] edg) { 
		for (int n=0;n<3;n++) for (int x=0;x<nx;x++) for (int y=0;y<ny;y++) for (int z=0;z<nz;z++) {
			edges[n][x][y][z] = edg[n][x][y][z];
		}
	}
    
	/** computation flags */
	public final boolean isWorking() { return isWorking; }
	/** computation flags */
	public final boolean isCompleted() { return isCompleted; }
	
    /** 
	 *  compute the edge map functions given the memberships
	 */
    final public void computeEdgeMap() {
        float distance,dist;
        int x,y,z,k,m,n,l;
		int dx,dy,dz,dnx,dny,dnz;
        int progress, mod;
        long inner_loop_time;
        float den,num,nup,ngp,ngm;
        float neighbors, ngb;
		
        distance = 0.0f;
		progress = 0;
        mod = nx*ny*nz/100; // mod is 1 percent of length

        inner_loop_time = System.currentTimeMillis();
		for (x=1;x<nx-1;x++) for (y=1;y<ny-1;y++) for (z=1;z<nz-1;z++) {
			progress++;
			if ( (verbose) && (progress%mod==0) )
                progressBar.updateValue(Math.round( (float)progress/(float)mod), false);
			
			// X edges
			
			// "data" term
			num = 0;
			for (k=0;k<clusters;k++) for (m=0;m<clusters;m++) if (m!=k) {
				num += mems[x][y][z][k]*mems[x][y][z][k]
							*mems[x+1][y][z][m]*mems[x+1][y][z][m];
			}
			nup = prior;
				
			// add neighbors smoothing 
			ngp = 0.0f; ngm=0.0f;
			neighbors = 0.0f;
			
			// same direction Y neighbors
			ngb = edges[0][x][y-1][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[0][x][y+1][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// same direction Z neighbors
			ngb = edges[0][x][y][z-1];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[0][x][y][z+1];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// Y direction
			ngb = edges[1][x][y-1][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[1][x][y][z];     ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[1][x+1][y-1][z]; ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[1][x+1][y][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// Z direction
			ngb = edges[2][x][y][z-1];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[2][x][y][z];     ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[2][x+1][y][z-1]; ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[2][x+1][y][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			
			if (neighbors>0.0) {
				ngm = smoothing*ngm/neighbors;
				ngp = smoothing*ngp/neighbors;
			}
			ngb = edges[0][x][y][z];
			num = num + ngm + contrast*(1.0f-ngb)*(1.0f-ngb);
			nup = nup + ngp + contrast*ngb*ngb;

			if (num>ZERO) num = 1.0f/num;
			else num = INF;
			if (nup>ZERO) nup = 1.0f/nup;
			else nup = INF;

			edges[0][x][y][z] = num/(num+nup);
			
			// Y edges
			
			// "data" term
			num = 0;
			for (k=0;k<clusters;k++) for (m=0;m<clusters;m++) if (m!=k) {
				num += mems[x][y][z][k]*mems[x][y][z][k]
							*mems[x][y+1][z][m]*mems[x][y+1][z][m];
			}
			nup = prior;
				
			// add neighbors smoothing 
			ngp = 0.0f; ngm=0.0f;
			neighbors = 0.0f;
			
			// same direction X neighbors
			ngb = edges[1][x-1][y][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[1][x+1][y][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// same direction Z neighbors
			ngb = edges[1][x][y][z-1];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[1][x][y][z+1];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// X direction
			ngb = edges[0][x-1][y][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[0][x][y][z];     ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[0][x-1][y+1][z]; ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[0][x][y+1][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// Z direction
			ngb = edges[2][x][y][z-1];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[2][x][y][z];     ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[2][x][y+1][z-1]; ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[2][x][y+1][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			
			if (neighbors>0.0) {
				ngm = smoothing*ngm/neighbors;
				ngp = smoothing*ngp/neighbors;
			}
			ngb = edges[1][x][y][z];
			num = num + ngm + contrast*(1.0f-ngb)*(1.0f-ngb);
			nup = nup + ngp + contrast*ngb*ngb;

			if (num>ZERO) num = 1.0f/num;
			else num = INF;
			if (nup>ZERO) nup = 1.0f/nup;
			else nup = INF;

			edges[1][x][y][z] = num/(num+nup);

			// Z edges

			// "data" term
			num = 0;
			for (k=0;k<clusters;k++) for (m=0;m<clusters;m++) if (m!=k) {
				num += mems[x][y][z][k]*mems[x][y][z][k]
							*mems[x][y][z+1][m]*mems[x][y][z+1][m];
			}
			nup = prior;
				
			// add neighbors smoothing 
			ngp = 0.0f; ngm=0.0f;
			neighbors = 0.0f;
			
			// same direction Y neighbors
			ngb = edges[2][x][y-1][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[2][x][y+1][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// same direction X neighbors
			ngb = edges[2][x-1][y][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[2][x+1][y][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// Y direction
			ngb = edges[1][x][y-1][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[1][x][y][z];     ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[1][x][y-1][z+1]; ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[1][x][y][z+1];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			// X direction
			ngb = edges[0][x-1][y][z];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[0][x][y][z];     ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[0][x-1][y][z+1]; ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			ngb = edges[0][x][y][z+1];   ngp += ngb*ngb; ngm += (1.0f-ngb)*(1.0f-ngb); neighbors += 1.0f;
			
			if (neighbors>0.0) {
				ngm = smoothing*ngm/neighbors;
				ngp = smoothing*ngp/neighbors;
			}
			ngb = edges[2][x][y][z];
			num = num + ngm + contrast*(1.0f-ngb)*(1.0f-ngb);
			nup = nup + ngp + contrast*ngb*ngb;

			if (num>ZERO) num = 1.0f/num;
			else num = INF;
			if (nup>ZERO) nup = 1.0f/nup;
			else nup = INF;

			edges[2][x][y][z] = num/(num+nup);
		}
        if (debug) System.out.print("inner loop time: (milliseconds): " + (System.currentTimeMillis()-inner_loop_time) +"\n"); 

        return;
    } // computeEdgeMap
    

	/** 
	 *	export edge map
	 */
	public final float[][][][] exportEdges() {
		int 	x,y,z,n;
		float[][][][]	Edges = new float[3][nx][ny][nz];
		
        for (n=0;n<3;n++) for (x=0;x<nx;x++) for (y=0;y<ny;y++) for (z=0;z<nz;z++) {
			Edges[n][x][y][z] = (1.0f-edges[n][x][y][z]);
		}
		return Edges;
	} // exportEdges

}
