package edu.jhmi.rad.medic.algorithms;

import edu.jhmi.rad.medic.dialogs.*;
import edu.jhmi.rad.medic.methods.*;
import edu.jhmi.rad.medic.utilities.MedicUtilPublic;
import gov.nih.mipav.model.algorithms.*;
import gov.nih.mipav.model.structures.*;
import gov.nih.mipav.view.*;
import java.awt.*;
import java.awt.event.*;
import java.io.*;
import java.lang.*;
import java.util.*;
import javax.swing.*;
import javax.swing.border.*;
import javax.swing.event.*;
 
/**
 *
 *   Expectation-Maximisation method for tissue segmentation
 *   This one estimates a separate variance for each class
 *   <p>
 *	 
 *
 *	@version    November 2006
 *	@author     Pilou Bazin
 *  @see 		JDialogTissueSegmentation
 *  @see		SegmentationEMpv
 *		
 *
*/
public class AlgorithmEMpvSegmentation extends AlgorithmBase {

    // Fuzzy images require 1 image for each class
    // Hard images 1 image with assigned clusters
    private ModelImage          destImage[];
    private ModelImage          srcImage;
    private int                 destNum=0;
    private ViewUserInterface   userInterface;
    private float				Imin;
    private float				Imax;
    
     // information on the parts of the image to process
    private     float       backThreshold;
    private     boolean     cropBackground;
	private		boolean		useRelative;
    private		boolean		adaptiveSmoothing = false;
    private     int         nxmin,nymin,nzmin,nxmax,nymax,nzmax;
    private     int         x0,xN,y0,yN,z0,zN;
    private     int         X0,XN,Y0,YN,Z0,ZN;
    private     int         mx,my,mz;

    // image size
	private int         clusters, classes; // classes include the outliers, clusters does not
	private int 		nx,ny,nz,dim;

    // algorithm parameters
	private		String		initMode;
	private		String		outputType;
	
	private 	float 		smoothing;
	private     int 		iterations;
	private     float   	maxDistance;
	
	private		boolean		addOutliers;
	private		float		outlier;        
	
	private		int			correctInhomogeneity;
	private		int			fieldDegree;
	
	private		boolean		useEdges;
	private		float		edgeSmoothness, edgeContrast, edgePrior;

	private     boolean		verbose = true;
	private     boolean		debug = false;
	
    /**
    *	Constructor for 3D images in which changes are placed in a predetermined destination image.
    *   @param destImg_      Image model where result image is to stored.
    *   @param srcImg_       Source image model.
    */
	public AlgorithmEMpvSegmentation(ModelImage[] destImg_, ModelImage srcImg_, int destNum_,
							String init_, String output_,
							int nClasses_, int nIterMax_, float distMax_, float smooth_,
                            boolean addOut_, float outVal_, 
							String correct_, int poly_,
							boolean edges_, float eSmooth_, float eContr_, float ePr_,
                            boolean cropBack_, float cropVal_, boolean useRel_, boolean adaptSmooth_) {
								
        super(null, srcImg_);
 	srcImage = srcImg_;
        destImage = destImg_;
        userInterface = ViewUserInterface.getReference();
		destNum = destNum_;
		
		initMode = init_;
        outputType = output_;
        
        cropBackground = cropBack_;
		backThreshold = cropVal_;
        useRelative = useRel_;
		adaptiveSmoothing = adaptSmooth_;
         
		addOutliers = addOut_;
		outlier = outVal_;
		        
		clusters = nClasses_;
		classes = nClasses_;
        if (addOutliers) classes++;
		
		iterations = nIterMax_;
        maxDistance = distMax_;
		smoothing = smooth_;
        
		if (correct_.equals("image")) 
			correctInhomogeneity = InhomogeneityCorrection.IMAGE;
		else if (correct_.equals("centroids")) 
			correctInhomogeneity = InhomogeneityCorrection.CENTROIDS;
		else if (correct_.equals("separate")) 
			correctInhomogeneity = InhomogeneityCorrection.SEPARATE;
		else 
			correctInhomogeneity = InhomogeneityCorrection.NONE;
		
		if (correctInhomogeneity!=InhomogeneityCorrection.NONE) fieldDegree = poly_;
		else fieldDegree = 0;
		
		useEdges = edges_;
		edgeSmoothness = eSmooth_;
		edgeContrast = eContr_;
		edgePrior = ePr_;
        
		if (debug) {
			MedicUtilPublic.displayMessage("image: \n");
			MedicUtilPublic.displayMessage(""+srcImage.getImageName()+"\n");
		}
	}

