package edu.jhu.ece.iacl.algorithms.vabra;

import java.util.ArrayList;
import java.util.List;

import edu.jhmi.rad.medic.libraries.ImageFunctionsPublic;
import edu.jhu.ece.iacl.algorithms.registration.NDimHistogramModifier;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities.InterpolationType;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamWeightedVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.structures.image.ImageDataMath;

public class VabraHistogramsSSD extends VabraHistograms{

	public double[] origSSD;
	public double[] currentSSD;
	public double[] defSTxPlus,defSTxMinus,defSTyPlus,defSTyMinus,defSTzPlus,defSTzMinus;
	
	protected int numOfCh;
	protected double[] chWeights;
	
	public VabraHistogramsSSD(int numOfSub, int numOfTar, int numOfBins) {
		super(numOfSub, numOfTar, numOfBins);
		
		this.numOfBins = numOfBins;
		this.numOfSub = numOfSub;
		this.numOfTar = numOfTar;
		
		this.numOfCh = Math.min(numOfSub, numOfTar);
		
		chWeights = new double[numOfCh];
		for (int i = 0; i < numOfCh; i++ ) chWeights[i] = 1/((double)numOfCh);
		
		defSTxPlus = new double[numOfCh];
		defSTyPlus = new double[numOfCh];
		defSTzPlus = new double[numOfCh];

		defSTxMinus = new double[numOfCh];
		defSTyMinus = new double[numOfCh];
		defSTzMinus = new double[numOfCh];
		
		origSSD = new double[numOfCh];
		currentSSD = new double[numOfCh];

		
	}
	
	public void resetGradientHistograms() {
		super.resetGradientHistograms();

		for(int c = 0; c < numOfCh; c++){
			defSTxPlus[c] = origSSD[c];
			defSTxMinus[c] = origSSD[c];
			defSTyPlus[c] = origSSD[c];
			defSTyMinus[c] = origSSD[c];
			defSTzPlus[c] = origSSD[c];
			defSTzMinus[c] = origSSD[c];
		}
	}
	
	public void resetCurrentHistograms() {
		for(int c = 0; c < numOfCh; c++){
			currentSSD[c] = origSSD[c];
		}
	}

	public void updateHistograms(VabraVolumeCollection normedTarget, VabraVolumeCollection normedDeformedSubject, int[] boundingBox) {
	    // System.out.println(getClass().getCanonicalName()+"\t"+"UPDATE HISTOGRAMS ");
		if (origDeformedSubject == null) {
			System.out.format("null Histograms");
			// initializeHistograms();
		}
		for (int ch = 0; ch < numOfCh; ch++) {

			origSSD[ch]=0;
			for (int i = boundingBox[0]; i <= boundingBox[1]; i++) 
				for (int j = boundingBox[2]; j <= boundingBox[3]; j++) 
					for (int k = boundingBox[4]; k <= boundingBox[5]; k++){ 
						origSSD[ch] +=  (normedDeformedSubject.data[ch].getInt(i, j, k) - normedTarget.data[ch].getInt(i, j, k))*(normedDeformedSubject.data[ch].getInt(i, j, k) - normedTarget.data[ch].getInt(i, j, k));
					}
		}
	}
	
	public void dispose(){
		origDeformedSubject = null;
		origTarget = null;
		origSSD = null;

		// used in optimization -- perhaps move to child class
		currentDeformedSubject = null;
		currentTarget = null;
		currentSSD = null;
	}

	public double getOrigCost(){
		
		double costValD=0;
		
		for (int ch = 0; ch < numOfCh; ch++){
			costValD += chWeights[ch] * origSSD[ch]; 
		}
		return costValD;
	}
	
	public double getCurrentCost(){
		
		double costValD=0;
		
		for (int ch = 0; ch < numOfCh; ch++){
			costValD += chWeights[ch] * currentSSD[ch]; 
		}
		return costValD;
	}
		
	public void adjustOrigBins(int[] subBin, int[] tarBin, int[] newBin){
		adjustSSD(origSSD, subBin, tarBin, newBin);
	}
	
	public void adjustCurrentBins( int[] subBin, int[] tarBin, int[] newBin){
		adjustSSD(currentSSD, subBin, tarBin, newBin);
	}
	
	public void adjustAllGradientBins(VabraVolumeCollection subject, double origX, double origY, double origZ,  double defX, double defY, double defZ, int[] targetBins, int[] subjectBins){

		adjustGradientBins( subject, origX + defX, origY, origZ, defSTxPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX - defX, origY, origZ, defSTxMinus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY + defY, origZ, defSTyPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY - defY, origZ, defSTyMinus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY, origZ + defZ, defSTzPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY, origZ - defZ, defSTzMinus, targetBins, subjectBins);
		
	}
	
	public void getCostGradients(double[] results, double[] deltaC){
		double gradx[] = new double[numOfCh];
		double grady[] = new double[numOfCh];
		double gradz[] = new double[numOfCh];
		
		for (int ch = 0; ch < numOfCh; ch++) {
			gradx[ch] = (defSTxPlus[ch] - defSTxMinus[ch])/ (2.0f * deltaC[0]);
			grady[ch] = (defSTyPlus[ch]	- defSTyMinus[ch])/ (2.0f * deltaC[1]);
			gradz[ch] = (defSTzPlus[ch]	- defSTzMinus[ch])/ (2.0f * deltaC[2]);
		}

		results[0] = 0.0;
		results[1] = 0.0;
		results[2] = 0.0;

		for (int ch = 0; ch < numOfCh; ch++) {
			results[0] += gradx[ch] * chWeights[ch];
			results[1] += grady[ch] * chWeights[ch];
			results[2] += gradz[ch] * chWeights[ch];
		}
	}
	
	public void commitCurrentJointHistogram(){
		//Should be already done
	}
	
	private void adjustGradientBins(VabraVolumeCollection subject,double x,double y,double z, double[] jointHist, int[] targetBins, int[] subjectBins){
		//int numOfSub = subjectBins.length;
		
		double[] testValsD = new double[numOfSub]; 
		int[] testBins=new int[numOfSub];
		
		if (x < subject.getXN() && x >= 0 && y < subject.getYN() && y >= 0 && z < subject.getZN() && z >= 0) {
			subject.interpolate(x, y, z, testValsD);
		} else {
			for (int ch = 0; ch < numOfSub; ch++) testValsD[ch] = subject.minValsD[ch];
		}
		
		for (int ch = 0; ch < numOfSub; ch++) {
			testBins[ch] = subject.calculateBin(testValsD[ch], ch);
		}
			adjustSSD(jointHist, subjectBins,targetBins, testBins);
	}
	

	private void adjustSSD(double[] SSD, int[] subBin, int[] tarBin, int[] newBin){
		for (int ch = 0; ch < numOfCh; ch++){
			SSD[ch] -= (subBin[ch] - tarBin[ch]) * (subBin[ch] - tarBin[ch]);
			SSD[ch] += (newBin[ch] - tarBin[ch]) * (newBin[ch] - tarBin[ch]);
		}
	}
		
	
}
