package edu.jhu.ece.iacl.algorithms.volume;

/**
 * @author Hanlin Wan
 *  
 * Finds the midsagittal plane that has minimal intersection with either hemisphere.
 */

import java.util.ArrayList;

import Jama.Matrix;
import edu.jhu.bme.smile.commons.optimize.DownhillSimplexND;
import edu.jhu.bme.smile.commons.optimize.OptimizableNDContinuous;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataUByte;

public class MSP extends AbstractCalculation{

	private ArrayList<double[][]> landmarks;
	private ImageData img, plane, seg, planeSeg, brainMask;
	private double[] weights;
	private int numLand;
	private Matrix lands;
	private double[] centroid;
	private double[] normal;
	private int points=0, marker=0, dim=3;
	int rows, cols, slices, iters=100;

	public MSP(){
		super();
		setLabel("MSP");
	}

	public MSP(ArrayList<double[][]> landmarks, ImageData img, ImageData brainMask){
		super();
		setLabel("MSP");
		this.landmarks=landmarks;
		this.img=img;
		this.brainMask = brainMask;
		numLand=landmarks.size();
		for (double[][] x : landmarks)
			points+=x[0].length;
		centroid = new double[dim];
		lands = new Matrix(dim,points);
		normal=new double[dim+2];
		weights=new double[points];
		rows=img.getRows();
		cols=img.getCols();
		slices=img.getSlices();
	}

	public void estimatePlane(){
		int c=0;
		for (int i = 0; i < numLand; i++) {
			double[][] land = landmarks.get(i);
			double w=computeWeight(land);
			for (int j = 0; j < land[0].length; j++) {
				weights[c++]=w;
			}			
		}
		for (int i = 0; i < dim; i++)
			centroid[i]/=points;
		getInitialPlane();

		double[] cen = new double[]{centroid[0],centroid[1],centroid[2],normal[3],normal[4]};
		double[] rad = new double[]{5,5,5,.1,.1};
		Plane p = new Plane();
		DownhillSimplexND opt = new DownhillSimplexND();		
		opt.initialize(p, cen, rad);
		opt.optimize(true);
		double[] ext = opt.getExtrema();
		System.out.println("Iters: "+opt.getIterations());
		System.out.println("WM: "+p.getValue(ext));
		centroid[0]=ext[0];
		centroid[1]=ext[1];
		centroid[2]=ext[2];
		normal[0]=Math.cos(ext[4])*Math.cos(ext[3]);
		normal[1]=Math.cos(ext[4])*Math.sin(ext[3]);
		normal[2]=Math.sin(ext[4]);
		normal[3]=ext[3];
		normal[4]=ext[4];
		writeOutPlanes();
	}

	private void getInitialPlane() {
		Matrix x0 = new Matrix(dim,points);
		for (int i = 0; i < dim; i++) {
			for (int j = 0; j < points; j++) {
				x0.set(i,j,(lands.get(i,j)-centroid[i])*weights[j]);
			}
		}

		Matrix V = x0.transpose().svd().getV();
		for (int i = 0; i < dim; i++) {
			normal[i]=V.get(i, dim-1);
		}

//		normal[3]=Math.atan2(normal[1],normal[0]); //theta
//		normal[4]=Math.atan2(normal[2], Math.sqrt(Math.pow(normal[0], 2)+Math.pow(normal[1], 2))); //phi
		normal[3]=normal[4]=0;
		printCenter("Initial");
	}

	private double computeWeight(double[][] x) {
		int len=0;
		double[] var = new double[dim];
		for (int i = 0; i < dim; i++) {
			double E1=0;
			double E2=0;
			len = x[i].length;
			for (int j = 0; j < len; j++) {
				double n=x[i][j];
				E1+=n;
				E2+=n*n;
				lands.set(i,j+marker,n);
			}
			centroid[i]+=E1;
			E1/=len;
			E2/=len;
			var[i]=E2-E1*E1;
		}
		double trace=0;
		for (int i = 0; i < var.length; i++) {
			trace+=var[i]*var[i];
		}
		marker+=len;
		return 1/trace;
	}

	private void writeOutPlanes() {
		plane=new ImageDataUByte(rows,cols,slices);
		seg=new ImageDataUByte(rows,cols,slices);
		planeSeg=new ImageDataUByte(cols,slices);
		for (int j = 0; j < cols; j++) {
			for (int k = 0; k < slices; k++) {
				double x = centroid[0] - (normal[1]*(j-centroid[1])+normal[2]*(k-centroid[2]))/normal[0];
				int pt = (int)Math.round(x);
				if (pt>=0 && pt<rows) {
					plane.set(pt,j,k,1);
					planeSeg.set(j,k,img.get(pt,j,k));
				}
				for (int i = 0; i < rows; i++) {
					if (brainMask.getUByte(i,j,k)>0) {
						if (i<=x)
							seg.set(i,j,k,1);
						else
							seg.set(i,j,k,2);
					}
				}
			}
		}
		printCenter("Optimized");
	}

	public void printCenter(String t) {
		System.out.print(t+":\n  Centriod: ");
		for (int i = 0; i < centroid.length; i++) {
			System.out.print(centroid[i]+", ");
		}
		System.out.println("\n  Theta: "+normal[3]+"   Phi: "+normal[4]);
	}

	public ImageData getPlane() {
		return plane;
	}

	public ImageData getSegmentation() {
		return seg;
	}

	public ImageData getPlaneSegmentation() {
		return planeSeg;
	}

	private class Plane implements OptimizableNDContinuous {

		public double[] getDomainMax() {
			return null;
		}

		public double[] getDomainMin() {
			return null;
		}

		public double getDomainTolerance() {
			return 1e-5;
		}

		public int getNumberOfDimensions() {
			return dim+2;
		}

		public double getValue(double[] val) {
			double w=0;
			double nx=Math.cos(val[4])*Math.cos(val[3]);
			double ny=Math.cos(val[4])*Math.sin(val[3]);
			double nz=Math.sin(val[4]);
			for (int j = 0; j < cols; j++) {
				for (int k = 0; k < slices; k++) {
					double x = val[0] - (ny*(j-val[1])+nz*(k-val[2]))/nx;
					int pt = (int)Math.round(x);
					if (pt>=0 && pt<rows && img.getUByte(pt,j,k)==3) 
						w++;
				}
			}
			return w;
		}
	}
}