    /**
    *	Prepares this class for destruction.
    */
	public void finalize(){
	    destImage   = null;
	    srcImage    = null;
        System.gc();
        super.finalize();
	}

    /**
    *	Constructs a string of the contruction parameters and outputs the string to the messsage frame if the logging
    *   procedure is turned on.
    *	@param destinationFlag	If true the log includes the name of the destination flag.
    */
	/*
    private void constructLog(boolean destinationFlag) {
        if ( destinationFlag == false) {
            historyString = new String( "EMpvSeg(" + ")");
        }
        else  {
            historyString = new String( "EMpvSeg(" + ")");
        }
        historyString += "\n";  // append a newline onto the string
        writeLog();
    }
	*/

    /**
    *   Starts the algorithm.
    */
	public void runAlgorithm() {

        if (srcImage  == null) {
            displayError("Source Image is null");
            //notifyListeners(this);
            return;
        }
        if (destImage  == null) {
            displayError("Destination Image is null");
            //notifyListeners(this);
            return;
        }

        // start the timer to compute the elapsed time
        setStartTime();

        if (destImage != null){     // if there exists a destination image
            //constructLog(true);
			
			float[] buffer;
            if (srcImage.getNDims() == 2) {
			   try {
					// image length is length in 2 dims
					int length = srcImage.getExtents()[0] 
                                * srcImage.getExtents()[1];
                    
                    // retrieve all data: for each volume
                    buffer = new float[length];
                    srcImage.exportData(0,length, buffer); // locks and releases lock
				} catch (IOException error) {
					buffer = null;
					errorCleanUp("Algorithm: source image locked", true);
					return;
				} catch (OutOfMemoryError e){
					buffer = null;
					errorCleanUp("Algorithm: Out of memory creating process buffer", true);
					return;
				}
		
				// init dimensions
				nx = srcImage.getExtents()[0];
				ny = srcImage.getExtents()[1];
				nz = 1;
				dim = 2;
				
                // main algorithm
				calcSegmentation(buffer);
				
			} else if (srcImage.getNDims() == 3) {
			   try {
					// image length is length in 3 dims
					int length = srcImage.getExtents()[0] 
                                * srcImage.getExtents()[1] * srcImage.getExtents()[2];
                    
                    // retrieve all data: for each volume
                    buffer = new float[length];
                    srcImage.exportData(0,length, buffer); // locks and releases lock
 				} catch (IOException error) {
					buffer = null;
					errorCleanUp("Algorithm: source image locked", true);
					return;
				} catch (OutOfMemoryError e){
					buffer = null;
					errorCleanUp("Algorithm: Out of memory creating process buffer", true);
					return;
				}
		
				// init dimensions
				nx = srcImage.getExtents()[0];
				ny = srcImage.getExtents()[1];
				nz = srcImage.getExtents()[2];
				dim = 3;
				
                // main algorithm
				calcSegmentation(buffer);
				
			}
        }

        // compute the elapsed time
        computeElapsedTime();

    } // end run()

