package edu.jhu.ece.iacl.plugins.topology;

import java.util.BitSet;

import edu.jhmi.rad.medic.algorithms.AlgorithmTopologyCorrection;
import edu.jhu.ece.iacl.algorithms.PrinceGroupAuthors;
import edu.jhu.ece.iacl.algorithms.topology.ConnectivityRule;
import edu.jhu.ece.iacl.algorithms.topology.TopologyCorrection;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.parameter.*;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipavWrapper;

import gov.nih.mipav.model.structures.ModelImage;
import gov.nih.mipav.model.structures.ModelStorageBase;
import gov.nih.mipav.view.MipavUtil;
import gov.nih.mipav.view.dialogs.JDialogBase;

public class MedicAlgorithmTopologyCorrection extends ProcessingAlgorithm {
	private ParamVolume vol;
	private ParamVolume tcvol;
	private float lowestLevel = 0.0f;
	private float highestLevel = 1.0f;
	private boolean useMinMax = true;
	private boolean thresholdStop = false;
	private int dimension = 2;
	private int objConnect = 4;
	private int bckConnect = 8;
	private String type = "scalar_image";
	private String connect2D = "2D:8/4";
	private String connect3D = "3D:26/6";
	private float minDistance = 0.0001f;
	private String inputSkel = "intensity";
	private	static String[]	inputSkels = { "intensity", "paint_mask" };
	private String propagationType = "object->background";
	private static String[] types = { "scalar_image", "binary_object" };

	private static String[] propagationTypes = { "object->background", "background->object" };
	private AlgorithmTopologyCorrection algo = null;
	ParamDouble lowestLevelParam;
	ParamDouble highestLevelParam;
	ParamDouble regularizeParam;
	ParamBoolean thresholdStopParam;
	ParamBoolean useMinMaxParam;
	ParamOption imageType;
	ParamOption initialSkel;
	ParamOption connect2DParam;
	ParamOption connect3DParam;
	ParamOption propagationTypeParam;
	ParamVolume maskVol;

	private static final String revnum = AlgorithmTopologyCorrection.get_version();

	protected void createInputParameters(ParamCollection inputParams) {
		inputParams.add(vol = new ParamVolume("Volume"));
		inputParams.add(imageType = new ParamOption("Image Type", types));
		inputParams.add(connect3DParam = new ParamOption("Connectivity (Foreground,Background)",new String[]{"(18,6)","(6,18)","(26,6)","(6,26)","(6,6)"}));
		connect3DParam.setValue(0);
		inputParams.add(propagationTypeParam = new ParamOption("Propagation Direction", propagationTypes));
		propagationTypeParam.setValue(0);
		inputParams.add(initialSkel = new ParamOption("Initialization", inputSkels));
		initialSkel.setValue(0);
		inputParams.add(maskVol = new ParamVolume("Paint mask"));
		maskVol.setMandatory(false);
		inputParams.add(lowestLevelParam = new ParamDouble("Lowest Level", 0.0));
		inputParams.add(highestLevelParam = new ParamDouble("Highest Level", 1.0));
		inputParams.add(useMinMaxParam = new ParamBoolean("Normalize", true));
		inputParams.add(thresholdStopParam = new ParamBoolean("Stop at intensity boundaries", false));
		inputParams.add(regularizeParam = new ParamDouble("Regularize Amount", 0.0001));

		inputParams.setLabel("Topology Correction");
		inputParams.setName("topology_correction");

		inputParams.setPackage("IACL");
		inputParams.setCategory("Topology");

		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("http://www.iacl.ece.jhu.edu/");
		info.setVersion(revnum);
		info.setEditable(false);
		info.setDescription("Fast-marching topology correction algorithm.");
		info.setLongDescription("Corrects a level set to be topologically consistent. If the algorithm is being run background->foreground, do not specify an initial mask.");
		
		info.add(PrinceGroupAuthors.pierreLouisBazin);
		info.add(PrinceGroupAuthors.blakeLucas);
		info.setStatus(DevelopmentStatus.RC);
	}

