/**
 * 
 */
package edu.jhu.ece.iacl.algorithms.skull_strip;

import java.io.IOException;

import edu.jhmi.rad.medic.libraries.Morphology;
import edu.jhmi.rad.medic.libraries.ObjectProcessing;
import edu.jhmi.rad.medic.methods.FastMarching;
import edu.jhmi.rad.medic.utilities.Numerics;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.data.BinaryMinHeap;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataUByte;
import edu.jhu.ece.iacl.jist.structures.image.MaskVolume18;
import edu.jhu.ece.iacl.jist.structures.image.MaskVolume6;
import edu.jhu.ece.iacl.jist.structures.image.VoxelInt;
import edu.jhu.ece.iacl.jist.structures.image.VoxelFloat;
import edu.jhu.ece.iacl.jist.structures.image.VoxelIndexed;

import gov.nih.mipav.model.algorithms.*;
import gov.nih.mipav.model.structures.ModelImage;

/*
 * @author Aaron Carass <aaron_carass@jhu.edu>
 */
public class SPECTRE2009 extends AbstractCalculation {
	public enum ImageModality {
		T1_SPGR, T1_ALT, T1_MPRAGE, T2, FLAIR
	};

	public enum OutputType {
		STRIPPED_IMAGE, STRIPPING_MASK, REGISTRATION_PRIOR, SEGMENTATION_PRIOR, INITIAL_MASK, ERODE_MASK, DESCENT_MASK, ALL
	};

	private static final String cvsversion = "$Revision: 1.1 $";
	public static final String revnum = cvsversion.replace("Revision: ", "")
			.replace("$", "").replace(" ", "");

	public static String get_version() {
		return revnum;
	}

	private ImageDataUByte priorVol;
	private ImageDataUByte segVol;
	private ImageDataUByte marchVol;
	private ImageDataUByte currentVol;
	private ImageDataFloat d0Vol;
	private ImageDataFloat d1Vol;
	private ImageDataFloat d2Vol;
	private ImageDataFloat d3Vol;

	private ImageDataUByte maskVol;

	private ImageModality modality;
	private int CSFmin, CSFmax;
	private int GMmin, GMmax;
	private int WMmin, WMmax;
	private int dir = 1;

	private OutputType outputType;

	private int nx, ny, nz, Nmask;

	private float initPrior = 0.5f;
	private float minPrior = 0.1f;
	private int i_initPrior = 0;
	private int i_minPrior = 0;

	private int radius = 8;
	private int initialErosionSize = 2;
	private int mmcDilationSize = 2;
	private int mmcErosionSize = 2;
	private float rx, ry, rz;

	private int iterations = 0;
	private boolean mask_good = false;

	public void printBar() {
		System.out.println("########################################\n");
	}

	public ImageDataUByte getMask() {
		return maskVol;
	}

	public ImageDataUByte getPrior() {
		return priorVol;
	}

	public ImageDataUByte getSegmentation() {
		return segVol;
	}

	public ImageDataUByte getMarchVol() {
		return marchVol;
	}

	public ImageDataUByte getCurrentVol() {
		return currentVol;
	}

	public ImageDataFloat getd0Vol() {
		return d0Vol;
	}

	public ImageDataFloat getd1Vol() {
		return d1Vol;
	}

	public ImageDataFloat getd2Vol() {
		return d2Vol;
	}

	public ImageDataFloat getd3Vol() {
		return d3Vol;
	}

