package edu.jhu.ece.iacl.algorithms.vabra;

import edu.jhu.bme.smile.commons.optimize.BrentMethod1D;
import edu.jhu.bme.smile.commons.optimize.Optimizable1DContinuous;
import edu.jhu.bme.smile.commons.optimize.Optimizer1DContinuous;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;

public class VabraOptimizer extends AbstractCalculation{

	VabraOptimizer globalPtr;
	VabraSubjectTargetPairs imgSubTarPairs;
	VabraRBF rbf;

	static double lambda = .350;;//jacobian threshold
	int[] localROI;
	double[] localCoarseGradient; // 3 = gradient dimensions
	int[] coarseLocalRegionCenter;
	int[][] fineLocalRegionCenters;
	double[] directionsOptmizationWeight;//Directional Constraints

	//Determines the way the deformation field is updated. 
	//Mode 0 is the original ABA mode where the fields are updated through summation. 
	//Mode 1 reapply the deformation field with each update and registers using the new deformed image;
	//See https://putter.ece.jhu.edu/IACL:MinProjectLogs_VABRA_DefFieldUpdate for details.


	public VabraOptimizer(VabraSubjectTargetPairs imgSubTarPairs, AbstractCalculation parent, double[] directionsOptmizationWeight) {
		super(parent);
		this.directionsOptmizationWeight = directionsOptmizationWeight;
		this.imgSubTarPairs = imgSubTarPairs;

		globalPtr = this;
		localROI = new int[6];
		rbf = new VabraRBF(imgSubTarPairs.origTargetList.get(0).getRows(),imgSubTarPairs.origTargetList.get(0).getCols(),imgSubTarPairs.origTargetList.get(0).getSlices());
		localCoarseGradient = new double[3];
		coarseLocalRegionCenter = new int[3];
	}

	public void dispose() {

		rbf = null;
		globalPtr = null;
		imgSubTarPairs = null;
		localROI = null;
		localCoarseGradient = null;
		coarseLocalRegionCenter = null;

	}

	public class CoarseOptimizer implements Optimizable1DContinuous
	{

		double domainMax,domainMin,tolerance;

		public CoarseOptimizer(double max, double min) {

			domainMax = max;
			domainMin = min;
			tolerance = 5.0e-5;
		}

		@Override
		public double getDomainMax() {
			return domainMax;
		}

		@Override
		public double getDomainMin() {
			return domainMin;
		}

		@Override
		public double getDomainTolerance() {
			return tolerance;
		}

		//@Override
		public double getValue(double c) {
			return globalPtr.coarseCostFunction(c,false);
		}		
	}
	/*
	public class FineOptimizer implements Optimizable1DContinuous
	{

		double domainMax,domainMin,tolerance;
		int numOfParam;
		double[] normalizedParams;

		public FineOptimizer(double max, double min, double[] inParams) {
			normalizedParams = inParams;
			numOfParam = normalizedParams.length;
			domainMax = max;
			domainMin = min;
			tolerance = 5.0e-5;
		}

		@Override
		public double getDomainMax() {
			return domainMax;
		}

		@Override
		public double getDomainMin() {
			return domainMin;
		}

		@Override
		public double getDomainTolerance() {
			return tolerance;
		}

		//@Override
		public double getValue(double c) {

			double[] coeffArray = new double[numOfParam];
			for (int j = 0; j < numOfParam; j++) {
				coeffArray[j] = c * normalizedParams[j];
			}
			return globalPtr.fineCostFunction(coeffArray);
		}		
	}
	 */