	@Override
	protected void createOutputParameters(ParamCollection outputParams) {
		outputParams.add(tcvol = new ParamVolume("Topologically Correct Volume"));
	}

	@Override
	public void execute(CalculationMonitor monitor) {
    	String tmpStr;
    	minDistance=regularizeParam.getFloat();
        lowestLevel = lowestLevelParam.getFloat();
		highestLevel=highestLevelParam.getFloat();
		
		if (connect3DParam.getValue().equals("(26,6)")) {
            dimension = 3;
			objConnect = 26;
			bckConnect = 6;
			connect3D = "3D:26/6";
        } else if (connect3DParam.getValue().equals("(18,6)")) {
            dimension = 3;
			objConnect = 18;
			bckConnect = 6;
			connect3D = "3D:18/6";
		} else if (connect3DParam.getValue().equals("(6,26)")) {
            dimension = 3;
			objConnect = 6;
			bckConnect = 26;
			connect3D = "3D:6/26";
		} else if (connect3DParam.getValue().equals("(6,18)")) {
            dimension = 3;
			objConnect = 6;
			bckConnect = 18;
			connect3D = "3D:6/18";
		} else if (connect3DParam.getValue().equals("(6,6)")) {
            dimension = 3;
			objConnect = 6;
			bckConnect = 6;
			connect3D = "3D:all";
		}
        
		type = imageType.getValue();
		useMinMax = useMinMaxParam.getValue();
        thresholdStop = thresholdStopParam.getValue();
        propagationType = propagationTypeParam.getValue();
		inputSkel = initialSkel.getValue();
		ModelImage image = vol.getImageData().getModelImageCopy();
		int[] destExtents;
		String name = JDialogBase.makeImageName(image.getImageName(), "_corrected");
		if (image.getNDims() == 2) { // source image is 2D
			destExtents = new int[2];
			destExtents[0] = image.getExtents()[0]; // X dim
			destExtents[1] = image.getExtents()[1]; // Y dim
		} else { // source image is 3D or more (?)
			destExtents = new int[3];
			destExtents[0] = image.getExtents()[0];
			destExtents[1] = image.getExtents()[1];
			destExtents[2] = image.getExtents()[2];
		}
		ModelImage resultImage = new ModelImage(ModelStorageBase.FLOAT, destExtents, JDialogBase.makeImageName(image
				.getImageName(), "_corrected"));
		// Create algorithm
		//Bhaskar
		//ImageDataMipav maskImage=new ImageDataMipav((propagationTypeParam.getIndex()==0)?maskVol.getImageData():null);
		//ImageDataMipav maskImage=(propagationTypeParam.getIndex()==0)?new ImageDataMipav(maskVol.getImageData()):null;
		ImageDataMipav maskImage=null;
		BitSet inputPaint=null;
		if(inputSkel.equals("paint_mask") ){
			maskImage = new ImageDataMipav(maskVol.getImageData());
			int nx=maskImage.getRows();
			int ny=maskImage.getCols();
			int nz=maskImage.getSlices();
			inputPaint=new BitSet(nx*ny*nz);
			int count=0;
			for(int x=0;x<nx;x++){
				for(int y=0;y<ny;y++){
					for(int z=0;z<nz;z++){
						count+=(maskImage.getFloat(x, y, z)>0)?1:0;
						inputPaint.set( (x) + (nx)*(y) + (nx)*(ny)*(z), (maskImage.getFloat(x, y, z)>0) );
					}
				}
			}
		}

		algo = new AlgorithmTopologyCorrection(resultImage, image, type, minDistance, dimension, objConnect,
				bckConnect, inputSkel, inputPaint, lowestLevel, highestLevel, useMinMax, propagationType, thresholdStop);
		algo.runAlgorithm();
		tcvol.setValue(new ImageDataMipavWrapper(resultImage));
		//tcvol.getImageData().getModelImageCopy().copyFileTypeInfo(vol.getImageData().getModelImageCopy());
		tcvol.getImageData().setHeader(vol.getImageData().getHeader());
		image.disposeLocal();
	}
}