    /**
    *	produces a RFCM fuzzy segmentation of the input images
    */
    private void calcSegmentation(float img[]){
		boolean[][][]			objectMask;
		float[][][]				image;
		int                     x,y,z,t,k,m,l;
		float                   dist;
		int                     indx;
        float[]                 buffer;
        byte[]                 	bytebuffer;
		int[]					id;
        SegmentationEMpv			segmentation;
		InhomogeneityCorrection	correction;
		ClusterSearch			search;

		int mod;
		int progress;

		boolean stop,stopReg;
		int n,Nt;
		float distance;
		
        if (verbose) MedicUtilPublic.displayMessage("\n -- EMpv Segmentation --\n");
		
		fireProgressStateChanged("initialization...");
        
		// increase the dimension to make boundaries
		nx = nx+2; ny = ny+2; nz = nz+2;
		// pre-processing : expand boundaries, so that we don't have to worry for them
        try {
            image = expandImage(img);
            img = null;
        } catch (OutOfMemoryError e) {
            img = null; image = null;
            errorCleanUp("Algorithm: Out of memory creating process buffer", true);
            setCompleted(false);
            return;
        }

		srcImage.calcMinMax();
		Imin = (float)srcImage.getMin();
		Imax = (float)srcImage.getMax();
		
		// parameter normalization
		if (useRelative) {
			//if (!adaptiveSmoothing) smoothing = smoothing*(Imax-Imin)*(Imax-Imin);
			backThreshold = Imin + backThreshold*(Imax-Imin);
			outlier = outlier*(Imax-Imin);
		} else {
			//if (!adaptiveSmoothing) smoothing = smoothing*smoothing;
		}
		
		// if use mask, process only parts where intensity > min (mask stripped areas)
        objectMask = createObjectMask(image, backThreshold);
		computeProcessingBoundaries(objectMask);
		mx = xN-x0+1+2;
        my = yN-y0+1+2;
        mz = zN-z0+1+2;
        // rescale images
		if (debug) MedicUtilPublic.displayMessage("shrink images..\n");
		image = reduceImageSize(image); 
        objectMask = reduceImageSize(objectMask); 
        if (debug) MedicUtilPublic.displayMessage("new dimensions: "+mx+"x"+my+"x"+mz+"\n");

        //***************************************//
		//*          MAIN ALGORITHM             *//
		//***************************************//
 
		// record time
		long start_time = System.currentTimeMillis();
		long inner_loop_time;

		// create the appropriate segmentation algorithm
		segmentation = new SegmentationEMpv(image, objectMask, mx, my, mz,
											classes, clusters, smoothing, outlier,
											userInterface, getProgressChangeListener());

		// create auxiliary algorithms
		if (correctInhomogeneity!=InhomogeneityCorrection.NONE) {
			// create the inhomogeneity correction algorithm
			correction = new InhomogeneityCorrection(fieldDegree,  correctInhomogeneity,
													image, objectMask, 
													segmentation.getMemberships(),
													segmentation.getCentroids(),
													mx,my,mz,1.0f,1.0f,1.0f,
													classes,clusters,dim,
													getProgressChangeListener());
			// integrate it into the segmentation
			if (correctInhomogeneity==InhomogeneityCorrection.SEPARATE) {
				segmentation.addSeparateInhomogeneityCorrection(correction.getFields());
			} else {
				segmentation.addInhomogeneityCorrection(correction.getField(),correctInhomogeneity);
			}
		} else correction = null;
		
        // initial guess for centroids
		if (initMode.equals("modes")) {
			search = new ClusterSearch(clusters);
			search.computeHistogram(image,objectMask,mx,my,mz);
			search.findCentroids();
		
			segmentation.importCentroids(search.exportCentroids());

			search.finalize();
			search = null;	
		} else if (initMode.equals("range")) {
			float[] cent = new float[clusters];
			cent[0] = Imin + 0.5f*(Imax-Imin)/(float)clusters;
			for (k=1;k<clusters;k++)
				cent[k] = cent[k-1] + (Imax-Imin)/(float)clusters;
						
			segmentation.importCentroids(cent);
		} else if (initMode.equals("manual")) {
		}
		// initial guess for the variance
		segmentation.setVariance( ((Imax-Imin)/(float)(clusters+1))*((Imax-Imin)/(float)(clusters+1)) );
		segmentation.setPriors();
        
        // initialize the segmentation
		segmentation.setMRF(0.0f);
		segmentation.computeMemberships();
		if (adaptiveSmoothing) {
			segmentation.setMRF(segmentation.computeEnergyFactor(smoothing));
			if (verbose) MedicUtilPublic.displayMessage("smoothing "+(Math.sqrt(segmentation.getMRF())/(Imax-Imin))+"\n");
		} else {
			segmentation.setMRF(smoothing);
			if (verbose) MedicUtilPublic.displayMessage("smoothing "+segmentation.getMRF()+"\n");
		}
		// main iterations: compute the classes on the image
		
		// two steps: first iterates until mild convergence without inhomogeneity correction
		if (debug) MedicUtilPublic.displayMessage("adapting inhomogeneity..\n");
					
		float maxDist0 = (float)Math.sqrt(maxDistance);
		int iter0 = (int)Math.sqrt(iterations);
		distance = 0.0f;
		Nt = 0;
		n = 0;
		int Niterations = 1;
		for (int d=0;(d<=fieldDegree) && (Niterations<=iterations);d++) {
			stop = false;
			n = 0;
			// for stability sake: low smoothing at first
			if (!adaptiveSmoothing) segmentation.setMRF(smoothing*d/(float)fieldDegree);
			while ((!stop) && (!threadStopped)) {
				fireProgressStateChanged("Initial iteration " + Niterations + " (max: " + distance + ")");
				fireProgressStateChanged(Math.round( (float)n/(float)iterations)*100);
				
				if (verbose) MedicUtilPublic.displayMessage("Initial iteration " + Niterations + " (max: " + distance + ")\n");
				
				// update external fields
				if (correctInhomogeneity!=InhomogeneityCorrection.NONE) {
					// fixed degree
					correction.computeCorrectionField(d);
				}
				
				// update centroids
				segmentation.computeCentroids();
				//  Assume variances are initially equal        
				segmentation.computeVariance(); 
				// update membership
				distance = segmentation.computeMemberships();
				
				if (adaptiveSmoothing) {
					segmentation.setMRF(segmentation.computeEnergyFactor(smoothing));
					if (verbose) MedicUtilPublic.displayMessage("smoothing "+(Math.sqrt(segmentation.getMRF())/(Imax-Imin))+"\n");
				}
				// check for segmentation convergence 
				n++;
				Niterations++;
				if (n >= iter0) stop = true;
				if (Niterations >= iterations) stop = true;
				if (distance < maxDist0) stop = true;            
				// if (distance < maxDistance) stop = true;            
			}
			if (d<fieldDegree) Nt += n;
		}
		
 		// Loop allowing variances to vary separately for each class first
		if (debug) MedicUtilPublic.displayMessage("main loop..\n");
			
		stop = false;
		if (Niterations >= iterations) stop = true;
		while ((!stop) && (!threadStopped)) {
			fireProgressStateChanged("Final iteration " + Niterations + " (max: " + distance + ")");
			fireProgressStateChanged(Math.round( (float)n/(float)iterations)*100);
			
			if (verbose) MedicUtilPublic.displayMessage("Final iteration " + Niterations + " (max: " + distance + ")\n");
			

			// update external fields
			if (correctInhomogeneity!=InhomogeneityCorrection.NONE) {
				// fixed degree
				correction.computeCorrectionField(fieldDegree);
			}
			
			// update centroids
			segmentation.computeCentroids();
			segmentation.computeVariances();
        
			// update membership
			distance = segmentation.computeMemberships();
			
			if (adaptiveSmoothing) {
				segmentation.setMRF(segmentation.computeEnergyFactor(smoothing));
				if (verbose) MedicUtilPublic.displayMessage("smoothing "+(Math.sqrt(segmentation.getMRF())/(Imax-Imin))+"\n");
			}
			// check for segmentation convergence 
			n++;
			Niterations++;
			if (Niterations > iterations) stop = true;
			if (distance < maxDistance) stop = true;            
		}


		segmentation.setMRF(smoothing);
		stop = false;
		if (Niterations >= iterations) stop = true;
		while ((!stop) && (!threadStopped)) {
			fireProgressStateChanged("Final iteration " + Niterations + " (max: " + distance + ")");
			fireProgressStateChanged(Math.round( (float)n/(float)iterations)*100);
			
			if (verbose) MedicUtilPublic.displayMessage("Final iteration " + Niterations + " (max: " + distance + ")\n");
			

			// update external fields
			if (correctInhomogeneity!=InhomogeneityCorrection.NONE) {
				// fixed degree
				correction.computeCorrectionField(fieldDegree);
			}
			
			// update centroids
			segmentation.computePriors();
			segmentation.computeCentroids();
			segmentation.computeVariances();
        
			// update membership
			distance = segmentation.computeMemberships();
			
			if (adaptiveSmoothing) {
				segmentation.setMRF(segmentation.computeEnergyFactor(smoothing));
				if (verbose) MedicUtilPublic.displayMessage("smoothing "+(Math.sqrt(segmentation.getMRF())/(Imax-Imin))+"\n");
			}
			// check for segmentation convergence 
			n++;
			Niterations++;
			if (Niterations > iterations) stop = true;
			if (distance < maxDistance) stop = true;            
		}
		Nt += n;

		// order the classes in increasing order
		id = segmentation.computeCentroidOrder();
        
        // debug
		if (verbose) MedicUtilPublic.displayMessage("total iterations: "+Nt+", total time: (milliseconds): " + (System.currentTimeMillis()-start_time)); 
		
        // extract results (segmentation and/or classes) and store them in destImage[]
		fireProgressStateChanged("creating result images...");
		
        int Ndest = 0;
        try {            
			if (!outputType.equals("hard_segmentation")) {
				fireProgressStateChanged("memberships...");
				for (k=0;k<classes;k++) {
					buffer = bufferFromImage(segmentation.exportMemberships()[k]);
					destImage[id[k+1]-1].importData(0, buffer, true);
					Ndest++;
				}
			}
			if (!outputType.equals("fuzzy_segmentation")) {
				fireProgressStateChanged("final classification...");
				bytebuffer = orderedBufferFromImage(segmentation.exportHardClassification(), id);
                destImage[Ndest].importData(0, bytebuffer, true);
				Ndest++;
			}
			segmentation.finalize();
			segmentation = null;
			if (outputType.equals("all_result_images")) {
				if ( (correctInhomogeneity!=InhomogeneityCorrection.NONE)
					&& (correctInhomogeneity!=InhomogeneityCorrection.SEPARATE) ) {
					fireProgressStateChanged("gain field...");
					buffer = bufferFromImage(correction.exportField());
					destImage[Ndest].importData(0, buffer, true);
					Ndest++;
					correction.finalize();
					correction = null;
				} else if (correctInhomogeneity==InhomogeneityCorrection.SEPARATE) {
					fireProgressStateChanged("gain field...");
					for (k=0;k<clusters;k++) {
						buffer = bufferFromImage(correction.exportFields(k));
						destImage[Ndest].importData(0, buffer, true);
						Ndest++;
					}
					correction.finalize();
					correction = null;
				}
			}
			bytebuffer = null;
			buffer = null;
		} catch (OutOfMemoryError e) {
            bytebuffer = null;
			buffer = null;
            errorCleanUp("Algorithm: Out of memory creating hard classification", true);
            finalize();
            setCompleted(false);
            return;
        } catch (IOException error) {
            errorCleanUp("Algorithm: export problem to destImage[]", true);
            finalize();
			setCompleted(false);
            return;
        }
        
        setCompleted(true);
    } // calcSegmentation
	
