package edu.jhu.ece.iacl.algorithms.vabra;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import edu.jhmi.rad.medic.libraries.ImageFunctionsPublic;
import edu.jhu.ece.iacl.algorithms.registration.NDimHistogramModifier;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities.InterpolationType;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamWeightedVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.structures.image.ImageDataMath;

public class VabraHistogramsNCC extends VabraHistograms{

	public double[][] origNCC;
	public double[][] currentNCC;
	public double[][] defSTxPlus,defSTxMinus,defSTyPlus,defSTyMinus,defSTzPlus,defSTzMinus;
	
	protected int numOfCh;
	protected int numOfVoxels;//number of voxels
	protected double[] chWeights;
	
	public VabraHistogramsNCC(int numOfSub, int numOfTar, int numOfBins) {
		super(numOfSub, numOfTar, numOfBins);
		
		this.numOfBins = numOfBins;
		this.numOfSub = numOfSub;
		this.numOfTar = numOfTar;
		this.numOfCh = Math.min(numOfSub, numOfTar);
		numOfVoxels = 0;
		chWeights = new double[numOfCh];
		for (int i = 0; i < numOfCh; i++ ) chWeights[i] = 1/((double)numOfCh);
		
		
		//NCC = <S'/||S'||, T'/||T'||>
		//where S' = S - mean(S) and T' = T - mean(T)
		//Q = Sum_i((S_i - mean(S))(T_i - mean(T))
		//	= Sum_i(S_i(T_i - mean(T)))
		//V = Sum_i(S_i - mean(S))^2
		//U = Sum_i(T_i - mean(T))^2
		//NCC = Q/(V*U)
		//when S_i becomes S_i_new, we need:
		//Q_new = Q + dQ for
		//dQ = - (S_i-S_i_new)(T_i - mean(T))
		//V_new = V + dv for
		//dV = - S_i^2 + S_i_new^2 + mean(S)*N - mean(S_new)*N
		//mean(S_new) = mean(S) - S_i/N + S_i_new/N
		//NCC_new =  Q_new/sqrt(V_new*U)
		//To calculate this we need to keep 6 values for each channel:
		//index[0]=Q
		//index[1]=V
		//index[2]=U
		//index[3]=mean(S)
		//index[4]=mean(T)
		//index[5]=NCC
		//for the following variables
		defSTxPlus = new double[numOfCh][6];
		defSTyPlus = new double[numOfCh][6];
		defSTzPlus = new double[numOfCh][6];

		defSTxMinus = new double[numOfCh][6];
		defSTyMinus = new double[numOfCh][6];
		defSTzMinus = new double[numOfCh][6];
		
		origNCC = new double[numOfCh][6];
		currentNCC = new double[numOfCh][6];

		
	}
	
	public void resetGradientHistograms() {
		super.resetGradientHistograms();

		for(int c = 0; c < numOfCh; c++){
			defSTxPlus[c] = Arrays.copyOf(origNCC[c],origNCC[c].length);
			defSTxMinus[c] = Arrays.copyOf(origNCC[c],origNCC[c].length);
			defSTyPlus[c] = Arrays.copyOf(origNCC[c],origNCC[c].length);
			defSTyMinus[c] = Arrays.copyOf(origNCC[c],origNCC[c].length);
			defSTzPlus[c] = Arrays.copyOf(origNCC[c],origNCC[c].length);
			defSTzMinus[c] = Arrays.copyOf(origNCC[c],origNCC[c].length);
		}
	}
	
	public void resetCurrentHistograms() {
		for(int c = 0; c < numOfCh; c++){
			currentNCC[c] = Arrays.copyOf(origNCC[c],origNCC[c].length);
		}
	}

	public void updateHistograms(VabraVolumeCollection normedTarget, VabraVolumeCollection normedDeformedSubject, int[] boundingBox) {
	    // System.out.println(getClass().getCanonicalName()+"\t"+"UPDATE HISTOGRAMS ");
		if (origDeformedSubject == null) {
			System.out.format("null Histograms");
			// initializeHistograms();
		}
		
		
		//NCC = <S'/||S'||, T'/||T'||>
		//where S' = S - mean(S) and T' = T - mean(T)
		//Q = Sum_i((S_i - mean(S))(T_i - mean(T))
		//	= Sum_i(S_i(T_i - mean(T)))
		//V = Sum_i(S_i - mean(S))^2
		//U = Sum_i(T_i - mean(T))^2
		//NCC = Q/sqrt(V*U)
		//To calculate this we need to keep 6 values for each channel:
		//index[0]=Q
		//index[1]=V
		//index[2]=U
		//index[3]=mean(S)
		//index[4]=mean(T)
		//index[5]=NCC
		
		
		numOfVoxels = (boundingBox[0]-boundingBox[1]+1)*
				(boundingBox[2]-boundingBox[3]+1)*
					(boundingBox[4]-boundingBox[5]+1);
		
		for (int ch = 0; ch < numOfCh; ch++) {
			
			//find mean(S) and mean(T) first
			origNCC[ch][3] = 0;
			origNCC[ch][4] = 0;
			for (int i = boundingBox[0]; i <= boundingBox[1]; i++) 
				for (int j = boundingBox[2]; j <= boundingBox[3]; j++) 
					for (int k = boundingBox[4]; k <= boundingBox[5]; k++){ 
						origNCC[ch][3] +=  normedDeformedSubject.data[ch].getInt(i, j, k); 
						origNCC[ch][4] +=  normedTarget.data[ch].getInt(i, j, k);
					}
			
			origNCC[ch][3] /= numOfVoxels;
			origNCC[ch][4] /= numOfVoxels;

			
			//find Q V and U
			origNCC[ch][0] = 0;
			origNCC[ch][1] = 0;
			origNCC[ch][2] = 0;
			double subMinusMean;
			double tarMinusMean;
			for (int i = boundingBox[0]; i <= boundingBox[1]; i++) 
				for (int j = boundingBox[2]; j <= boundingBox[3]; j++) 
					for (int k = boundingBox[4]; k <= boundingBox[5]; k++){
						subMinusMean = normedDeformedSubject.data[ch].getInt(i, j, k) - origNCC[ch][3];
						tarMinusMean = normedTarget.data[ch].getInt(i, j, k) - origNCC[ch][4];
						
						origNCC[ch][0] += subMinusMean*tarMinusMean;
						origNCC[ch][1] += subMinusMean*subMinusMean;
						origNCC[ch][2] += tarMinusMean*tarMinusMean;
					}
			
			
			
			origNCC[ch][5]=origNCC[ch][0]/Math.sqrt(origNCC[ch][1]*origNCC[ch][2]);
		}
	}
	