	public void coarseGradient(int[] regionCenter, double[] currentCostGradient) {
		double x, y, z;
		VabraVolumeCollection referenceSubject;

		int tx_valRBFoff, ty_valRBFoff, tz_valRBFoff;
		int supportRBF[] = new int[6];

		//double nmiOrig;

		int targetBins[] = new int[imgSubTarPairs.numOfTar];
		int subjectBins[] = new int[imgSubTarPairs.numOfSub];

		//Step Size
		double defX, defY, defZ;
		double deltaC[] = new double[3];
		deltaC[0] = 0.1;
		deltaC[1] = 0.1;
		deltaC[2] = 0.1;

		int TimeCounter = 0;

		long start, stop;

		//int counter memflag;
		imgSubTarPairs.resetCostGradient();
		

		//nmiOrig = imgSubTarPairs.hist.getOrigNMI();

		supportRBF[0] = Math.max(imgSubTarPairs.boundingBox[0], regionCenter[0] - rbf.getScale());
		supportRBF[1] = Math.min(imgSubTarPairs.boundingBox[1], regionCenter[0] + rbf.getScale());
		supportRBF[2] = Math.max(imgSubTarPairs.boundingBox[2], regionCenter[1] - rbf.getScale());
		supportRBF[3] = Math.min(imgSubTarPairs.boundingBox[3], regionCenter[1] + rbf.getScale());
		supportRBF[4] = Math.max(imgSubTarPairs.boundingBox[4], regionCenter[2] - rbf.getScale());
		supportRBF[5] = Math.min(imgSubTarPairs.boundingBox[5], regionCenter[2] + rbf.getScale());

		start = System.currentTimeMillis();
		for (int i = supportRBF[0]; i <= supportRBF[1]; i++) 
			for (int j = supportRBF[2]; j <= supportRBF[3]; j++) 
				for (int k = supportRBF[4]; k <= supportRBF[5]; k++) {
					/* coordinates relative to region center */
					tx_valRBFoff = i - regionCenter[0] + rbf.getOffsetX();
					ty_valRBFoff = j - regionCenter[1] + rbf.getOffsetY();
					tz_valRBFoff = k - regionCenter[2] + rbf.getOffsetZ();
					if (rbf.values[tx_valRBFoff][ty_valRBFoff][tz_valRBFoff] != 0) {

						//Calculate rbf at point multiplied by step size

						defX = deltaC[0]*rbf.values[tx_valRBFoff][ty_valRBFoff][tz_valRBFoff];
						defY = deltaC[1]*rbf.values[tx_valRBFoff][ty_valRBFoff][tz_valRBFoff];
						defZ = deltaC[2]*rbf.values[tx_valRBFoff][ty_valRBFoff][tz_valRBFoff];

						imgSubTarPairs.updateCostGradientAtPointWithDef(i, j, k, defX, defY, defZ);
					}
				}


		//Find Gradients
		imgSubTarPairs.getCurrentCostGradient(currentCostGradient, deltaC);

		//clear out directions not being optimized
		for(int i = 0; i<3;i++)currentCostGradient[i]=currentCostGradient[i]*directionsOptmizationWeight[i];
		
		stop = System.currentTimeMillis();
		if ((float) (stop - start) > 1000) {
			System.out.format("Coarse Gradient time:%f, %d\n", (float) (stop - start), TimeCounter);
			System.out.format("%f/%f\n", (float) Runtime.getRuntime().freeMemory(), (float) Runtime.getRuntime().totalMemory());
			TimeCounter = 0;
		} else TimeCounter++;

	}


	void updateFromCoarseOptimization(double coeff) {
		//int channels = tarSubImagePairs.numOfCh
		/*double val;
		float x, y, z;
		int i, j, k, ch;
		int ox_valRBFoff, oy_valRBFoff, oz_valRBFoff;*/
		if (Math.abs(coeff) < 0.00001) return;


		double maxJacob = Math.abs(coeff)* rbf.getMaxValuesChange()* (Math.max(Math.abs(localCoarseGradient[0]), 
			Math.max(Math.abs(localCoarseGradient[1]), Math.abs(localCoarseGradient[2]))));
		

		
		//double currentJacobian = coeff*coeff*coeff*rbfJacobian;
		if(!imgSubTarPairs.forceDiffeomorphismInRecon || imgSubTarPairs.checkPositiveJacobian(imgSubTarPairs.totalDeformField, coeff, localCoarseGradient, coarseLocalRegionCenter, localROI, rbf, directionsOptmizationWeight)){
			if (imgSubTarPairs.forceDiffeomorphismInRecon || (maxJacob <= lambda)){
		//if(checkPositiveJacobian(coeff)){
		//if(currentJacobian < (1-3*maxAndMinJacobian[0])/3){
		//if(maxAndMinJacobian[1] > 0){
		//	System.out.format("Accepted: Jacobian will be Positive\n");
		//}
			
		//if(currentJacobian < (1-3*maxAndMinJacobian[0])/3){
			coarseCostFunction(coeff, true);
			//System.out.format("Accepted\n");
			//System.out.format("rbfScale " +rbf.getScale()+ " downsamplefactor"+imgSubTarPairs.currentDownSampleFactor+"\n");
			imgSubTarPairs.addRBFPoint(coarseLocalRegionCenter, rbf.getScale(), imgSubTarPairs.currentDownSampleFactor, localCoarseGradient, coeff);
			}
		}//else
		//	System.out.format("Rejected:");// Current"+currentJacobian+"existing:"+maxAndMinJacobian[0]+"\n");

	}

	
	

