package edu.vanderbilt.masi.LabelFusion;

public class SpatialData {
	
	public int num_upds;
	private double incx;
	private double incy;
	private double incz; 
	private int [] num_up;
	private int [] win_dims;
	private int [][][][] interp_matrix;
	public int [][] seed_points;
	public int [][] start_coords;
	public int [][] end_coords;
	private int [][][] sub_conv;
	private int [][] ind_conv;
	private int type;
	private int [] xu;
	private int [] yu;
	private int [] zu;
	
	public SpatialData(ObservationBase obs,
					      int [] num_up_in,
					      int [] win_dims_in,
					      int type_in) {
		
		// save the input parameters
		num_up = num_up_in; // number of updates
		win_dims = win_dims_in; // size of each window
		type = type_in;
		num_upds = num_up[0] * num_up[1] * num_up[2];
		
		// allocate memory for the interpolation matrix for each voxel
		interp_matrix = new int [obs.dims[0]][obs.dims[1]][obs.dims[2]][8];
		
		// allocate memory for the seed points
		seed_points = new int [num_upds][3];
		
		// allocate memory for the starting coordinates
		start_coords = new int [num_upds][3];
		
		// allocate memory for the ending coordinates
		end_coords = new int [num_upds][3];
		
		// conversion matrices
		sub_conv = new int [num_up[0]][num_up[1]][num_up[2]];
		ind_conv = new int [num_upds][3];
		
		// set the seed points and starting and ending coordinates
		set_seeds_coordinates(obs);
		
		// set the interpolation matrix
		set_interp_matrix(obs);
	}
	
	private void set_interp_matrix(ObservationBase obs) {

		// save some of the important variables
		for (double x = 0; x < obs.dims[0]; x++)
			for (double y = 0; y < obs.dims[1]; y++)
				for (double z = 0; z < obs.dims[2]; z++) {
					
					int xi = find_closest_ind((int)x, 0);
					int yi = find_closest_ind((int)y, 1);
					int zi = find_closest_ind((int)z, 2);
					
					// correct the update number
					int ui = sub2ind(xi, yi, zi);
					int xis = seed_points[ui][0];
					int yis = seed_points[ui][1];
					int zis = seed_points[ui][2];
					
					// set the interp_matrix for this voxel
					int sx = (x < xis) ? -1 : 1;
					int sy = (y < yis) ? -1 : 1;
					int sz = (z < zis) ? -1 : 1;
					interp_matrix[(int)x][(int)y][(int)z][0] = ui;
					interp_matrix[(int)x][(int)y][(int)z][1] = sub2ind(xi+sx, yi, zi);
					interp_matrix[(int)x][(int)y][(int)z][2] = sub2ind(xi, yi+sy, zi);
					interp_matrix[(int)x][(int)y][(int)z][3] = sub2ind(xi, yi, zi+sz);
					interp_matrix[(int)x][(int)y][(int)z][4] = sub2ind(xi+sx, yi+sy, zi);
					interp_matrix[(int)x][(int)y][(int)z][5] = sub2ind(xi+sx, yi, zi+sz);
					interp_matrix[(int)x][(int)y][(int)z][6] = sub2ind(xi, yi+sy, zi+sz);
					interp_matrix[(int)x][(int)y][(int)z][7] = sub2ind(xi+sx, yi+sy, zi+sz);
					
//					// keep track of which update gives us the minimum distance
//					double min_dist = Double.MAX_VALUE;
//					int min_u = -1;
//					
//					for (int ui2 = 0; ui2 < num_upds; ui2++) {
//						double dist = Math.sqrt(Math.pow(x - seed_points[ui2][0], 2) + 
//												Math.pow(y - seed_points[ui2][1], 2) +
//												Math.pow(z - seed_points[ui2][2], 2));
//						if (dist < min_dist) {
//							min_dist = dist;
//							min_u = ui2;
//						}
//					}
//					int ui2 = min_u;
//					int xi2 = ind_conv[ui2][0];
//					int yi2 = ind_conv[ui2][1];
//					int zi2 = ind_conv[ui2][2];
//					
//					if (xi != xi2)
//						System.out.println("X: " + xi + " " + xi2 + " " + incx + " " + seed_points[0][0] + " " + x + " " + seed_points[ui][0] + " " + seed_points[ui2][0]);
//					if (yi != yi2)
//						System.out.println("y: " + yi + " " + yi2 + " " + incy + " " + seed_points[0][1] + " " + y + " " + seed_points[ui][1] + " " + seed_points[ui2][1]);
//					if (zi != zi2)
//						System.out.println("z: " + zi + " " + zi2 + " " + incz + " " + seed_points[0][2] + " " + z + " " + seed_points[ui][2] + " " + seed_points[ui2][2]);
				}
	}
	
