package edu.jhu.ece.iacl.plugins.qualitycontrol;

//import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities;
import edu.jhu.ece.iacl.jist.io.ImageDataReaderWriter;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamDouble;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataDouble;

import java.io.File;
import java.io.IOException;


public class MedicAlgorithmQualityControlSegmentation extends ProcessingAlgorithm{
	private ParamVolume inParamSegmentation;
	private ParamVolume inParamIntensity;
	private ParamOption inParamPassCondition,inParamMonitorType;	
	private ParamDouble inParamUpperRange, inParamLowerRange,inParamN;
	private ParamFile outSnapShot;
	private ParamBoolean outTestPassed;


	/****************************************************
	 * CVS Version Control
	 ****************************************************/
	private static final String cvsversion = "$Revision: 1.2 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");
	private static final String shortDescription = "Takes a snap shot of a segmentation and checks if a derived measurement is within range";
	private static final String longDescription = "";


	protected void createInputParameters(ParamCollection inputParams) {
		inputParams.add(inParamSegmentation=new ParamVolume("Segmentation Volume to Monitor"));
		inputParams.add(inParamIntensity=new ParamVolume("Intensity Volume to Monitor (Optional)"));
		inParamIntensity.setMandatory(false);
		String[] monitorType = {"Total Voxels of Label N", "Total Volume of Label N", "Average Intensity Under Label N"};
		inputParams.add(inParamMonitorType=new ParamOption("Monitor Type",monitorType));
		inputParams.add(inParamN=new ParamDouble("N",0));
		String[] passConditions = {"Inside Range", "Outside Range", "Above Upper Range", "Below Lower Range"}; 
		inputParams.add(inParamPassCondition=new ParamOption("Monitor Pass When",passConditions));
		inputParams.add(inParamUpperRange=new ParamDouble("Upper Range",0));
		inputParams.add(inParamLowerRange=new ParamDouble("Lower Range",0));
		inputParams.setPackage("IACL");
		inputParams.setCategory("Quality Control");
		inputParams.setLabel("Quality Control - Segmentation");
		inputParams.setName("QC_Segmentation");


		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("");
		info.setAffiliation("");
		info.setDescription(shortDescription);
		info.setLongDescription(shortDescription + longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.RC);
	}


	protected void createOutputParameters(ParamCollection outputParams) {
		outputParams.add(outSnapShot = new ParamFile("Center of mass Snapshot of Result"));
		outputParams.add(outTestPassed = new ParamBoolean("Quality Control Passed"));
	}


	protected void execute(CalculationMonitor monitor) {
		//draw snapshot at center of mass
		writeSnapshot();
		
		//calculate metric
		double metric = 0; 
		switch(inParamMonitorType.getIndex()){
		case 0: //Total Voxels of Label N
			metric = findNumOfVoxelsWithLabel(inParamN.getInt());
			break;
		case 1: //Total Volume of Label N 
			metric = findNumOfVoxelsWithLabel(inParamN.getInt());			
			float[] res = inParamSegmentation.getImageData().getHeader().getDimResolutions();
			metric = metric*res[0]*res[1]*res[2];
			break;
		case 2: //Average Intensity Under Label N	
			metric = findIntensityUnderLabel(inParamN.getInt());
			break;
		}
		
		System.out.format("Metric" + metric + "\n");

		//check if metric passes condition
		double upperD = inParamUpperRange.getDouble();
		double lowerD = inParamLowerRange.getDouble();
		outTestPassed.setValue(false);//we set as fail by default and set true if it passes

		switch(inParamPassCondition.getIndex()){
		case 0: //Inside Range   
			if(metric < upperD && metric > lowerD) outTestPassed.setValue(true);
			break;
		case 1: //Outside Range
			if(metric > upperD && metric < lowerD) outTestPassed.setValue(true);
			break;
		case 2: //Above Upper Range
			if(metric > upperD) outTestPassed.setValue(true);
			break;
		case 3: //Below Lower Range
			if(metric < lowerD) outTestPassed.setValue(true);
			break;
		default:
			outTestPassed.setValue(false);
			break;
		}
	}
	
	
	protected double findIntensityUnderLabel(int N){
		ImageData seg = inParamSegmentation.getImageData();
		ImageData intensity = inParamIntensity.getImageData();
		int XN = seg.getRows();
		int YN = seg.getCols();
		int ZN = seg.getSlices();
		double intensitySum = 0;
		int counter = 0; 
		for(int i = 0; i < XN; i++)
			for(int j = 0; j < YN; j++)
				for(int k = 0; k < ZN; k++){
					if(seg.getInt(i, j, k) == N)intensitySum+= intensity.getDouble(i, j, k);
					counter++;
				}
		return intensitySum/counter;
	}
	
	protected int findNumOfVoxelsWithLabel(int N){
		ImageData seg = inParamSegmentation.getImageData(); 
		int XN = seg.getRows();
		int YN = seg.getCols();
		int ZN = seg.getSlices();
		int numOfVoxels = 0;
		
		for(int i = 0; i < XN; i++)
			for(int j = 0; j < YN; j++)
				for(int k = 0; k < ZN; k++){
					if(seg.getInt(i, j, k) == N)numOfVoxels++;
				}

		return numOfVoxels; 
	}
	
	protected void writeSnapshot() {
		File outputDir = new File(this.getOutputDirectory()+File.separator+edu.jhu.ece.iacl.jist.utility.FileUtil.forceSafeFilename(this.getAlgorithmName()));
		
		try{
			if(!outputDir.isDirectory()){
				(new File(outputDir.getCanonicalPath())).mkdir();
			}
		}catch(IOException e){
			e.printStackTrace();
		}

		ImageData seg = inParamSegmentation.getImageData(); 
		ImageData intensity = inParamIntensity.getImageData();
		//double[] segMaxMin = RegistrationUtilities.calculateMinAndMaxVals(seg, seg);
		//double[] intensityMaxMin = RegistrationUtilities.calculateMinAndMaxVals(intensity, intensity);
		
		double[] segMaxMin = calculateMinAndMaxVals(seg, seg);
		double[] intensityMaxMin = calculateMinAndMaxVals(intensity, intensity);
		
		int XN = seg.getRows();
		int YN = seg.getCols();
		int ZN = seg.getSlices();
		int pad = 2;
		int maxDim = Math.max(Math.max(XN, YN),ZN);

		int IN = 3*maxDim + 2*pad;//x dim of snapshot
		int JN;//y dim of snapshot
		if(intensity != null){
			JN = 2*maxDim+pad;
		}else{
			JN = maxDim; 
		}
		double[][] snapshot = new double[IN][JN];//create output snapshot
		int[] CoM = new int[3];//center of mass
		
		double totalSum = 0;
		double currentVal = 0;
		//Calculate Centers of Mass
		for(int i = 0; i < XN; i++)
			for(int j = 0; j < YN; j++)
				for(int k = 0; k < ZN; k++){
					currentVal = seg.getDouble(i, j, k);
					totalSum += currentVal;
					CoM[0] += currentVal*i;
					CoM[1] += currentVal*j;
					CoM[2] += currentVal*k;
		}

		//System.out.format(totalSum+" CoM");
		for(int i = 0; i < 3; i++) {
			CoM[i] = (int)((double)CoM[i]/totalSum);
			//System.out.format(" "+CoM[i]);
		}
		
		
		//write screenshots at CoM
		//X-Y slice first
		int dimDiff = (maxDim - XN)/2; 
		for(int i = 0; i < XN; i++)
			for(int j = 0; j < YN; j++){
				snapshot[dimDiff+i][j] = seg.getDouble(i, j, CoM[2])/segMaxMin[1];
				if (intensity != null) snapshot[dimDiff+i][j+pad+maxDim] = intensity.getDouble(i, j, CoM[2])/intensityMaxMin[1];
			}
		
		//Z-X slice next
		dimDiff = (maxDim - ZN)/2;
		for(int k = 0; k < ZN; k++)
		for(int i = 0; i < XN; i++){
				snapshot[k+dimDiff+maxDim+pad][i] = seg.getDouble(i, CoM[1], k)/segMaxMin[1];
				if (intensity != null) snapshot[k+dimDiff+maxDim+pad][i+pad+maxDim] = intensity.getDouble(i, CoM[1], k)/intensityMaxMin[1];
			}

		//Z-Y slice last
		dimDiff = (maxDim - ZN)/2;
		for(int k = 0; k < ZN; k++)
			for(int j = 0; j < YN; j++){
			snapshot[k+dimDiff+2*maxDim+2*pad][j] = seg.getDouble(CoM[0], j, k)/segMaxMin[1];
			if (intensity != null) snapshot[k+dimDiff+2*maxDim+2*pad][j+pad+maxDim] = intensity.getDouble(CoM[0], j, k)/intensityMaxMin[1];
		}
		
		//pad vertical lines
		for(int i = 0; i < pad; i++)
			for(int j = 0; j < JN; j++){
				snapshot[i+maxDim][j] = 1; 
				snapshot[i+2*maxDim+pad][j] = 1;
			}

		//pad horizontal lines
		for(int i = 0; i < IN; i++)
			for(int j = 0; j < pad; j++){
				snapshot[i][j+maxDim] = 1; 
			}

		
		ImageData out = new ImageDataDouble(snapshot);
		
		ImageDataReaderWriter imageRW = new ImageDataReaderWriter();
		File outFile = new File(outputDir+File.separator+seg.getName()+".jpg");
		imageRW.write(out, outFile);
		outSnapShot.setValue(outFile);
	}
	
	
	//Calculate Max and Min of the two images
	static public double[] calculateMinAndMaxVals(ImageData sub, ImageData tar) {

		int ch;
		int CH;
		int i, j, k;
		int x = 0, y = 0, z = 0;
		int mx = 0, my = 0, mz = 0;

		int XN, YN, ZN;
		XN = sub.getRows();
		YN = sub.getCols();
		ZN = sub.getSlices();
		double MinandMaxValsD[] = new double[2];
		double max = Double.NEGATIVE_INFINITY;
		double min = Double.POSITIVE_INFINITY;
		for (i = 0; i < XN; i++) {
			for (j = 0; j < YN; j++) {
				for (k = 0; k < ZN; k++) {

					if (sub.getDouble(i, j, k) > max) {
						max = sub.getDouble(i, j, k);
						mx = i;
						my = j;
						mz = k;
					}
					if (sub.getDouble(i, j, k) < min) {
						min = sub.getDouble(i, j, k);
						x = i;
						y = j;
						z = k;
					}
					if (tar.getDouble(i, j, k) > max) {
						max = tar.getDouble(i, j, k);
						mx = i;
						my = j;
						mz = k;
					}
					if (tar.getDouble(i, j, k) < min) {
						min = tar.getDouble(i, j, k);
						x = i;
						y = j;
						z = k;
					}

				}
			}
		}

		MinandMaxValsD[0] = min;
		MinandMaxValsD[1] = max;
		//System.out.format("Max: " + MinandMaxValsD[0] + " Min" + MinandMaxValsD[1] + "\n");
		return MinandMaxValsD;

	}


}