	/** brings image in correct array */
	private float[][][] expandImage(float[] image) {
		int 		x,y,z;
		float[][][] 	tmp;
		
		tmp = new float[nx][ny][nz];
		for (x=1;x<nx-1;x++)
			for (y=1;y<ny-1;y++)
				for (z=1;z<nz-1;z++)
					tmp[x][y][z] = image[ (x-1) + (nx-2)*(y-1) + (nx-2)*(ny-2)*(z-1) ];
		
		return tmp;
	}
	
	/** creates a mask for unused data */
	private boolean[][][] createObjectMask(float[][][] image, float val) {
		int 		x,y,z;
		boolean[][][]  	objMask;
        boolean useWholeImage = true;

		// uses only values over the threshold, if mask used
		objMask = new boolean[nx][ny][nz];
		for (x=1;x<nx-1;x++)
			for (y=1;y<ny-1;y++)
				for (z=1;z<nz-1;z++) {
                    if (useWholeImage) {
                        if ( (cropBackground) && (image[x][y][z] <= val) )
                            objMask[x][y][z] = false;
                        else
                            objMask[x][y][z] = true;
                    } else if (mask.get((x-1)+(nx-2)*(y-1)+(nx-2)*(ny-2)*(z-1)) ) {
                        if ( (cropBackground) && (image[x][y][z] <= val) )
                            objMask[x][y][z] = false;
                        else
                            objMask[x][y][z] = true;
                    }
                }
		// remove the boundary from the computations
		for (x=0;x<nx;x++)
			for (y=0;y<ny;y++) {
				objMask[x][y][0] = false;
				objMask[x][y][nz-1] = false;
			}
		for (y=0;y<ny;y++)
			for (z=0;z<nz;z++) {
				objMask[0][y][z] = false;
				objMask[nx-1][y][z] = false;
			}
		for (z=0;z<nz;z++)
			for (x=0;x<nx;x++) {
				objMask[x][0][z] = false;
				objMask[x][ny-1][z] = false;
			}

		return objMask;
	} // createObjectMask
    
