package edu.vanderbilt.masi.algorithms.labelfusion.simple;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataInt;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.vanderbilt.masi.utilities.AndrewUtils;

public class NonLocalSIMPLE extends AbstractCalculation {

	private NonLocalSIMPLELabelVolume obs;
	private int poolRegion;
	private float diff;
	int numPatches;
	private int maxIter;
	private float epsilon;
	private float numDevs;

	public NonLocalSIMPLE(List<ImageData> obsLabsList,
			List<ImageData> obsImgsList, ImageData targetImage,
			int patchRadiusIn, int searchRadiusIn, float searchSTDIn,
			float minPatchWeightIn, int maxNumPatchesIn, int poolRegionIn,
			float epsilonIn, int maxIterIn, float numDevsIn) throws Exception {
		super();
		this.numDevs = numDevsIn;
		this.epsilon = epsilonIn;
		JistLogger.logOutput(JistLogger.WARNING, "Staring Non-Local SIMPLE");
		this.poolRegion = poolRegionIn;
		this.maxIter = maxIterIn;
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("We have %d observed images.", obsImgsList.size()));
		JistLogger.logOutput(JistLogger.WARNING,
				String.format("The pooling size is %d", this.poolRegion));
		JistLogger.logOutput(JistLogger.WARNING, 
				String.format("Dev val is %.3f",this.numDevs));