	//float[] maxAndMinJacobian;
	//float rbfJacobian;
	ImageData[]  totalDeformFieldSplit = new ImageData[3];
//	should return change in cost function due to optimization
	public double coarseOptimize(int[] regionCenter, double[] gradient) {
		int gradParams = imgSubTarPairs.coarseGradientParameters();

		localROI[0] = Math.max(imgSubTarPairs.boundingBox[0], regionCenter[0] - rbf.getScale());
		localROI[1] = Math.min(imgSubTarPairs.boundingBox[1], regionCenter[0] + rbf.getScale());
		localROI[2] = Math.max(imgSubTarPairs.boundingBox[2], regionCenter[1] - rbf.getScale());
		localROI[3] = Math.min(imgSubTarPairs.boundingBox[3], regionCenter[1] + rbf.getScale());
		localROI[4] = Math.max(imgSubTarPairs.boundingBox[4], regionCenter[2] - rbf.getScale());
		localROI[5] = Math.min(imgSubTarPairs.boundingBox[5], regionCenter[2] + rbf.getScale());

		// do I use the old gradient or calculate a new in since things may have
		// changed
		coarseGradient(regionCenter, localCoarseGradient);
		RegistrationUtilities.VectorNormalization(localCoarseGradient, gradParams);

		coarseLocalRegionCenter[0] = regionCenter[0];
		coarseLocalRegionCenter[1] = regionCenter[1];
		coarseLocalRegionCenter[2] = regionCenter[2];

		double originalCost = coarseCostFunction(0,false);
		double delta = 0;
		double optimizedCoeff = 0;
		double optimizedCost = 0;

		//System.out.format("Orig:" +originalCost+"\n");
		if(originalCost != -2 && originalCost != 0){ //skip if already perfectly registered
			//maxAndMinJacobian = RegistrationUtilities.computeMaxAndMinJacobianDet(imgSubTarPairs.totalDeformFieldM);
			//rbfJacobian = RegistrationUtilities.computeMaxJacobianDetRBF(rbf.values,localCoarseGradient);
			//System.out.format(rbf.getMaxValuesChange() + "\n");
			//double maxCoeff = (1-3*maxAndMinJacobian[0])/(3*rbfJacobian);
			double maxCoeff = lambda/(rbf.getMaxValuesChange()* (Math.max(Math.abs(localCoarseGradient[0]), 
					Math.max(Math.abs(localCoarseGradient[1]), Math.abs(localCoarseGradient[2])))));
			//maxCoeff = lambdaBounds;//Math.min(maxCoeff, lambdaBounds);
			//System.out.format("LocalCorse:"+localCoarseGradient[0] + " " + localCoarseGradient[1] + " " + localCoarseGradient[2 ]+"\n");
			
			CoarseOptimizer fun = new CoarseOptimizer(maxCoeff,-maxCoeff);
			//CoarseOptimizer fun = new CoarseOptimizer(imgSubTarPairs.maxDimensions,-imgSubTarPairs.maxDimensions);
			Optimizer1DContinuous opt = new BrentMethod1D();
			opt.initialize(fun);		
			opt.optimize(true);
			optimizedCoeff = opt.getExtrema();
			optimizedCost = coarseCostFunction(optimizedCoeff,false);
			//System.out.format("Potential Coeff:" +optimizedCoeff +"NMI"+optimizedCost+"\n");
			delta = (optimizedCost) - originalCost;
			//System.out.format("delta:" +delta+"\n");
		}
		//System.out.format("**************************Coeff: "+optimizedCoeff[0]+"***************\n");
		if (delta < 0 && Math.abs(optimizedCoeff) >= 0.005) {
			//System.out.println(getClass().getCanonicalName()+"\t"+"UPDATE FROM COARSE OPTIMIZATION "+Math.abs(optimizedCoeff[0]));
			//System.out.format("At ("+coarseLocalRegionCenter[0]+","+coarseLocalRegionCenter[1]+","+coarseLocalRegionCenter[2]+") Coeff:"+optimizedCoeff+" NMI:"+ optimizedNMI +"\n");
			//System.out.format("New:" +coarseCostFunction(0,false)+"\n");
			//System.out.format("\nNMI passed\n");
	//		System.out.format("RBFJacobian" + rbfJacobian + "\nMaxJacobian" + maxAndMinJacobian[0] + "MinJacobian" + maxAndMinJacobian[1]
	//				+ "\nMaxJacobianID" + maxAndMinJacobian[2] + "MinJacobianID" + maxAndMinJacobian[3]+"\n");
			updateFromCoarseOptimization(optimizedCoeff);
			//System.out.format("New:" +coarseCostFunction(0,false)+"\n");
			// System.out.format("accepted\n");
		} else {
			//System.out.format("SHOULD NOT OPTIMIZE %f<0 && %f>=0.005\n",delta,Math.abs(optimizedCoeff[0]));
			delta = 0;
		}
		return delta;
	}



