package edu.jhu.ece.iacl.algorithms.MGDM.forces;

import edu.jhmi.rad.medic.utilities.*;
import edu.jhu.ece.iacl.algorithms.MGDM.*;
import edu.jhu.ece.iacl.algorithms.MGDM.forces.MgdmConstants.ForceType;

/**
 * This class implements the force `INTENSITY` for Mgdm.  This acts to minimize the 
 * Mumford-Shah energy functional, and yields results similar to those in \"Active Contours
 * without Edges\" by Chan and Vese. 
 * 
 * @author Bhaskar Kishore
 * @author John Bogovic
 *
 */
public class MgdmForceIntensity extends MgdmForce {
	
	private static MgdmConstants.ForceType type = ForceType.INTENSITY;
	
	private static MgdmForceIntensity helperInstance = new MgdmForceIntensity(ForceType.CUSTOM, 0, Integer.MIN_VALUE);
	private static MgdmIntensityForceInitializer intensityInit = helperInstance.new MgdmIntensityForceInitializer();
	
	private static double[] m_channelWeights;
	
	public MgdmForceInitializerUpdater getInitializerUpdater(){
		return intensityInit;
	}
	
	public MgdmForceIntensity(ForceType typein, double w, int id){
		super(type, w, id);
	}
	
	public MgdmForceIntensity(ForceType typein, double[] w, int id){
		super(type, 1, id);
		m_channelWeights = w;
	}

	public MgdmForceIntensity(double w){
		super(ForceType.INTENSITY, w, MgdmConstants.INVALID);
		setRequiredData(new String[]{"DeltaP","DeltaM","numintensity",
				"label","neighbor","intensityforces"});
	}

	@Override
	public double getForce(int xyz, int lbl, int nbr, GdmDerivatives dr, int iteration,
			MgdmDataRepository repo, float flipValue, boolean debug) {
		
		
		float[][] intensityforces = null;
		float[][] centroids = null;
		 
		try {
			
			intensityforces	= (float[][]) repo.get("intensityforces");
			centroids		= (float[][]) repo.get("centroids");
			
		} catch(Exception e){
			System.err.println("REPO ERROR : Multiple fetch failure.");
			e.printStackTrace();
			return Double.NaN;
		}
		
		double objerr = 0;
		double comperr = 0;
		double intforce = 0;
		
		int numintensity = intensityforces.length;
		
		for (int c = 0; c < numintensity; c++) {
			
			
			if(m_channelWeights!=null){
				objerr  +=  m_channelWeights[c] * (intensityforces[c][xyz] - centroids[c][lbl])*(intensityforces[c][xyz] - centroids[c][lbl]);
				comperr +=  m_channelWeights[c] * (intensityforces[c][xyz] - centroids[c][nbr])*(intensityforces[c][xyz] - centroids[c][nbr]);
			}else{
				objerr  +=  m_Weight * (intensityforces[c][xyz] - centroids[c][lbl])*(intensityforces[c][xyz] - centroids[c][lbl]);
				comperr +=  m_Weight * (intensityforces[c][xyz] - centroids[c][nbr])*(intensityforces[c][xyz] - centroids[c][nbr]);
			}
		}
		
		intforce = Math.sqrt(comperr)-Math.sqrt(objerr);
		
		if(debug){
			System.out.println("objerr: " + objerr);
			System.out.println("comperr: " + comperr);
			System.out.println("intforce: " + intforce);
			
			System.out.println("deltaP: " +  dr.deltaP());
			System.out.println("deltaM: " +  dr.deltaM());
		}

		// weight is specific to modality and applied above
		double intensity = (Numerics.max(intforce, 0.0) * dr.deltaP() 
				+ Numerics.min(intforce, 0.0) * dr.deltaM());
		
		return intensity;
	}
	
	@Override
	public void initForce(MgdmDataRepository repo) {
		// Do nothing
	}
	
	public void fromParameters(String[] params){
		// Do nothing
	}
	
	public MgdmForceIntensity clone(){
		return new MgdmForceIntensity(this.m_Weight);
	}
	
	/**
	 * This class adds a "centroids" key and corresponding array to the MgdmDataRepository.  This array
	 * stores the average intensities for each intensity channel for each object represented by the decomposition.
	 * 
	 * Since all MgdmForceIntensity objects use the same "centroids" object, each 
	 * 
	 * 
	 * @author John Bogovic
	 * @see MgdmForceIntensity
	 * @see MgdmForceInitializerUpdater
	 *
	 */
	public class MgdmIntensityForceInitializer extends MgdmForceInitializerUpdater{
		
		@Override
		public void initializeRepoForForce(MgdmDecomposition mgdmDecomp, MgdmDataRepository repo) {
			
			float[][] intensity = null;
			
//			try{
//				intensity =  (float[][])repo.get("intensityforces");
//				
//			}catch(Exception e){
//				System.out.println("REPOSITORY FETCH ERROR.");
//				System.err.println("REPOSITORY FETCH ERROR.");
//				e.printStackTrace();
//				return;
//			}
			
			Object o = repo.get("intensityforces");
			System.out.println("Object o: " + o);
			intensity =  (float[][])o;
			System.out.println("intensity: "  + intensity);
			
			if(intensity==null){
				System.out.println("intensity: "  + intensity);
			}
			
			int numLabels 		= mgdmDecomp.getNumLabels();
			int numIntensities  = intensity.length;
		
			
			float[][] centroids = new float[numIntensities][numLabels];
			try{
				repo.add("centroids", centroids);
			}catch(Exception e){
				System.out.println("REPOSITORY ADD ERROR.");
				System.err.println("REPOSITORY ADD ERROR.");
				e.printStackTrace();
				return;
			}
			
			updateRepoForForce(mgdmDecomp, repo);
			
		}


		@Override
		public void updateRepoForForce(MgdmDecomposition mgdmDecomp, MgdmDataRepository repo) {
			
			MgdmDecomposition.MgdmDecompositionIterator it = mgdmDecomp.iterator();
			
			float[][] centroids = null;
			float[][] intensity = null;
			
			try{
				centroids = (float[][])repo.get("centroids");
				intensity =  (float[][])repo.get("intensityforces");
			}catch(Exception e){
				System.err.println("REPOSITORY FETCH ERROR.");
				e.printStackTrace();
				return;
			}
			
			int numIntensities  = intensity.length;
			int numLabels 		= mgdmDecomp.getNumLabels();
			
			for(int i=0; i<numLabels; i++)for(int j=0; j<numIntensities; j++){
				centroids[j][i]=0;
			}
			
			int[] labelcounts = new int[numLabels];
			
			while(it.hasNext()){ // loop over space
				
				int xyz = it.next();
				int thislabel = mgdmDecomp.getLabel(xyz, 0);
				
				// increment count for label
				labelcounts[thislabel]++;
				
				// sum intensity values for this label
				for(int n=0; n<numIntensities; n++){
					centroids[n][thislabel] += intensity[n][xyz];
				}
				
				
			} // end loop over space
			
			// sum intensity values for this label
			for(int lb=0; lb<numLabels; lb++){
				
				if(labelcounts[lb]>0) for(int n=0; n<numIntensities; n++){
					centroids[n][lb] /= labelcounts[lb];
				}
				
			}
			
			
		}
		
	}
	
	
}
