package edu.jhu.ece.iacl.plugins.labeling.staple;

import java.io.File;
import java.util.ArrayList;
import java.util.List;

import edu.jhmi.rad.medic.utilities.CropParameters;
import edu.jhmi.rad.medic.utilities.CubicVolumeCropper;
import edu.jhu.ece.iacl.algorithms.manual_label.staple.*;
import edu.jhu.ece.iacl.io.CubicVolumeReaderWriter;
import edu.jhu.ece.iacl.io.StringReaderWriter;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.pipeline.AlgorithmInformation.Citation;
import edu.jhu.ece.iacl.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.pipeline.parameter.ParamFloat;
import edu.jhu.ece.iacl.pipeline.parameter.ParamInteger;
import edu.jhu.ece.iacl.pipeline.parameter.ParamDouble;
import edu.jhu.ece.iacl.pipeline.parameter.ParamObject;
import edu.jhu.ece.iacl.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.structures.geom.GridPt;
import edu.jhu.ece.iacl.structures.image.ImageData;
import edu.jhu.ece.iacl.structures.image.ImageDataFloat;

public class MedicAlgorithmSTAPLE extends ProcessingAlgorithm{
	
	private ParamVolumeCollection ratervols;
	private ParamDouble eps;
	private ParamInteger maxiters;
	private ParamOption init;
	private ParamOption connectivity;
	private ParamFloat beta;
	
	private ParamBoolean probFlag;
	
	
	private ParamObject<String> pl;
	private ParamVolumeCollection truthOut;
	private ParamVolume labelvol;
	
	private CubicVolumeReaderWriter rw = CubicVolumeReaderWriter.getInstance();
	
	private static final String rcsid =
		"$Id: MedicAlgorithmSTAPLE.java,v 1.5 2009/06/03 13:01:35 bogovic Exp $";
	private static final String cvsversion =
		"$Revision: 1.5 $";
	private static final String revnum = cvsversion.replace("$Revision: 1.5 $", "").replace(" $", "");


	protected void createInputParameters(ParamCollection inputParams) {

		inputParams.add(ratervols=new ParamVolumeCollection("Rater Volumes"));
		inputParams.setName("staplevolume");
		inputParams.setLabel("STAPLE "+revnum);

		inputParams.add(init=new ParamOption("Initialization Type", new String[]{"Performance", "Truth"}));

		inputParams.add(eps = new ParamDouble("Max Delta for Convergence"));
		eps.setValue((double)0.00001);
		
		inputParams.setCategory("IACL.Labeling.STAPLE");
		inputParams.add(maxiters=new ParamInteger("Max Iterations"));
		maxiters.setValue(new Integer(50));
		
		inputParams.add(beta=new ParamFloat("MRF Regularization parameter",0.0f,Float.MAX_VALUE,0.0f));
		inputParams.add(connectivity=new ParamOption("Connectivity",new String[]{"6","18","26"}));
		
		inputParams.add(probFlag = new ParamBoolean("Output Label Probabilities"));
		probFlag.setValue(false);
		
		AlgorithmInformation info=getAlgorithmInformation();
		info.setWebsite("");
		info.setDescription("STAPLE - Simultaneous Truth and Performance Level Estimation");
		info.add(new AlgorithmAuthor("John Bogovic", "bogovic@jhu.edu", ""));
		info.setAffiliation("Johns Hopkins University, Department of Electrical Engineering");
		info.add(new Citation("Warfield SK, Zou KH, Wells WM, \"Simultaneous Truth and Performance Level Estimation (STAPLE): An Algorithm for the Validation of Image Segmentation\" IEEE TMI 2004; 23(7):903-21  "));	
		info.setVersion(revnum);	
		info.setLongDescription("Given a number of labelings of a particular strucute, STAPLE returns membership functions of the Truth");
	
	}
	
	protected void createOutputParameters(ParamCollection outputParams) {	
		
		outputParams.add(labelvol = new ParamVolume("Label Volume", null, -1,-1,-1,-1));
		
		outputParams.add(truthOut = new ParamVolumeCollection("Label Probabilities"));
		truthOut.setMandatory(false);
		
		pl = new ParamObject<String>("PerformanceLevels",new StringReaderWriter());
		outputParams.add(pl);
		
	}
	
	protected void execute(CalculationMonitor monitor) {
		

		String name = ratervols.getImageDataList().get(0).getName();
		
		//crop input volumes 
		//Find crop parameters - using addition of all rater images
		ImageData added = addedVolume();
		CubicVolumeCropper cropper = new CubicVolumeCropper();	
		cropper.crop(added, 0, 1);
		CropParameters cp = cropper.getLastCropParams();
		ArrayList<ImageData> croppedRaterList = new ArrayList<ImageData>(ratervols.getImageDataList().size());
		
		for(int i=0; i<ratervols.getImageDataList().size(); i++){
			croppedRaterList.add(cropper.crop(ratervols.getImageDataList().get(i), cp));
		}
		
		GridPt.Connectivity conn;
		if(connectivity.getIndex()==0){
			conn=GridPt.Connectivity.SIX;
		}else if(connectivity.getIndex()==1){
			conn=GridPt.Connectivity.EIGHTEEN;
		}else{
			conn=GridPt.Connectivity.TWENTYSIX;
		}
		
		
		//clean up the old rater volumes
		ratervols.dispose();
		
		STAPLEmulti staple = new STAPLEmulti(croppedRaterList);
		staple.setmaxIters(maxiters.getInt());
		staple.setEps(eps.getFloat());
		staple.setInit(init.getValue());
		staple.setConnectivity(conn);
		staple.setBeta(beta.getFloat());
		staple.distributeBeta();	//distribute beta?
		staple.iterate();
		
		if(probFlag.getValue()){
			truthOut.setValue(staple.getTruth());
		}
		
		labelvol.setValue(cropper.uncrop(staple.getHardSeg(), cp));
		labelvol.setFileName("LabelVolume_"+this.toString());
		
		ArrayList<ImageData> truth = staple.getTruth();
		ArrayList<ImageData> truthuc = new ArrayList<ImageData>(truth.size());
		for(int i=0; i<truth.size(); i++){
			ImageData f = truth.get(i);
			ImageData d = cropper.uncrop(f, cp);
			f.dispose();
			truthuc.add(d);
		}
		truthOut.setValue(truthuc);
		pl.setObject(staple.getPeformanceLevel().toString());
		pl.setFileName(name+"_StapleRaterPerformance");
		
		System.out.println("FINISHED");
	}
	
	private ImageData addedVolume(){
		ImageData added = ratervols.getImageDataList().get(0).clone();
		int rows=added.getRows();
		int cols=added.getCols();
		int slices=added.getSlices();
		
		for(int l = 1; l<ratervols.getImageDataList().size(); l++){
			ImageData vol = ratervols.getImageDataList().get(l);
			for (int i = 0; i < rows; i++) {
				for (int j = 0; j < cols; j++) {
					for (int k = 0; k < slices; k++) {
						added.set(i,j,k,added.getInt(i,j,k)+vol.getInt(i,j,k));
					}
				}
			}
		}
		return added;
	}
	public void writeVolumes(List<ImageData> volumes){
		File dir = new File("/home/john/Desktop/truthVols");
		for(int i=0; i<volumes.size(); i++){
			rw.write(volumes.get(i), dir);
		}
		
	}
}