	//public float coarseCostFunction(float[] lambda) {
	public double coarseCostFunction(double coeff, boolean commitUpdate) {
		double rbfVal;
		double coeff_x, coeff_y, coeff_z;
		double defX, defY, defZ;
		int ox_valRBFoff, oy_valRBFoff, oz_valRBFoff;
		double tlambda;
		//int i, j, k, ch;

		imgSubTarPairs.resetCost();
		if(coeff == 0) return imgSubTarPairs.getCurrentCost(); 

		tlambda = (Math.round((double) coeff * 10000.0)) / 10000.0;

		coeff_x = ((double) tlambda) * localCoarseGradient[0];
		coeff_y = ((double) tlambda) * localCoarseGradient[1];
		coeff_z = ((double) tlambda) * localCoarseGradient[2];

		//optimized by putting outside and incrementing.
		for (int i = localROI[0]; i <= localROI[1]; i++) {
			for (int j = localROI[2]; j <= localROI[3]; j++) {
				for (int k = localROI[4]; k <= localROI[5]; k++) {

					//(ox, oy, oz) are coordinates of (i, j) relative to region center				
					ox_valRBFoff = i - coarseLocalRegionCenter[0] + rbf.getOffsetX();
					oy_valRBFoff = j - coarseLocalRegionCenter[1] + rbf.getOffsetY();
					oz_valRBFoff = k - coarseLocalRegionCenter[2] + rbf.getOffsetZ();

					rbfVal = rbf.values[ox_valRBFoff][oy_valRBFoff][oz_valRBFoff];

					if (rbfVal != 0) {
						// steepest descent direction: negative of gradient of NMI wrt c.

						//Amount to adjust with RBF if not constrained
						defX = directionsOptmizationWeight[0]*coeff_x*rbfVal;
						defY = directionsOptmizationWeight[1]*coeff_y*rbfVal;
						defZ = directionsOptmizationWeight[2]*coeff_z*rbfVal;
						
						imgSubTarPairs.updateCurrentCostAtPointWithDef(i, j, k, defX, defY, defZ, commitUpdate);
						//set as current deformation field if actually updating
					}
				}
			}
		}
		//Update copies and final deformation field if actually updating
		if(commitUpdate) imgSubTarPairs.finalDefFieldUpdate(localROI);
		//System.out.format("COARSE COST FUNC %f %f\n",lambda[0],nmiVal);
		return imgSubTarPairs.getCurrentCost();
	}

	public VabraRBF getRBF(){
		return rbf;
	}
	
	public void setLambda(double newLambda){
		lambda = newLambda; 
		
	}

}