	private int find_closest_ind(int u, int t) {
		int ui = -1;
		int [] cu = null;
		
		if (num_up[t] == 1)
			return(0);
		
		if (t == 0)
			cu = xu;
		else if (t == 1)
			cu = yu;
		else if (t == 2)
			cu = zu;
			
		
		for (int i = 0; i < num_up[t]; i++) {

			if (i == 0) {
				//System.out.println(u + " " + cu[i]);
				if (Math.abs(u - cu[i]) <= Math.abs(u - cu[i+1])) {
					ui = i;
					break;
				}
			} else if (i == num_up[t]-1) {
				//System.out.println(u + " " + cu[i]);
				if (Math.abs(u - cu[i]) <= Math.abs(u - cu[i-1])) {
					ui = i;
					break;
				}
			} else {
				//System.out.println(u + " " + cu[i]);
				if (Math.abs(u - cu[i]) <= Math.abs(u - cu[i-1]) && Math.abs(u - cu[i]) <= Math.abs(u - cu[i+1])) {
					ui = i;
					break;
				}
			}
		}
				
		return ui;
	}
	
	private void set_seeds_coordinates(ObservationBase obs) {
		
		// get the increment values in each of the directions
		int xs = Math.round(win_dims[0] / 2) - 1;
		int xe = obs.dims[0] - xs - 1;
		int ys = Math.round(win_dims[1] / 2) - 1;
		int ye = obs.dims[1] - ys - 1;
		int zs = Math.round(win_dims[2] / 2) - 1;
		int ze = obs.dims[2] - zs - 1;
		if (win_dims[0] == 1) {
			xs = 0;
			xe = obs.dims[0]-1;
		}
		if (win_dims[1] == 1) {
			ys = 0;
			ye = obs.dims[1]-1;
		}
		if (win_dims[2] == 1) {
			zs = 0;
			ze = obs.dims[2]-1;
		}
		incx = ((double)(xe - xs)) / ((double)(num_up[0] - 1));
		incy = ((double)(ye - ys)) / ((double)(num_up[1] - 1));
		incz = ((double)(ze - zs)) / ((double)(num_up[2] - 1));
		
		// save the x, y, and z values for each of the updates
		xu = new int [num_up[0]];
		yu = new int [num_up[1]];
		zu = new int [num_up[2]];
		
		int count = 0;
		for (double i = xs; Math.round(i) <= xe; i += incx) {
			xu[count] = (int) Math.round(i);
			count++;
		}
		count = 0;
		for (double i = ys; Math.round(i) <= ye; i += incy) {
			yu[count] = (int) Math.round(i);
			count++;
		}
		count = 0;
		for (double i = zs; Math.round(i) <= ze; i += incz) {
			zu[count] = (int) Math.round(i);
			count++;
		}
		
		// save some of the important variables
		int u = -1;
		for (int xi = 0; xi < num_up[0]; xi++)
			for (int yi = 0; yi < num_up[1]; yi++)
				for (int zi = 0; zi < num_up[2]; zi++) {
					u++;
					sub_conv[xi][yi][zi] = u;
					ind_conv[u][0] = xi;
					ind_conv[u][1] = yi;
					ind_conv[u][2] = zi;
					
					// save the seed points
					seed_points[u][0] = xu[xi];
					seed_points[u][1] = yu[yi];
					seed_points[u][2] = zu[zi];
					
					// save the starting locations
					start_coords[u][0] = (int) Math.ceil(seed_points[u][0] - ((double)win_dims[0])/2 + 1);
					start_coords[u][1] = (int) Math.ceil(seed_points[u][1] - ((double)win_dims[1])/2 + 1);
					start_coords[u][2] = (int) Math.ceil(seed_points[u][2] - ((double)win_dims[2])/2 + 1);
					
					// save the ending locations
					end_coords[u][0] = (int) Math.floor(seed_points[u][0] + ((double)win_dims[0])/2 - 1);
					end_coords[u][1] = (int) Math.floor(seed_points[u][1] + ((double)win_dims[1])/2 - 1);
					end_coords[u][2] = (int) Math.floor(seed_points[u][2] + ((double)win_dims[2])/2 - 1);
					
					if (xi == 0)
						start_coords[u][0] = 0;
					if (yi == 0)
						start_coords[u][1] = 0;
					if (zi == 0)
						start_coords[u][2] = 0;
					if (xi == num_up[0]-1)
						end_coords[u][0] = obs.dims[0]-1;
					if (yi == num_up[1]-1)
						end_coords[u][1] = obs.dims[1]-1;
					if (zi == num_up[2]-1)
						end_coords[u][2] = obs.dims[2]-1;
					if (win_dims[0] == 1) {
						start_coords[u][0] = seed_points[u][0];
						end_coords[u][0] = seed_points[u][0];
					}
					if (win_dims[1] == 1) {
						start_coords[u][1] = seed_points[u][1];
						end_coords[u][1] = seed_points[u][1];
					}
					if (win_dims[2] == 1) {
						start_coords[u][2] = seed_points[u][2];
						end_coords[u][2] = seed_points[u][2];
					}
					
				}
	}
	