	public void dispose(){
		origDeformedSubject = null;
		origTarget = null;
		origNCC = null;

		// used in optimization -- perhaps move to child class
		currentDeformedSubject = null;
		currentTarget = null;
		currentNCC = null;
	}

	public double getOrigCost(){
		
		double costValD=0;
		
		for (int ch = 0; ch < numOfCh; ch++){
			costValD += chWeights[ch] * origNCC[ch][5]; 
		}
		return -costValD;
	}
	
	public double getCurrentCost(){
		
		double costValD=0;
		
		for (int ch = 0; ch < numOfCh; ch++){
			costValD += chWeights[ch] * currentNCC[ch][5]; 
		}
		return -costValD;
	}
		
	public void adjustOrigBins(int[] subBin, int[] tarBin, int[] newBin){
		adjustNCC(origNCC, subBin, tarBin, newBin);
	}
	
	public void adjustCurrentBins( int[] subBin, int[] tarBin, int[] newBin){
		adjustNCC(currentNCC, subBin, tarBin, newBin);
	}
	
	public void adjustAllGradientBins(VabraVolumeCollection subject, double origX, double origY, double origZ,  double defX, double defY, double defZ, int[] targetBins, int[] subjectBins){

		adjustGradientBins( subject, origX + defX, origY, origZ, defSTxPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX - defX, origY, origZ, defSTxMinus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY + defY, origZ, defSTyPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY - defY, origZ, defSTyMinus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY, origZ + defZ, defSTzPlus, targetBins, subjectBins);
		adjustGradientBins( subject, origX, origY, origZ - defZ, defSTzMinus, targetBins, subjectBins);
		
	}
	
	public void getCostGradients(double[] results, double[] deltaC){
		double gradx[] = new double[numOfCh];
		double grady[] = new double[numOfCh];
		double gradz[] = new double[numOfCh];
		
		for (int ch = 0; ch < numOfCh; ch++) {
			gradx[ch] = (defSTxPlus[ch][5] - defSTxMinus[ch][5])/ (2.0f * deltaC[0]);
			grady[ch] = (defSTyPlus[ch][5]	- defSTyMinus[ch][5])/ (2.0f * deltaC[1]);
			gradz[ch] = (defSTzPlus[ch][5]	- defSTzMinus[ch][5])/ (2.0f * deltaC[2]);
		}

		results[0] = 0.0;
		results[1] = 0.0;
		results[2] = 0.0;

		for (int ch = 0; ch < numOfCh; ch++) {
			results[0] += gradx[ch] * chWeights[ch];
			results[1] += grady[ch] * chWeights[ch];
			results[2] += gradz[ch] * chWeights[ch];
		}
	}
	
	public void commitCurrentJointHistogram(){
		//Should be already done
	}
	
	private void adjustGradientBins(VabraVolumeCollection subject,double x,double y,double z, double[][] jointHist, int[] targetBins, int[] subjectBins){
		//int numOfSub = subjectBins.length;
		
		double[] testValsD = new double[numOfSub]; 
		int[] testBins=new int[numOfSub];
		
		if (x < subject.getXN() && x >= 0 && y < subject.getYN() && y >= 0 && z < subject.getZN() && z >= 0) {
			subject.interpolate(x, y, z, testValsD);
		} else {
			for (int ch = 0; ch < numOfSub; ch++) testValsD[ch] = subject.minValsD[ch];
		}
		
		for (int ch = 0; ch < numOfSub; ch++) {
			testBins[ch] = subject.calculateBin(testValsD[ch], ch);
		}
			adjustNCC(jointHist, subjectBins,targetBins, testBins);
	}
	

	private void adjustNCC(double[][] NCC, int[] subBin, int[] tarBin, int[] newBin){
		double mean_S_old = 0;
		for (int ch = 0; ch < numOfCh; ch++){
			//Q_new = Q + dQ for
			//dQ = - (S_i-S_i_new)(T_i - mean(T))
			NCC[ch][0] += -(subBin[ch] - newBin[ch]) * (tarBin[ch] - NCC[ch][4]);

			//mean(S_new) = mean(S) - S_i/N + S_i_new/N
			mean_S_old = NCC[ch][3];
			NCC[ch][3] += -(subBin[ch] - newBin[ch])/numOfVoxels;
			
			//V_new = V + dv for
			//dV = - S_i^2 + S_i_new^2 + mean(S)*N - mean(S_new)*N
			NCC[ch][1] += -subBin[ch]*subBin[ch] + newBin[ch]*newBin[ch]
					+(mean_S_old-NCC[ch][3])*numOfVoxels;
			
			//NCC_new = Q_new/sqrt(V_new*U)
			NCC[ch][5] =  NCC[ch][0]/Math.sqrt(NCC[ch][1]*NCC[ch][2]);
		}
	}
		
	
}
