package edu.jhu.ece.iacl.algorithms.topology;

import java.util.BitSet;

import edu.jhmi.rad.medic.algorithms.AlgorithmTopologyCorrection;
import edu.jhmi.rad.medic.libraries.TopologyPropagation;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipavWrapper;

import gov.nih.mipav.model.structures.ModelImage;
import gov.nih.mipav.model.structures.ModelStorageBase;
import gov.nih.mipav.view.dialogs.JDialogBase;

public class TopologyCorrection extends AbstractCalculation{
	private float lowestLevel = 0.0f;
	private float highestLevel = 1.0f;
	private boolean useMinMax = true;
	private int dimension = 2;
	private int objConnect = 4;
	private int bckConnect = 8;
	private String type = "scalar_image";
	private String connect2D = "2D:8/4";
	private String connect3D = "3D:26/6";
	private float minDistance = 0.0001f;
	private String inputSkel = "intensity";
	private String propagationType = "object->background";
	private static String[] types = { "scalar_image", "binary_object" };
	private static String[] propagationTypes = { "object->background", "background->object" };
	public enum PropagationTypes {OBJECT_TO_BACKGROUND,BACKGROUND_TO_OBJECT};
	private AlgorithmTopologyCorrection algo = null;
	private ImageDataMipav vol;
	public TopologyCorrection(){
	}
	public TopologyCorrection(ImageData vol){
		this.vol=new ImageDataMipav(vol);	
	}
	public ImageDataFloat correctBinary(int rule){
		return correct(rule,"binary_object");
	}
	public ImageDataFloat correctScalar(int rule){
		return correct(rule,"scalar_image");
	}
	public ImageDataFloat solve(ImageDataFloat img,ImageDataFloat maskImage,double paintThreshold,int conn,PropagationTypes propType){
		if (conn==ConnectivityRule.CONNECT_26_6) {
            dimension = 3;
			objConnect = 26;
			bckConnect = 6;
			connect3D = "3D:26/6";
        } else if (conn==ConnectivityRule.CONNECT_18_6) {
            dimension = 3;
			objConnect = 18;
			bckConnect = 6;
			connect3D = "3D:18/6";
		} else if (conn==ConnectivityRule.CONNECT_6_26) {
            dimension = 3;
			objConnect = 6;
			bckConnect = 26;
			connect3D = "3D:6/26";
		} else if (conn==ConnectivityRule.CONNECT_6_18) {
            dimension = 3;
			objConnect = 6;
			bckConnect = 18;
			connect3D = "3D:6/18";
		} else {
            dimension = 3;
			objConnect = 6;
			bckConnect = 6;
			connect3D = "3D:all";
		}
        
		useMinMax = true;
        propagationType = propagationTypes[propType.ordinal()];
		ModelImage image = img.getModelImageCopy();
		int[] destExtents;
		String name = (image.getImageName()+ "_corrected");
		if (image.getNDims() == 2) { // source image is 2D
			destExtents = new int[2];
			destExtents[0] = image.getExtents()[0]; // X dim
			destExtents[1] = image.getExtents()[1]; // Y dim
		} else { // source image is 3D or more (?)
			destExtents = new int[3];
			destExtents[0] = image.getExtents()[0];
			destExtents[1] = image.getExtents()[1];
			destExtents[2] = image.getExtents()[2];
		}
		ModelImage resultImage = new ModelImage(ModelStorageBase.FLOAT, destExtents, (image
				.getImageName()+ "_corrected"));
		// Create algorithm
		BitSet inputPaint=null;
		if(maskImage!=null){
			int nx=maskImage.getRows();
			int ny=maskImage.getCols();
			int nz=maskImage.getSlices();
			inputPaint=new BitSet(nx*ny*nz);
			for(int x=0;x<nx;x++){
				for(int y=0;y<ny;y++){
					for(int z=0;z<nz;z++){
						inputPaint.set( (x) + (nx)*(y) + (nx)*(ny)*(z),(maskImage.getFloat(x, y, z)>paintThreshold) );
					}
				}
			}
			inputSkel="paint_mask";
		}
		
		algo = new AlgorithmTopologyCorrection(resultImage, image, type, minDistance, dimension, objConnect,
				bckConnect, inputSkel, inputPaint, lowestLevel, highestLevel, useMinMax, propagationType);
		algo.runAlgorithm();
		
		return new ImageDataFloat(new ImageDataMipavWrapper(resultImage));
	}
	public ImageDataFloat correctLevelSet(int rule){
		ImageDataFloat mask=new ImageDataFloat(vol.getRows(),vol.getCols(),vol.getSlices());
		ImageDataFloat result=new ImageDataFloat(vol);
		float[][][] resultMat=result.toArray3d();
		mask.setName("mask");
		float[][][] maskMat=mask.toArray3d();
		for(int i=0;i<vol.getRows();i++){
			for(int j=0;j<vol.getCols();j++){
				for(int k=0;k<vol.getSlices();k++){
					if(vol.getFloat(i,j,k)>0){
						maskMat[i][j][k]=0;
					} else {
						maskMat[i][j][k]=254;
					}
				}
			}
		}
		TopologyCorrection tc=new TopologyCorrection(mask);
		int newRule=0;
		//Determine reverse rule for binary volume
		switch(rule){
		case ConnectivityRule.CONNECT_18_6:
			newRule=ConnectivityRule.CONNECT_6_18;break;
		case ConnectivityRule.CONNECT_6_18:
			newRule=ConnectivityRule.CONNECT_18_6;break;
		case ConnectivityRule.CONNECT_26_6:
			newRule=ConnectivityRule.CONNECT_6_26;break;
		case ConnectivityRule.CONNECT_6_26:
			newRule=ConnectivityRule.CONNECT_26_6;break;
		}
		ImageDataFloat masktc=tc.correctBinary(newRule);
		float[][][] masktcMat=masktc.toArray3d();
		for(int i=0;i<vol.getRows();i++){
			for(int j=0;j<vol.getCols();j++){
				for(int k=0;k<vol.getSlices();k++){
					if(masktcMat[i][j][k]<=0&&resultMat[i][j][k]<=0){
						resultMat[i][j][k]=0.001f;
					} else if(masktcMat[i][j][k]>0&&resultMat[i][j][k]>0){
						resultMat[i][j][k]=-0.001f;
					}
				}
			}
		}
		result.setName(vol.getName()+"_tc_"+ConnectivityCheck.toString(rule));
		return result;
	}
	private ImageDataFloat correct(int rule,String type){
		TopologyPropagation algorithm;
		int x,y,z;
		int cObj =0 ;
		int cBack = 0;
		switch(rule){
		case ConnectivityRule.CONNECT_18_6:
			cObj=18;cBack=6;break;
		case ConnectivityRule.CONNECT_6_18:
			cObj=6;cBack=18;break;
		case ConnectivityRule.CONNECT_26_6:
			cObj=26;cBack=6;break;
		case ConnectivityRule.CONNECT_6_26:
			cObj=6;cBack=26;break;
		}
		ModelImage image=vol.getModelImageCopy();
		int[] destExtents;
		if (image.getNDims() == 2) { // source image is 2D
			destExtents = new int[2];
			destExtents[0] = image.getExtents()[0]; // X dim
			destExtents[1] = image.getExtents()[1]; // Y dim
		} else { // source image is 3D or more (?)
			destExtents = new int[3];
			destExtents[0] = image.getExtents()[0];
			destExtents[1] = image.getExtents()[1];
			destExtents[2] = image.getExtents()[2];
		}
		ModelImage resultImage = new ModelImage(ModelStorageBase.FLOAT, destExtents,image.getImageName()+ "_corrected");
		float minDistance		=		0.0001f;
		AlgorithmTopologyCorrection algo = new AlgorithmTopologyCorrection(resultImage, image, type,
				minDistance, 3, 
				cObj, cBack, 
				"intensity", null, 
				0, 1, true,
		"object->background");
		algo.runAlgorithm();
		algo.finalize();
		ImageDataFloat newVol=new ImageDataFloat(new ImageDataMipav(resultImage));
		newVol.setName(vol.getName()+"_tc_"+ConnectivityCheck.toString(rule));
		image.disposeLocal();
		return newVol;
	}
	public float getHighestLevel() {
		return highestLevel;
	}
	public void setHighestLevel(float highestLevel) {
		this.highestLevel = highestLevel;
	}
	public float getLowestLevel() {
		return lowestLevel;
	}
	public void setLowestLevel(float lowestLevel) {
		this.lowestLevel = lowestLevel;
	}
	public boolean isUseMinMax() {
		return useMinMax;
	}
	public void setUseMinMax(boolean useMinMax) {
		this.useMinMax = useMinMax;
	}
	
}