	public double [] get_interp_theta(int x,
									  int y,
									  int z,
									  int r,
									  int v,
									  double [][][][] theta_in) {
		
		int num_labels = theta_in.length;
		double [] thetacol = new double [num_labels];
		double [] weights = calc_weights(x, y, z);
		
		for (int l = 0; l < num_labels; l++)
			for (int s = 0; s < 8; s++) {
				int u = interp_matrix[x][y][z][s];
				if (weights[s] > 0) {
					thetacol[l] += weights[s] * theta_in[v][l][r][u];
				}
			}
		
		return(thetacol);
		
	}
	
	private double [] calc_weights(int x,
								  int y,
								  int z) {
		
		double [] weights = new double [8];
		
		// nearest neighbor interpolation
		if (type == 0)
			weights[0] = 1;
		
		// linear interpolation
		else {
			
			double weight_sum = 0;
			
			// get the weights for each neighbor
			for (int s = 0; s < 8; s++) {
				weights[s] = get_weight(x, y, z, s);
				weight_sum += weights[s];
			}
			
			// normalize the weights
			for (int s = 0; s < 8; s++) {
				weights[s] /= weight_sum;
			}
		}
		
		// return the weights
		return(weights);
	}
	
	private double get_weight(int x, int y, int z, int s) {
		
		// get the update number we're currently looking at
		int u = interp_matrix[x][y][z][s];
		
		// if u < 0, then it doesn't exist and should be de-weighted
		if (u < 0)
			return(0);
		
		// calculate the distance
		double sum = 0;
		sum += Math.pow(seed_points[u][0]-x, 2);
		sum += Math.pow(seed_points[u][1]-y, 2);
		sum += Math.pow(seed_points[u][2]-z, 2);
		double dist = Math.sqrt(sum);
		
		// return the appropriate value
		return((dist == 0) ? 100000 : 1 / dist);
	}
	
	private int sub2ind(int xi, int yi, int zi) {
		
		if (xi < 0 || xi >= num_up[0] ||
		    yi < 0 || yi >= num_up[1] ||
		    zi < 0 || zi >= num_up[2])
			return(-1);
		return(sub_conv[xi][yi][zi]);
	}

}
