package edu.jhu.ece.iacl.algorithms.vabra;

import java.util.ArrayList;
import java.util.List;

import edu.jhmi.rad.medic.libraries.ImageFunctionsPublic;
import edu.jhu.ece.iacl.algorithms.registration.NDimHistogramModifier;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities.InterpolationType;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamWeightedVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.structures.image.ImageDataMath;

public class VabraHistogramsMultCh extends VabraHistograms{

	public int[][][] origJointST;
	public int[][][] currentJointST;
	public int[][][] defSTxPlus,defSTxMinus,defSTyPlus,defSTyMinus,defSTzPlus,defSTzMinus;
	
	protected int numOfCh;
	protected double[] chWeights;
	
	public VabraHistogramsMultCh(int numOfSub, int numOfTar, int numOfBins) {
		super(numOfSub, numOfTar, numOfBins);
		
		this.numOfBins = numOfBins;
		this.numOfSub = numOfSub;
		this.numOfTar = numOfTar;
		
		this.numOfCh = Math.min(numOfSub, numOfTar);
		
		chWeights = new double[numOfCh];
		for (int i = 0; i < numOfCh; i++ ) chWeights[i] = 1/((double)numOfCh);
		
		//allocated joint histograms
		defSTxPlus = new int[numOfCh][numOfBins][numOfBins];
		defSTyPlus = new int[numOfCh][numOfBins][numOfBins];
		defSTzPlus = new int[numOfCh][numOfBins][numOfBins];

		defSTxMinus = new int[numOfCh][numOfBins][numOfBins];
		defSTyMinus = new int[numOfCh][numOfBins][numOfBins];
		defSTzMinus = new int[numOfCh][numOfBins][numOfBins];
		
		origJointST = new int[numOfCh][numOfBins][numOfBins];
		
		currentJointST = new int[numOfCh][numOfBins][numOfBins];
	}
	
	public void resetGradientHistograms() {
		super.resetGradientHistograms();

		defSTxPlus = ImageData.clone(origJointST);
		defSTxMinus = ImageData.clone(origJointST);
		defSTyPlus = ImageData.clone(origJointST);
		defSTyMinus = ImageData.clone(origJointST);
		defSTzPlus = ImageData.clone(origJointST);
		defSTzMinus = ImageData.clone(origJointST);
	}
	
	public void resetCurrentHistograms() {
		currentDeformedSubject = ImageData.clone(origDeformedSubject);
		currentJointST = ImageData.clone(origJointST);
	}

	public void updateHistograms(VabraVolumeCollection normedTarget, VabraVolumeCollection normedDeformedSubject, int[] boundingBox) {
	    // System.out.println(getClass().getCanonicalName()+"\t"+"UPDATE HISTOGRAMS ");
		if (origDeformedSubject == null) {
			System.out.format("null Histograms");
			// initializeHistograms();
		}
		for (int ch = 0; ch < numOfCh; ch++) {
			RegistrationUtilities.Histogram3D(normedDeformedSubject.data, ch, numOfBins, boundingBox, origDeformedSubject);
			RegistrationUtilities.Histogram3D(normedTarget.data, ch, numOfBins, boundingBox, origTarget);
			RegistrationUtilities.JointHistogram3D(normedDeformedSubject.data,normedTarget.data, ch, 
					numOfBins, boundingBox, origJointST);
		}
	}
	
	public void dispose(){
		origDeformedSubject = null;
		origTarget = null;
		origJointST = null;

		// used in optimization -- perhaps move to child class
		currentDeformedSubject = null;
		currentTarget = null;
		currentJointST = null;
	}

	public double getOrigCost(){
		
		double nmiValD=0;
		
		for (int ch = 0; ch < numOfCh; ch++){
			nmiValD += chWeights[ch] * RegistrationUtilities.NMI(origDeformedSubject, origTarget, origJointST, ch, numOfBins); 
		}
		return -nmiValD;
	}
	
	public double getCurrentCost(){
		
		double nmiValD=0;
		
		for (int ch = 0; ch < numOfCh; ch++){
			nmiValD += chWeights[ch] * RegistrationUtilities.NMI(currentDeformedSubject, origTarget, currentJointST, ch, numOfBins); 
		}
		return -nmiValD;
	}
		