	public ImageDataFloat solve(ImageData img, ImageData segmentedImg,
			ImageDataUByte probVol, ImageModality modality,
			OutputType outputType, float initPrior, float minPrior,
			int initialErosionSize, int mmcDilationSize, int mmcErosionSize,
			int max_iterations, int Nmask) {
		this.Nmask = Nmask;
		this.nx = img.getRows();
		this.ny = img.getCols();
		this.nz = img.getSlices();
		this.modality = modality;
		this.modality = ImageModality.T1_SPGR;
		this.outputType = outputType;
		this.initPrior = initPrior;
		this.minPrior = minPrior;
		this.i_initPrior = (int) (initPrior * (float) Nmask);
		this.i_minPrior = (int) (minPrior * (float) Nmask);
		this.initialErosionSize = initialErosionSize;
		this.mmcDilationSize = mmcDilationSize;
		this.mmcErosionSize = mmcErosionSize;

		rx = img.getHeader().getDimResolutions()[0];
		ry = img.getHeader().getDimResolutions()[1];
		rz = img.getHeader().getDimResolutions()[2];

		System.out.println("\n");
		printBar();
		System.out.println("SPECTRE2009 " + revnum + "\n");
		printBar();

		System.out.println("ROWS " + nx + " COLS " + ny + " SLICES " + nz
				+ "\nNMASK " + Nmask + "\nMODALITY " + modality
				+ "\nOUTPUT TYPE " + outputType + "\nFLOAT INIT PRIOR "
				+ initPrior + "\nFLOAT MIN PRIOR " + minPrior
				+ "\nINT INIT PRIOR " + i_initPrior + "\nINT MIN PRIOR "
				+ i_minPrior + "\nEROSION " + initialErosionSize + "\nMMC ("
				+ mmcDilationSize + ", " + mmcErosionSize + ")\n" + rx + " "
				+ ry + " " + rz);
		// Fantasm segmentation classes, equivalent to -T flag, spgr by default
		switch (modality) {
		case T1_SPGR:
			CSFmin = 1;
			CSFmax = 2;
			GMmin = 2;
			GMmax = 3;
			WMmin = 2;
			WMmax = 4;
			break;
		case T1_ALT:
			CSFmin = 1;
			CSFmax = 1;
			GMmin = 2;
			GMmax = 2;
			WMmin = 2;
			WMmax = 5;
			break;
		case T1_MPRAGE:
			CSFmin = 1;
			CSFmax = 2;
			GMmin = 2;
			GMmax = 4;
			WMmin = 3;
			WMmax = 5;
			break;
		case FLAIR:
			CSFmin = 1;
			CSFmax = 2;
			GMmin = 2;
			GMmax = 5;
			WMmin = 2;
			WMmax = 4;
			break;
		case T2:
			CSFmin = 1;
			CSFmax = 2;
			GMmin = 2;
			GMmax = 4;
			WMmin = 2;
			WMmax = 3;
			dir = -1;
			break;
		default:
			System.out.println("\nUNKNOWN MODAILTY\n");
			CSFmin = 1;
			CSFmax = 2;
			GMmin = 2;
			GMmax = 4;
			WMmin = 3;
			WMmax = 4;
			break;
		}

		System.out.println("CREATE IMAGE "+img.getHeader().getDimResolutions().length);
		float[][][] image = (new ImageDataFloat(img)).toArray3d();
		System.out.println("CREATE SEGMENTATION "+img.getHeader().getDimResolutions().length);
		byte[][][] segmentation = (new ImageDataUByte(segmentedImg))
				.toArray3d();
		System.out.println("CREATE MASKS ");


		byte[][][] prior = (new ImageDataUByte(probVol)).toArray3d();
		byte[][][] label = new byte[nx][ny][nz];
		float[][][] march = new float[nx][ny][nz];
		byte[][][] march_byte = new byte[nx][ny][nz];

		float[][][] d0 = new float[nx][ny][nz];
		float[][][] d1 = new float[nx][ny][nz];
		float[][][] d2 = new float[nx][ny][nz];
		float[][][] d3 = new float[nx][ny][nz];

		boolean[][][] brain = new boolean[nx][ny][nz];
		int[][][] lb = new int[nx][ny][nz];
		byte[][][] current = new byte[nx][ny][nz];
		// 1. build a probability prior from the masks
		// 1b. if using multiple masks, get the boundary between them
		// Add up probabilities for different masks
		int priorVolume = 0;
		int currentVolume = 0;
		System.out.println("BUILD PRIOR");
		for (int x = 0; x < nx; x++)
			for (int y = 0; y < ny; y++)
				for (int z = 0; z < nz; z++) {
					label[x][y][z] = 0;

					if (prior[x][y][z] > 0)
						label[x][y][z] = 1;
					if (prior[x][y][z] > i_minPrior)
						priorVolume++;
				}
		System.out.print("Prior Volume " + priorVolume + "\n");

		// 2. initial mask
		// brain = new boolean[nx][ny][nz];
		System.out.print("Initializing mask ...");
		for (int x = 0; x < nx; x++)
			for (int y = 0; y < ny; y++)
				for (int z = 0; z < nz; z++) {
					brain[x][y][z] = false;
					current[x][y][z] = 0;
					if (segmentation[x][y][z] >= WMmin
							&& segmentation[x][y][z] <= WMmax
							&& prior[x][y][z] >= (Nmask - 1)) {
						brain[x][y][z] = true;
						current[x][y][z] = 1;
					} else if (segmentation[x][y][z] >= GMmin
							&& segmentation[x][y][z] <= GMmax
							&& prior[x][y][z] == (Nmask)) {
						brain[x][y][z] = true;
						current[x][y][z] = 2;
					}
				}
		System.out.println(" done");

		if (outputType == OutputType.ALL) {
			for (int x = 0; x < nx; x++)
				for (int y = 0; y < ny; y++)
					for (int z = 0; z < nz; z++) {
						if (brain[x][y][z] == true) {
							d0[x][y][z] = image[x][y][z];
						} else {
							d0[x][y][z] = 0.0f;
						}
					}
		}


		System.out.print("Removing 18 Connected Holes ... ");
		brain = ObjectProcessing.removeHoles18(brain, nx, ny, nz);
		System.out.println(" done");


		if (outputType == OutputType.ALL) {
			for (int x = 0; x < nx; x++)
				for (int y = 0; y < ny; y++)
					for (int z = 0; z < nz; z++) {
						if (brain[x][y][z] == true) {
							d1[x][y][z] = image[x][y][z];
						} else {
							d1[x][y][z] = 0.0f;
						}
					}
		}

		int dx = 0;
		int dy = 0;
		int dz = 0;
		while (initialErosionSize >= 0 && mask_good == false
				&& iterations < max_iterations) {
			iterations++;
			System.out.println("26 Connected Erosion with radius " + initialErosionSize);
			/*
			 * Awkward way to deal with resolution.
			 */
			dx = Numerics.round((float) initialErosionSize / (2 * rx));
			dy = Numerics.round((float) initialErosionSize / (2 * ry));
			dz = Numerics.round((float) initialErosionSize / (2 * rz));
			System.out.println(" " + dx + " " + dy + " " + dz + " ");
			brain = Morphology.erodeObject(brain, nx, ny, nz, dx, dy, dz);
			System.out.println("done\n");

			/*
			 * Largest 6 connected neighborhood.
			 */
			System.out.print("Get largest 6 Connected neighborhood ... ");
			lb = ObjectProcessing.connected6Object3D(brain, nx, ny, nz);
			brain = ObjectProcessing.largestObjectFromLabel(lb, nx, ny, nz);
			// brain = ObjectProcessing.largest6Object(brain, nx, ny, nz);
			System.out.println("done");

			if (outputType == OutputType.ALL) {
				for (int x = 0; x < nx; x++)
					for (int y = 0; y < ny; y++)
						for (int z = 0; z < nz; z++) {
							if (brain[x][y][z] == true) {
								d2[x][y][z] = image[x][y][z];
							} else {
								d2[x][y][z] = 0.0f;
							}
						}
			}

			/*
			 * Initialize distance map.
			 */
			System.out.print("Fast Marching ... ");
			for (int x = 0; x < nx; x++)
				for (int y = 0; y < ny; y++)
					for (int z = 0; z < nz; z++) {
						if (brain[x][y][z] == true) {
							march[x][y][z] = -1.0f;
						} else {
							march[x][y][z] = 1.0f;
						}
					}
			// march = ObjectProcessing.fastMarchingDistance(brain, nx, ny, nz, 18);
			march = FastMarching.signedDistanceFunction(march, nx, ny, nz, 2
					* radius + 2 * initialErosionSize);
			for (int x = 0; x < nx; x++)
				for (int y = 0; y < ny; y++)
					for (int z = 0; z < nz; z++) {
						if (march[x][y][z] > -1.00f) {
							march_byte[x][y][z] = (byte) (march[x][y][z] + 1.0f);
						} else {
							march_byte[x][y][z] = 0;
						}
					}
			System.out.println("done");

			recursiveHillDescent(brain, image, prior, segmentation, march_byte,
					radius + 2 * Numerics.round(initialErosionSize));

			if (outputType == OutputType.ALL) {
				for (int x = 0; x < nx; x++)
					for (int y = 0; y < ny; y++)
						for (int z = 0; z < nz; z++) {
							if (brain[x][y][z] == true) {
								d3[x][y][z] = image[x][y][z];
							} else {
								d3[x][y][z] = 0.0f;
							}
						}
			}

			System.out.println("Topologically consistent closing ... ");
			System.out.println(dx + " " + dy + " " + dz);
			dx = Numerics.round(mmcDilationSize / rx);
			dy = Numerics.round(mmcDilationSize / ry);
			dz = Numerics.round(mmcDilationSize / rz);

			brain = Morphology.dilateObject(brain, nx, ny, nz, dx, dy, dz);
			brain = ObjectProcessing.removeHoles6(brain, nx, ny, nz);

			/*
			lb = ObjectProcessing.connected6Object3D(brain, nx, ny, nz);
			brain = ObjectProcessing.largestObjectFromLabel(lb, nx, ny,nz);
			 */
			dx = Numerics.round(mmcErosionSize / rx);
			dy = Numerics.round(mmcErosionSize / ry);
			dz = Numerics.round(mmcErosionSize / rz);
			brain = Morphology.erodeObject(brain, nx, ny, nz, dx, dy, dz);
			System.out.println("done");

			currentVolume = 0;
			for (int x = 0; x < nx; x++)
				for (int y = 0; y < ny; y++)
					for (int z = 0; z < nz; z++) {
						if (brain[x][y][z] == true) {
							currentVolume++;
						}
					}

			if (priorVolume * 0.900f > currentVolume
					|| priorVolume * 1.200f < currentVolume) {
				System.out.println("*********************MASK LOOKS WRONG.");
				System.out.println("Mask size is : " + currentVolume);
				System.out.println("Mask min size: " + priorVolume * 0.900f);
				System.out.println("Mask max size: " + priorVolume * 1.200f);

				initialErosionSize--;

				if (initialErosionSize < 0 || iterations == 0 || iterations >= max_iterations ) {
					System.out.println("*********************MASK LOOKS WRONG.");
					System.out.println("SPECTRE2009 HAS FAILED.");
					// return null;
				}
			}
		}

		maskVol = new ImageDataUByte(nx, ny, nz);
		byte[][][] maskMat = maskVol.toArray3d();
		ImageDataFloat resultData = new ImageDataFloat(nx, ny, nz);
		float[][][] result = resultData.toArray3d();
		for (int x = 0; x < nx; x++)
			for (int y = 0; y < ny; y++)
				for (int z = 0; z < nz; z++) {
					if (brain[x][y][z] == true) {
						result[x][y][z] = image[x][y][z];
						maskMat[x][y][z] = 1;
					} else {
						result[x][y][z] = 0.0f; // This is a bad assumption.
						maskMat[x][y][z] = 0;
					}
				}

		System.out.println("Passing results to vols.");
		resultData.setHeader(img.getHeader());
		resultData.setName(img.getName() + "_strip");
		priorVol = new ImageDataUByte(prior);
		priorVol.setName(img.getName() + "_prior");
		priorVol.setHeader(img.getHeader());
		maskVol.setName(img.getName() + "_mask");
		maskVol.setHeader(img.getHeader());
		segVol = new ImageDataUByte(segmentation);
		segVol.setName(img.getName() + "_fantasm");
		segVol.setHeader(img.getHeader());
		// marchVol = new ImageDataUByte(march_byte);
		// marchVol.setName(img.getName() + "_march");
		// marchVol.setHeader(img.getHeader());
		currentVol = new ImageDataUByte(current);
		currentVol.setName(img.getName() + "_current");
		currentVol.setHeader(img.getHeader());
		d0Vol = new ImageDataFloat(d0);
		d0Vol.setName(img.getName() + "_d0");
		d0Vol.setHeader(img.getHeader());
		// d1Vol = new ImageDataFloat(d1);
		// d1Vol.setName(img.getName() + "_d1");
		// d1Vol.setHeader(img.getHeader());
		// d2Vol = new ImageDataFloat(d2);
		// d2Vol.setName(img.getName() + "_d2");
		// d2Vol.setHeader(img.getHeader());
		// d3Vol = new ImageDataFloat(d3);
		// d3Vol.setName(img.getName() + "_d3");
		// d3Vol.setHeader(img.getHeader());

		System.out.println("SPECTRE2009 Exit.");
		return resultData;
	}

