package edu.vanderbilt.masi.algorithms.adaboost;

import edu.jhu.ece.iacl.jist.io.FileReaderWriter;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SegAdapterImageTesting extends SegAdapterImageBase {
	
	private int [] num_samples;
	private float [][] l_centers;
	private float [][][] full_prob;
	private int [] tl;
	protected ImageData estimate;
	protected ImageData initial_estimate;
	float [][][][] feature_chans = null;
	float [] mean_feature_vals = null;
	int num_channels;
	
	public SegAdapterImageTesting(ParamVolume est_vol,
								  ParamVolumeCollection feature_vols,
								  int [] tlabels,
								  int [] dilations,
								  int [][] feature_rads,
								  boolean [][][][] initial_masks) {
		
		// do some pre-processing
		tl = tlabels;
		int num_labels = tlabels.length;
		int max_dilation = 0;
		int [] max_feature_rad = new int [3];
		for (int i = 0; i < tlabels.length; i++) {
			if (dilations[i] > max_dilation)
				max_dilation = dilations[i];
			for (int k = 0; k < 3; k++)
				if (feature_rads[i][k] > max_feature_rad[k])
					max_feature_rad[k] = feature_rads[i][k];
		}
		
		// load the image information
		JistLogger.logOutput(JistLogger.INFO, "-> Loading estimated segmentation");
		initial_estimate = est_vol.getImageData(true);
		
		// set the original dimensions, cropping region and cropped dimensions
		JistLogger.logOutput(JistLogger.INFO, "-> Determining appropriate cropping region");
		set_dimensions(initial_estimate, max_dilation, max_feature_rad, initial_masks == null);
		
		// allocate the output estimate and max probability map
		JistLogger.logOutput(JistLogger.INFO, "-> Allocating space for the output segmentation");
		estimate = initial_estimate.clone();
		estimate.setName(get_output_name(initial_estimate));
		full_prob = new float [orig_dims[0]][orig_dims[1]][orig_dims[2]];
		for (int x = 0; x < orig_dims[0]; x++)
			for (int y = 0; y < orig_dims[1]; y++)
				for (int z = 0; z < orig_dims[2]; z++)
					full_prob[x][y][z] = Float.NEGATIVE_INFINITY;
		
		JistLogger.logOutput(JistLogger.INFO, "-> Estimating Per-Label Masks");
		boolean [][][] tmp_mask = new boolean[dims[0]][dims[1]][dims[2]];
		boolean [][][] tmp_segmask = new boolean[dims[0]][dims[1]][dims[2]];
		
		// compute the temporary mask
		int xi, yi, zi;
		for (int x = cr[0][0]; x <= cr[0][1]; x++)
			for (int y = cr[1][0]; y <= cr[1][1]; y++)
				for (int z = cr[2][0]; z <= cr[2][1]; z++) {
					xi = x - cr[0][0]; yi = y - cr[1][0]; zi = z - cr[2][0];
					tmp_mask[xi][yi][zi] = initial_estimate.getShort(x, y, z) > 0;
				}
		dilate_mask(tmp_mask, dims, max_dilation);
		
		// allocate space for the feature channel information
		num_channels = feature_vols.size();
		if (num_channels > 0) {
			feature_chans = new float [dims[0]][dims[1]][dims[2]][num_channels];
			mean_feature_vals = new float [num_channels];
		}
		
		// load the feature channels
		JistLogger.logOutput(JistLogger.INFO, "-> Loading Feature Images:");
		load_feature_channels(feature_vols, num_channels, 0, tmp_mask, feature_chans, mean_feature_vals);
		
		// iterate over all of the labels
		num_samples = new int [num_labels];
		l_centers = new float [num_labels][];
		for (int l = 0; l < num_labels; l++) {
			
			// compute the temporary masks to determine the number of samples from this image
			for (int x = cr[0][0]; x <= cr[0][1]; x++)
				for (int y = cr[1][0]; y <= cr[1][1]; y++)
					for (int z = cr[2][0]; z <= cr[2][1]; z++) {
						xi = x - cr[0][0]; yi = y - cr[1][0]; zi = z - cr[2][0];
						if (initial_masks == null)
							tmp_segmask[xi][yi][zi] = initial_estimate.getShort(x, y, z) == tlabels[l];
						else
							tmp_segmask[xi][yi][zi] = initial_masks[l][x][y][z];
					}
			
			// perform some operations on the masks
			dilate_mask(tmp_segmask, dims, dilations[l]);
			if (initial_masks == null)
				num_samples[l] = intersect_mask(tmp_segmask, tmp_mask, dims);
			else
				num_samples[l] = intersect_mask(tmp_segmask, tmp_segmask, dims);
			l_centers[l] = compute_label_center(tmp_segmask, feature_rads[l]);
		}
		
		JistLogger.logOutput(JistLogger.INFO, "");
	}
	
	private String get_output_name (ImageData img) {
		
		// get the the original filename
		String pfix = img.getName();
		
		// try to get the prefix
		if (pfix.endsWith(".gz") || pfix.endsWith(".bz2"))
			pfix = FileReaderWriter.getFileName(pfix);
		pfix = FileReaderWriter.getFileName(pfix);
		
		// return the prefix appended with "_AdaBoost"
		return(String.format("%s_AdaBoost", pfix));
	}

	public int get_num_samples(int ind) { return(num_samples[ind]); }
	
	public void print_info() {
		JistLogger.logOutput(JistLogger.INFO, "-> Found the following information: ");
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimensions: %dx%dx%d (%dx%dx%d)",
														    orig_dims[0], orig_dims[1], orig_dims[2],
														    dims[0], dims[1], dims[2]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of labels: %d", tl.length));
		JistLogger.logOutput(JistLogger.INFO, "");
	}
		
	public void process(int num_features,
						int tlabel,
						int tlabel_ind,
						int dilation,
						int [] feature_rad,
						int numiter,
						WeakLearner [][] learners,
						boolean [][][] initial_mask,
						int fusion_type) {

		// allocate some space using new dimensions
		boolean [][][] mask = new boolean[dims[0]][dims[1]][dims[2]];
		boolean [][][] segmask = new boolean[dims[0]][dims[1]][dims[2]];
		short [][][] est_seg = new short[dims[0]][dims[1]][dims[2]];
		float [][] X = new float [num_features][1];

		// load the estimated segmentation and resulting masks
		JistLogger.logOutput(JistLogger.INFO, "-> Performing pre-processing for this label");
		load_initial_data(initial_estimate, est_seg, mask, segmask, tlabel, dilation, num_samples[tlabel_ind], false, initial_mask);

		// compute the center of mass of the current label of interest
		float [] l_center = l_centers[tlabel_ind];
		JistLogger.logOutput(JistLogger.INFO, String.format("Label Center: %fx%fx%f", l_center[0], l_center[1], l_center[2]));

		// get the features of interest
		JistLogger.logOutput(JistLogger.INFO, "-> Applying Classifiers");
		int sample_num = 0;
		int ii = 0;
		for (int x = feature_rad[0]; x < dims[0] - feature_rad[0]; x++)
			for (int y = feature_rad[1]; y < dims[1] - feature_rad[1]; y++)
				for (int z = feature_rad[2]; z < dims[2] - feature_rad[2]; z++)
					if (segmask[x][y][z]) {
						
						JistLogger.printStatusBar(ii, num_samples[tlabel_ind]);

						// set the feature vector
						set_features(X, sample_num, x, y, z, feature_rad, l_center, 
									 est_seg, feature_chans, mean_feature_vals, num_channels, num_features);
						
						float pp = AdaBoost.classify(X, sample_num, learners, fusion_type);

						// assign this label appropriately
						assign(pp, X, sample_num, tlabel, tlabel_ind);
						ii++;

					}
	}
	
	private void assign(float pp,
					    float [][] X,
					    int sample_num,
					    int tlabel,
					    int tlabel_ind) {
		
		int ix = Math.round(X[0][sample_num] + l_centers[tlabel_ind][0] + cr[0][0]);
		int iy = Math.round(X[1][sample_num] + l_centers[tlabel_ind][1] + cr[1][0]);
		int iz = Math.round(X[2][sample_num] + l_centers[tlabel_ind][2] + cr[2][0]);
		if (pp > full_prob[ix][iy][iz]) {
			full_prob[ix][iy][iz] = pp;
			estimate.set(ix, iy, iz, tlabel);
		}
	}
	
	public ImageData get_estimate() { return(estimate); }
}