	public void adjustOrigBins(int[] subBin, int[] tarBin, int[] newBin){
		for (int ch = 0; ch < numOfCh; ch++)adjustBins(origDeformedSubject, origJointST, subBin[ch], tarBin[ch], newBin[ch],ch);
	}
	
	public void adjustCurrentBins( int[] subBin, int[] tarBin, int[] newBin){
		for (int ch = 0; ch < numOfCh; ch++)adjustBins(currentDeformedSubject, currentJointST, subBin[ch], tarBin[ch], newBin[ch],ch);	
	}
	
	public void adjustAllGradientBins(VabraVolumeCollection subject, double origX, double origY, double origZ,  double defX, double defY, double defZ, int[] targetBins, int[] subjectBins){

		adjustGradientBins( subject, origX + defX, origY, origZ, defSxPlus, defSTxPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX - defX, origY, origZ, defSxMinus, defSTxMinus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY + defY, origZ, defSyPlus, defSTyPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY - defY, origZ, defSyMinus, defSTyMinus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY, origZ + defZ, defSzPlus, defSTzPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY, origZ - defZ, defSzMinus, defSTzMinus, targetBins, subjectBins);
		
	}
	
	public void getCostGradients(double[] results, double[] deltaC){
		double gradx[] = new double[numOfCh];
		double grady[] = new double[numOfCh];
		double gradz[] = new double[numOfCh];
		
		for (int ch = 0; ch < numOfCh; ch++) {
			gradx[ch] = (RegistrationUtilities.NMI(defSxPlus, origTarget,defSTxPlus, ch, numOfBins)
					- RegistrationUtilities.NMI(defSxMinus,origTarget, defSTxMinus, ch, numOfBins))
					/ (2.0f * deltaC[0]);
			grady[ch] = (RegistrationUtilities.NMI(defSyPlus, origTarget, defSTyPlus, ch, numOfBins)
					- RegistrationUtilities.NMI(defSyMinus, origTarget, defSTyMinus, ch, numOfBins))
					/ (2.0f * deltaC[1]);
			gradz[ch] = (RegistrationUtilities.NMI(defSzPlus, origTarget, defSTzPlus, ch, numOfBins)
					- RegistrationUtilities.NMI(defSzMinus, origTarget, defSTzMinus, ch, numOfBins))
					/ (2.0f * deltaC[2]);
		}

		results[0] = 0.0;
		results[1] = 0.0;
		results[2] = 0.0;

		for (int ch = 0; ch < numOfCh; ch++) {
			results[0] += gradx[ch] * chWeights[ch];
			results[1] += grady[ch] * chWeights[ch];
			results[2] += gradz[ch] * chWeights[ch];
		}
	}
	
	public void commitCurrentJointHistogram(){
		//Should be already done
	}
	
	private void adjustGradientBins(VabraVolumeCollection subject,double x,double y,double z, int[][] subjectHist,int[][][] jointHist, int[] targetBins, int[] subjectBins){
		int numOfSub = subjectBins.length;
		
		double[] testValsD = new double[numOfSub]; 
		int testBins;
		
		if (x < subject.getXN() && x >= 0 && y < subject.getYN() && y >= 0 && z < subject.getZN() && z >= 0) {
			subject.interpolate(x, y, z, testValsD);
		} else {
			for (int ch = 0; ch < numOfSub; ch++) testValsD[ch] = subject.minValsD[ch];
		}
		
		for (int ch = 0; ch < numOfSub; ch++) {
			testBins = subject.calculateBin(testValsD[ch], ch);
			adjustBins(subjectHist, jointHist, subjectBins[ch],targetBins[ch],testBins, ch);
		}

	}
	

	private void adjustBins(int[][] subjectHist, int[][][] jointHist, int subBin, int tarBin, int newBin, int ch){
		subjectHist[ch][newBin] += 1;
		subjectHist[ch][subBin] -= 1;
		jointHist[ch][newBin][tarBin] += 1;
		jointHist[ch][subBin][tarBin] -= 1;
	}
		
	
}