		obs = new NonLocalSIMPLELabelVolume(obsLabsList, obsImgsList, targetImage, patchRadiusIn, searchRadiusIn,
				searchSTDIn, minPatchWeightIn, maxNumPatchesIn);
		run();
	}

	private void run(){
		int[] dims = obs.getDims();
		int[][][][] currSeg = new int[dims[0]][dims[1]][dims[2]][dims[3]];
		for(int i=0;i<this.maxIter;i++){
			JistLogger.logOutput(JistLogger.WARNING, "Starting iteration "+i);
			JistLogger.logFlush();
			this.diff = 0f;
			numPatches = 0;
			for(int x=0;x<dims[0];x++)
				for(int y=0;y<dims[1];y++)
					for(int z=0;z<dims[2];z++)
						for(int c=0;c<dims[3];c++)
							if(obs.isConsensus(x, y, z, c))
								currSeg[x][y][z][c] = obs.getConsLab(x, y, z, c);
							else
								currSeg[x][y][z][c] = getLabelEstimate(x,y,z,c);
			runSIMPLEWeighting(currSeg);
			diff = diff / numPatches;
			JistLogger.logOutput(JistLogger.WARNING,
					String.format("The average change in weight was %.3f",diff));
			if(diff <= this.epsilon)
				break;
			JistLogger.logFlush();
		}
	}

	private void runSIMPLEWeighting(int[][][][] seg){
		for(int x=0;x<seg.length;x++)
			for(int y=0;y<seg[x].length;y++)
				for(int z=0;z<seg[x][y].length;z++)
					for(int c=0;c<seg[x][y][z].length;c++)
						if(!obs.isConsensus(x, y, z, c)){
							runNonLocalPooling(seg,x,y,z,c);
							runSIMPLEWeightRemoval(x,y,z,c);
						}
	}

	private void runSIMPLEWeightRemoval(int x, int y, int z, int c){
		Patch[][] patches = obs.getVoxelPatches(x, y, z, c);
		int n=0;
		for(Patch[] ps:patches)
			n += ps.length;
		float[] weights = new float[n];
		int nInit = n;
		n=0;
		for(Patch[] ps:patches)
			for(Patch p: ps){
				weights[n] = p.getDiceWeight();
				n++;
			}
		float mean = AndrewUtils.calculateMean(weights);
		float std  = AndrewUtils.calculateSTD(weights, mean);
		if(Float.isNaN(std))
			std = 0.0001f;
		float thresh = mean - this.numDevs*std;
		thresh = thresh - 0.0001f; // Fudge it a little. This is kind of a hack.
		Patch[][] newPatches = new Patch[patches.length][];
		int totalPatches = 0;
		for(int j=0;j<patches.length;j++){
			Patch[] ps = patches[j];
			LinkedList<Patch> pl = new LinkedList<Patch>();
			for(Patch p: ps)
				if(p.getDiceWeight() >= thresh)
					pl.add(p);
			Patch[] np = new Patch[pl.size()];
			n=0;
			for(Patch p: pl){
				np[n] = p;
				n++;
			}
			totalPatches += np.length;
			newPatches[j] = np;
		}
		if(totalPatches == 0 && nInit > 0){
			String patchString = "[";
			for(Patch[] ps:patches)
				for(Patch p: ps){
					patchString = String.format("%s%.4f, ",patchString,p.getDiceWeight());
					n++;
				}
			patchString += "]";
			JistLogger.logOutput(JistLogger.WARNING, 
					String.format("WARNING: No Patches remaining at [%d %d %d %d]. There were %d initially. Mean %.3f STD %.3f",x,y,z,c,nInit,mean,std));
			JistLogger.logOutput(JistLogger.WARNING, String.format(
					"Patches: %s",patchString));
		}
		obs.setPatches(x,y,z,c,newPatches);
	}

	private void runNonLocalPooling(int[][][][] seg, int x, int y, int z, int c){
		Patch[][] patches = obs.getVoxelPatches(x, y, z, c);
		for(int j=0;j<patches.length;j++){
			for(Patch p: patches[j])
				runNonLocalPooling(seg,x,y,z,c,j,p);
		}
	}

	private void runNonLocalPooling(int[][][][] seg, int x, int y, int z, int c, int j, Patch p){
		int xi = p.getX();
		int yi = p.getY();
		int zi = p.getZ();
		int ci = p.getC();
		int dxl = obs.getDXLow(xi, x);
		int dxh = obs.getDXHigh(xi, x);
		int dyl = obs.getDYLow(yi, y);
		int dyh = obs.getDYHigh(yi, y);
		int dzl = obs.getDZLow(zi, z);
		int dzh = obs.getDZHigh(zi, z);
		int numSame = 0;
		int numTotal = 0;
		for(int dx=-dxl;dx<=dxh;dx++)
			for(int dy=-dyl;dy<=dyh;dy++)
				for(int dz=-dzl;dz<dzh;dz++){
					int lab1 = seg[x+dx][y+dy][z+dz][c];
					int lab2 = obs.getLabel(xi+dx,yi+dy,zi+dz,ci,j);
					numTotal++;
					if(lab1 == lab2)
						numSame++;
				}
		float w = (float)numSame / (float)numTotal;
		p.setDiceWeight(w);
		this.diff += p.getWeightDifference();
		this.numPatches++;
	}

	private int getLabelEstimate(int x, int y, int z, int c){
		int l = 0;
		Patch[][] patches = obs.getVoxelPatches(x,y,z,c);
		float[] weights = new float[obs.getMaxLabel()+1];
		Arrays.fill(weights, 0f);
		int n = 0;
		float totWeight = 0f;
		for(Patch[] ps: patches)
			for(Patch p: ps){
				weights[p.getLab()] += p.getWeight();
				n++;
				totWeight += p.getWeight();
			}
		if(n < 5 && totWeight < .2){
			String msg = String.format("Only %d patches at [%d %d %d %d] with total signal %.4f. Using majority vote instead.",n,x,y,z,c,totWeight);
			JistLogger.logOutput(JistLogger.WARNING, msg);
			l = obs.getMajorityVoteEstimate(x, y, z, c);
		}
		else
			for(int i=0;i<weights.length;i++)
				if(weights[i] > weights[l])
					l = i;
		return l;
	}

	public ImageData getSegmentation(){
		JistLogger.logOutput(JistLogger.WARNING,
				"Writing the segmentation");
		int[] dims = obs.getDims();
		int[] origDims = obs.getOrigDims();
		int[][] cropping = obs.getCroppingRegion();
		ImageData im = new ImageDataInt(origDims[0], origDims[1], origDims[2], origDims[3]);
		im.setHeader(obs.getHeader());
		im.setName("segmentation");
		int l;
		for(int x=0;x<dims[0];x++)
			for(int y=0;y<dims[1];y++)
				for(int z=0;z<dims[2];z++)
					for(int c=0;c<dims[3];c++){
						if(obs.isConsensus(x, y, z, c))
							l = obs.getConsLab(x,y,z,c);
						else
							l = getLabelEstimate(x,y,z,c);
						im.set(x+cropping[0][0],y+cropping[1][0],z+cropping[2][0],c+cropping[3][0],l);
					}
		return im;
	}

}