	private final void recursiveHillDescent(boolean[][][] brain,
			float[][][] image, byte[][][] prior, byte[][][] segmentation,
			byte[][][] march, int steps) {
		int ignore0 = 0, ignore1 = 0, ignore2 = 0, ignore3 = 0;
		int add0 = 0, add1 = 0;
		int count = 0;

		boolean[][][] result = new boolean[nx][ny][nz];
		boolean[][][] visited = new boolean[nx][ny][nz];
		BinaryMinHeap heap = new BinaryMinHeap(count, nx, ny, nz);
		int ni, nj, nk;
		MaskVolume18 mask = new MaskVolume18();
		byte[] neighborsX = mask.getNeighborsX();
		byte[] neighborsY = mask.getNeighborsY();
		byte[] neighborsZ = mask.getNeighborsZ();

		System.out.print("Initializing heap ... ");
		for (int i = 0; i < nx; i++) {
			for (int j = 0; j < ny; j++) {
				for (int k = 0; k < nz; k++) {
					result[i][j][k] = brain[i][j][k];
					visited[i][j][k] = brain[i][j][k];
					if (march[i][j][k] == 0) {
						VoxelIndexed<VoxelInt> vox = new VoxelIndexed<VoxelInt>(
								new VoxelInt(march[i][j][k]));
						vox.setRefPosition(i, j, k);
						heap.add(vox);
						visited[i][j][k] = true;
					}
					/*
					 * Not the best way to start the heap.

					if (brain[i][j][k] == true) {
						result[i][j][k] = true;
						for (int koff = 0; koff < MaskVolume18.length; koff++) {
							ni = i + neighborsX[koff];
							nj = j + neighborsY[koff];
							nk = k + neighborsZ[koff];
							if (nj < 0 || nj >= ny || nk < 0 || nk >= nz
									|| ni < 0 || ni >= nx)
								continue; // Out of boundary
							if (brain[ni][nj][nk] == false) {
								VoxelIndexed<VoxelFloat> vox = new VoxelIndexed<VoxelFloat>(
										new VoxelFloat(
												(float) march[i][j][k]));
								vox.setRefPosition(i, j, k);
								// labels[i][j][k] = -1.5f;
								heap.add(vox);
								break;
							}
						}
					} else {
						result[i][j][k] = false;
					}
					 */
				}
			}
		}
		System.out.println("done");
		System.out.println("Heap Size ... " + heap.size());

		int i = 0, j = 0, k = 0;
		boolean add;
		VoxelIndexed<VoxelInt> he;

		// while (!heap.isEmpty()) {
		while (heap.size() > 0) {
			he = (VoxelIndexed<VoxelInt>) heap.remove();
			i = he.getRow();
			j = he.getColumn();
			k = he.getSlice();

			for (int koff = 0; koff < MaskVolume18.length; koff++) {
				ni = i + neighborsX[koff];
				nj = j + neighborsY[koff];
				nk = k + neighborsZ[koff];
				add = false;

				if (ni < 0 || ni >= nx || nj < 0 || nj >= ny || nk < 0
						|| nk >= nz) {
					ignore0++;
					continue; // Out of boundary
				}

				if (brain[ni][nj][nk] == true || visited[ni][nj][nk] == true) {
					ignore1++;
					continue;
				}

				if (march[ni][nj][nk] < march[i][j][k]) {
					ignore2++;
					continue;
				}

				if ((prior[ni][nj][nk] > i_minPrior)
						&& (dir * image[ni][nj][nk] <= dir * image[i][j][k])
						&& (march[ni][nj][nk] >= march[i][j][k])
						&& (((segmentation[ni][nj][nk] >= GMmin) && (segmentation[ni][nj][nk] <= GMmax)) || ((segmentation[ni][nj][nk] >= WMmin) && (segmentation[ni][nj][nk] <= WMmax)))) {
					add = true;
					add0++;
				} else if ((dir * image[ni][nj][nk] < dir * image[i][j][k])
						&& (march[ni][nj][nk] > march[i][j][k])
						&& (((segmentation[ni][nj][nk] >= CSFmin) && (segmentation[ni][nj][nk] <= CSFmax))
								|| ((segmentation[ni][nj][nk] >= GMmin) && (segmentation[ni][nj][nk] <= GMmax)) || ((segmentation[ni][nj][nk] >= WMmin) && (segmentation[ni][nj][nk] <= WMmax)))) {
					add = true;
					add1++;
				}
				/*
				if (prior[ni][nj][nk] > i_minPrior){
					if ( (((segmentation[ni][nj][nk] >= GMmin) && (segmentation[ni][nj][nk] <= GMmax))
						|| ((segmentation[ni][nj][nk] >= WMmin) && (segmentation[ni][nj][nk] <= WMmax)))
						&& (dir * image[ni][nj][nk] <= dir * image[i][j][k])){
						add = true;
						add0++;
					}
				} else {
					if ( (((segmentation[ni][nj][nk] >= GMmin) && (segmentation[ni][nj][nk] <= GMmax))
						|| ((segmentation[ni][nj][nk] >= WMmin) && (segmentation[ni][nj][nk] <= WMmax)))
						&& (dir * image[ni][nj][nk] < dir * image[i][j][k])){
						add = true;
						add1++;
					}
				}
				 */

				if (add) {
					VoxelIndexed<VoxelFloat> vox = new VoxelIndexed<VoxelFloat>(
							new VoxelFloat((float) march[ni][nj][nk]));
					result[ni][nj][nk] = true;
					visited[ni][nj][nk] = true;
					vox.setRefPosition(ni, nj, nk);
					heap.add(vox);
				}
			}
		}

		for (int x = 0; x < nx; x++)
			for (int y = 0; y < ny; y++)
				for (int z = 0; z < nz; z++) {
					brain[x][y][z] = result[x][y][z];
				}

		System.out.println("recursiveHillDescent:\n\tOut of Bounds: " + ignore0
				+ "\n\tAlready in the mask: " + ignore1
				+ "\n\tNot moving away from the mask: " + ignore2
				+ "\n\tType 1: " + add0 + "\n\tType 2: " + add1);
	}
}
