package edu.vanderbilt.masi.algorithms.adaboost;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public abstract class SegAdapterImageBase extends AbstractCalculation {
	
	protected int [][] cr;
	protected int [] orig_dims;
	protected int [] dims;
	
	// abstract methods
	public abstract void print_info();
	
	protected void set_dimensions(ImageData img,
								  int dilation,
								  int [] feature_rad,
								  boolean use_cropping) {
		
		// get the dimensions of this image
		orig_dims = new int [3];
		orig_dims[0] = Math.max(img.getRows(), 1);
		orig_dims[1] = Math.max(img.getCols(), 1);
		orig_dims[2] = Math.max(img.getSlices(), 1);
		
		// make sure that the feature radius is valid
		for (int i = 0; i < 3; i++)
			if ((orig_dims[i] - 1) / 2 < feature_rad[i]) {
				String err_msg = String.format("ERROR: Feature Radius of %d is invalid for Dimension %d (size = %d)", feature_rad[i], i+1, orig_dims[i]);
				JistLogger.logOutput(JistLogger.SEVERE, err_msg);
				throw new RuntimeException(err_msg);
			}
		
		// allocate some space
		cr = new int [3][2];
		for (int i = 0; i < 3; i++) {
			cr[i][0] = 0;
			cr[i][1] = orig_dims[i]-1;
		}
		
		if (use_cropping) {
			// determine the initial cropping region
			for (int x = 0; x < orig_dims[0]; x++)
				for (int y = 0; y < orig_dims[1]; y++)
					for (int z = 0; z < orig_dims[2]; z++)
						if (img.getShort(x, y, z) > 0) {
							if (x < cr[0][0]) cr[0][0] = x;
							if (y < cr[1][0]) cr[1][0] = y;
							if (z < cr[2][0]) cr[2][0] = z;
							if (x > cr[0][1]) cr[0][1] = x;
							if (y > cr[1][1]) cr[1][1] = y;
							if (z > cr[2][1]) cr[2][1] = z;
						}
		}

		// set the new (cropped) dimensions
		dims = new int [3];
		for (int i = 0; i < 3; i++) {
			cr[i][0] = Math.max(cr[i][0] - dilation - feature_rad[i], 0);
			cr[i][1] = Math.min(cr[i][1] + dilation + feature_rad[i], orig_dims[i]-1);
			dims[i] = cr[i][1] - cr[i][0] + 1;
		}
	}
	
	protected void set_features(float [][] X,
							  int sample_num,
							  int x,
							  int y,
							  int z,
							  int [] feature_rad,
							  float [] l_center,
							  short [][][] est_seg,
							  float [][][][] feature_chans,
							  float [] mean_feature_vals,
							  int num_channels,
							  int num_features) {
		
		int cx = 0, ct, xl, xh, yl, yh, zl, zh, xi, yi, zi, ch, j, k;
		
		// add the current location relative to the center of the label
		X[cx++][sample_num] = x - l_center[0];
		X[cx++][sample_num] = y - l_center[1];
		X[cx++][sample_num] = z - l_center[2];
		
		// set the current region of interest
		xl = Math.max(x - feature_rad[0], 0);
		xh = Math.min(x + feature_rad[0], dims[0]-1);
		yl = Math.max(y - feature_rad[1], 0);
		yh = Math.min(y + feature_rad[1], dims[1]-1);
		zl = Math.max(z - feature_rad[2], 0);
		zh = Math.min(z + feature_rad[2], dims[2]-1);
		
		// set the regional features;
		for (xi = xl; xi <= xh; xi++)
			for (yi = yl; yi <= yh; yi++)
				for (zi = zl; zi <= zh; zi++) {
					X[cx++][sample_num] = est_seg[xi][yi][zi];
					for (ch = 0; ch < num_channels; ch++)
						X[cx++][sample_num] = feature_chans[xi][yi][zi][ch] - mean_feature_vals[ch];
				}
	
		// introduce the spatial correlations
		ct = cx;
		for (j = 0; j < 3; j++)
			for (k = j; k < ct; k++) {
				X[cx++][sample_num] = X[j][sample_num] * X[k][sample_num];
			}
		
		// make sure we find the expected number of features
		if (cx != num_features) {
			JistLogger.logOutput(JistLogger.SEVERE, String.format("ERROR: Found %d features, expected %d features", cx, num_features));
			throw new RuntimeException("Incorrect number of features found");
		}
	}	
	
	protected void dilate_mask(boolean [][][] mask,
							   int [] dims,
							   int dilation) {
		
		// create a temporary mask
		boolean [][][] tmask = new boolean [dims[0]][dims[1]][dims[2]];
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					tmask[x][y][z] = mask[x][y][z];
		
		// perform the dilation
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) {
					if (tmask[x][y][z] == false) continue;
					
					// set the current region of interest
					int xl = Math.max(x - dilation, 0);
					int xh = Math.min(x + dilation, dims[0]-1);
					int yl = Math.max(y - dilation, 0);
					int yh = Math.min(y + dilation, dims[1]-1);
					int zl = Math.max(z - dilation, 0);
					int zh = Math.min(z + dilation, dims[2]-1);
					
					for (int xi = xl; xi <= xh; xi++)
						for (int yi = yl; yi <= yh; yi++)
							for (int zi = zl; zi <= zh; zi++)
								mask[xi][yi][zi] = true;
				}
	}
	
	protected int intersect_mask(boolean [][][] segmask,
								 boolean [][][] mask,
								 int [] dims) {
		int num = 0;
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++) {
					segmask[x][y][z] = segmask[x][y][z] && mask[x][y][z];
					if (segmask[x][y][z])
						num++;
				}
		return(num);
	}
	
	protected void verify_dims(ImageData img) {
		
		// make sure that the dimensions match
		if (orig_dims[0] != Math.max(img.getRows(), 1) ||
			orig_dims[1] != Math.max(img.getCols(), 1) ||
			orig_dims[2] != Math.max(img.getSlices(), 1) ||
			1 != Math.max(img.getComponents(), 1)) {
			JistLogger.logOutput(JistLogger.SEVERE, "Error: Rater Dimensions do not match");
			throw new RuntimeException("Error: Dimensions do not match");
		}
		
	}
	
	protected float [] compute_label_center(boolean [][][] segmask,
										    int [] feature_rad) {
		float count = 0;
		float [] l_center = new float [3];
		
		for (int x = 0; x < dims[0]; x++)
			for (int y = 0; y < dims[1]; y++)
				for (int z = 0; z < dims[2]; z++)
					if (segmask[x][y][z])
						if (x >= feature_rad[0] && x < dims[0]-feature_rad[0] &&
							y >= feature_rad[1] && y < dims[1]-feature_rad[1] &&
							z >= feature_rad[2] && z < dims[2]-feature_rad[2]) {
							 l_center[0] += x; 
							 l_center[1] += y;
							 l_center[2] += z;
							 count++;
						}
		
		for (int i = 0; i < 3; i++)
			l_center[i] /= count;
		
		return(l_center);
	}

	protected void load_feature_channels(ParamVolumeCollection feature_vols,
									     int num_channels,
									     int currnum,
									     boolean [][][] mask,
									     float [][][][] feature_chans,
									     float [] mean_feature_vals) {

		if (num_channels > 0) {
			
			// initialize some variables
			int xi, yi, zi, ji, count;
			int j_start = currnum * num_channels;
			int j_end = (currnum+1) * num_channels;
			
			// load each feature volume
			ImageData img;
			for (int j = j_start; j < j_end; j++) {
				
				count = 0;
				ji = j - j_start;
				
				// load the feature information
				img = feature_vols.getParamVolume(j).getImageData(true);
				verify_dims(img); 
				for (int x = cr[0][0]; x <= cr[0][1]; x++)
					for (int y = cr[1][0]; y <= cr[1][1]; y++)
						for (int z = cr[2][0]; z <= cr[2][1]; z++) {
							
							// set the relative indices
							xi = x - cr[0][0];
							yi = y - cr[1][0];
							zi = z - cr[2][0];
							
							feature_chans[xi][yi][zi][ji] = img.getFloat(x, y, z);
							
							if (mask[xi][yi][zi]) {
								mean_feature_vals[ji] += feature_chans[xi][yi][zi][ji];
								count++;
							}
						}
				img.dispose();
				
				// set the final mean value for this feature
				mean_feature_vals[ji] /= (float)count;
				JistLogger.logOutput(JistLogger.INFO, String.format("Mean Intensity: %f", mean_feature_vals[ji]));
			}
		} else
			JistLogger.logOutput(JistLogger.INFO, "No feature images provided");
		
	}

	protected void load_initial_data(ImageData img,
									 short [][][] est_seg,
									 boolean [][][] mask,
									 boolean [][][] segmask,
									 int tlabel,
									 int dilation,
									 int num_samples,
									 boolean dispose_img,
									 boolean [][][] initial_mask) {

		int tmp_num_samples = 0;
		int xi, yi, zi;
		short val;
		for (int x = cr[0][0]; x <= cr[0][1]; x++)
			for (int y = cr[1][0]; y <= cr[1][1]; y++)
				for (int z = cr[2][0]; z <= cr[2][1]; z++) {
					xi = x - cr[0][0];
					yi = y - cr[1][0];
					zi = z - cr[2][0];
					val = img.getShort(x, y, z);
					est_seg[xi][yi][zi] = val;
					mask[xi][yi][zi] = val > 0;
					if (initial_mask == null)
						segmask[xi][yi][zi] = val == tlabel;
					else
						segmask[xi][yi][zi] = initial_mask[x][y][z];
				}
		
		if (dispose_img)
			img.dispose();

		// perform some operations on the masks
		dilate_mask(mask, dims, dilation);
		dilate_mask(segmask, dims, dilation);
		if (initial_mask == null)
			tmp_num_samples = intersect_mask(segmask, mask, dims);
		else
			tmp_num_samples = intersect_mask(segmask, segmask, dims);
		
		// dispose of the image, if desired
		if (dispose_img)
			img.dispose();
		
		// make sure the number of samples is correct
		if (tmp_num_samples != num_samples) {
			JistLogger.logOutput(JistLogger.SEVERE, String.format("ERROR: Found %d samples, expected %d samples", tmp_num_samples, num_samples));
			throw new RuntimeException("Incorrect number of samples found");
		}

	}

}