    /** sets the processing lower and upper boundaries */
    private void computeProcessingBoundaries(boolean[][][] objMask) {
		int 		x,y,z;
        
        x0 = nx;
        xN = 0;
        y0 = ny;
        yN = 0;
        z0 = nz;
        zN = 0;
        for (x=0;x<nx;x++)
			for (y=0;y<ny;y++)
				for (z=0;z<nz;z++) {
                    if (objMask[x][y][z]) {
                        if (x < x0) x0 = x;
                        if (x > xN) xN = x;
                        if (y < y0) y0 = y;
                        if (y > yN) yN = y;
                        if (z < z0) z0 = z;
                        if (z > zN) zN = z;
                    }
                }
				
        // debug
        System.out.print("boundaries: ["+x0+","+xN+"] ["+y0+","+yN+"] ["+z0+","+zN+"]\n");
        
        return;
    }

    /** create smaller image (for saving memory) */
	private float[][][] reduceImageSize(float[][][] image) {
		float[][][] smaller = new float[mx][my][mz];

		for (int x=0;x<mx;x++) for (int y=0;y<my;y++) for (int z=0;z<mz;z++) {
			smaller[x][y][z] = 0.0f;
		}
		for (int x=x0;x<=xN;x++) {
            for (int y=y0;y<=yN;y++) {
                for (int z=z0;z<=zN;z++) {
					smaller[x-x0+1][y-y0+1][z-z0+1] = image[x][y][z];
				}
			}
		}
		return smaller;
	}
	private boolean[][][] reduceImageSize(boolean[][][] image) {
		boolean[][][] smaller = new boolean[mx][my][mz];

		for (int x=0;x<mx;x++) for (int y=0;y<my;y++) for (int z=0;z<mz;z++) {
			smaller[x][y][z] = false;
		}
		for (int x=x0;x<=xN;x++) {
            for (int y=y0;y<=yN;y++) {
                for (int z=z0;z<=zN;z++) {
					smaller[x-x0+1][y-y0+1][z-z0+1] = image[x][y][z];
				}
			}
		}
		return smaller;
	}

	/** retrieve original size from smaller image */
	private float[] bufferFromImage(float[][][] image) {
		float[] larger = new float[(nx-2)*(ny-2)*(nz-2)];

		for (int n=0;n<(nx-2)*(ny-2)*(nz-2);n++)
			larger[n] = 0.0f;
		
		for (int x=x0;x<=xN;x++) {
            for (int y=y0;y<=yN;y++) {
                for (int z=z0;z<=zN;z++) {
					larger[(x-1)+(nx-2)*(y-1)+(nx-2)*(ny-2)*(z-1)] = image[x-x0+1][y-y0+1][z-z0+1];
				}
			}
		}
		return larger;
	}

	/** retrieve original size from smaller image */
	private byte[] orderedBufferFromImage(byte[][][] image, int[] id) {
		byte[] larger = new byte[(nx-2)*(ny-2)*(nz-2)];

		for (int n=0;n<(nx-2)*(ny-2)*(nz-2);n++)
			larger[n] = 0;
		
		for (int x=x0;x<=xN;x++) {
            for (int y=y0;y<=yN;y++) {
                for (int z=z0;z<=zN;z++) {
					larger[(x-1)+(nx-2)*(y-1)+(nx-2)*(ny-2)*(z-1)] = 0;
					for (int k=0;k<classes+1;k++)
						if (image[x-x0+1][y-y0+1][z-z0+1]==id[k]) 
							larger[(x-1)+(nx-2)*(y-1)+(nx-2)*(ny-2)*(z-1)] = (byte)k;
				}
			}
		}
		return larger;
	}

}
