package edu.vanderbilt.masi.algorithms.adaboost;

import edu.jhu.ece.iacl.jist.pipeline.JistPreferences;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class SegAdapterImageTraining extends SegAdapterImageBase {
	
	private int num_samples;
	private int num_pos_samples;
	private int num_neg_samples;
	private double pos_rate;
	private double neg_rate;
	private double im_rate = 1f;
	
	public SegAdapterImageTraining(ParamVolume est_vol,
								   ParamVolume man_vol,
						           int dilation,
						           int [] feature_rad,
						           int tlabel,
						           boolean [][][] initial_mask) {
		
		// load the image information (quietly)
		JistPreferences prefs = JistPreferences.getPreferences();
		int orig_level = prefs.getDebugLevel();
		prefs.setDebugLevel(JistLogger.SEVERE);
		ImageData img = est_vol.getImageData(true);
		
		// set the original dimensions, cropping region and cropped dimensions
		set_dimensions(img, dilation, feature_rad, initial_mask == null);		
		
		// load the manual segmentation into memory
		ImageData man_img = man_vol.getImageData(true);
		verify_dims(man_img);
		
		// compute the temporary masks to determine the number of samples from this image
		boolean [][][] tmp_mask = new boolean[dims[0]][dims[1]][dims[2]];
		boolean [][][] tmp_segmask = new boolean[dims[0]][dims[1]][dims[2]];
		boolean [][][] tmp_manmask = new boolean[dims[0]][dims[1]][dims[2]];
		int xi, yi, zi;
		short val;
		for (int x = cr[0][0]; x <= cr[0][1]; x++)
			for (int y = cr[1][0]; y <= cr[1][1]; y++)
				for (int z = cr[2][0]; z <= cr[2][1]; z++) {
					xi = x - cr[0][0]; yi = y - cr[1][0]; zi = z - cr[2][0];
					val = img.getShort(x, y, z);
					tmp_mask[xi][yi][zi] = val > 0;
					tmp_manmask[xi][yi][zi] = man_img.getShort(x, y, z) == tlabel;
					if (initial_mask == null) {
						tmp_segmask[xi][yi][zi] = val == tlabel;
					} else {
						tmp_segmask[xi][yi][zi] = initial_mask[x][y][z];
					}
				}
		
		// free the volume (quietly)
		est_vol.dispose();
		man_vol.dispose();
		prefs.setDebugLevel(orig_level);
		
		// perform some operations on the masks
		dilate_mask(tmp_mask, dims, dilation);
		dilate_mask(tmp_segmask, dims, dilation);
		
		
		// set the number of samples
		if (initial_mask == null)
			num_samples = intersect_mask(tmp_segmask, tmp_mask, dims);
		else
			num_samples = intersect_mask(tmp_segmask, tmp_segmask, dims);
			
		num_pos_samples = intersect_mask(tmp_manmask, tmp_segmask, dims);
		num_neg_samples = num_samples - num_pos_samples;
		int min_posneg_samples = Math.min(num_pos_samples, num_neg_samples);
		pos_rate = ((double)min_posneg_samples) / ((double)num_pos_samples);
		neg_rate = ((double)min_posneg_samples) / ((double)num_neg_samples);
	}
	
	public int get_num_samples() { return(num_samples); }
	public int get_num_samples(boolean equal_class) {
		double num = get_pos_rate(equal_class) * num_pos_samples + 
				 	 get_neg_rate(equal_class) * num_neg_samples;
		return((int)num);
	}
	public int get_num_samples(boolean equal_class, boolean equal_example) {
		double num = get_pos_rate(equal_class) * num_pos_samples + 
				 	 get_neg_rate(equal_class) * num_neg_samples;
		return((int)(num*get_rate(equal_example)));
	}
	public int get_num_pos_samples() { return(num_pos_samples); }	
	public int get_num_pos_samples(boolean equal_class) {
		double num = get_pos_rate(equal_class) * num_pos_samples;
		return((int)(num));
	}
	public int get_num_pos_samples(boolean equal_class, boolean equal_example) {
		double num = get_pos_rate(equal_class) * num_pos_samples;
		return((int)(num*get_rate(equal_example)));
	}
	public int get_num_neg_samples() { return(num_neg_samples); }
	public int get_num_neg_samples(boolean equal_class) {
		double num = get_neg_rate(equal_class) * num_neg_samples;
		return((int)(num));
	}
	public int get_num_neg_samples(boolean equal_class, boolean equal_example) {
		double num = get_neg_rate(equal_class) * num_neg_samples;
		return((int)(num*get_rate(equal_example)));
	}
	public double get_pos_rate(boolean equal_class) { return((equal_class) ? pos_rate : 1); }
	public double get_neg_rate(boolean equal_class) { return((equal_class) ? neg_rate : 1); }
	public double get_rate(boolean equal_example) { return((equal_example) ? im_rate : 1); }
	public void set_rate(int num, boolean equal_class) {
		double tmp_num_samples = get_pos_rate(equal_class) * num_pos_samples + 
								 get_neg_rate(equal_class) * num_neg_samples;
		im_rate = ((double)num) / ((double)tmp_num_samples);
	}
	
	public void print_info() {
		JistLogger.logOutput(JistLogger.INFO, "-> Found the following information: ");
		JistLogger.logOutput(JistLogger.INFO, String.format("Dimensions: %dx%dx%d (%dx%dx%d)",
														    orig_dims[0], orig_dims[1], orig_dims[2],
														    dims[0], dims[1], dims[2]));
		JistLogger.logOutput(JistLogger.INFO, String.format("Number of Possible Samples: %d", num_samples));
	}
	
	public void process(int num,
					    boolean [] Y,
					    float [][] X,
					    ParamVolumeCollection est_vols,
					    ParamVolumeCollection man_vols,
					    ParamVolumeCollection feature_vols,
					    int num_channels,
					    int num_features,
					    int tlabel,
					    int dilation,
					    int [] feature_rad,
					    RandomSampler random_sampler,
					    boolean [][][] initial_mask) {
		
		// make sure that the random sampler is using the information from this image
		random_sampler.set_current_image(this);
		
		// load the image
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Loading Estimated Segmentation"));
		ImageData img = est_vols.getParamVolume(num).getImageData(true);
		
		// allocate some space using new dimensions
		boolean [][][] mask = new boolean[dims[0]][dims[1]][dims[2]];
		boolean [][][] segmask = new boolean[dims[0]][dims[1]][dims[2]];
		short [][][] est_seg = new short[dims[0]][dims[1]][dims[2]];
		short [][][] man_seg = new short[dims[0]][dims[1]][dims[2]];
		
		// allocate space for the feature channel information
		float [][][][] feature_chans = null;
		float [] mean_feature_vals = null;
		if (num_channels > 0) {
			feature_chans = new float [dims[0]][dims[1]][dims[2]][num_channels];
			mean_feature_vals = new float [num_channels];
		}
		
		// load the estimated segmentation and resulting masks
		load_initial_data(img, est_seg, mask, segmask, tlabel, dilation, num_samples, true, initial_mask);
		
		// compute the center of mass of the current label of interest
		float [] l_center = compute_label_center(segmask, feature_rad);
		JistLogger.logOutput(JistLogger.INFO, String.format("Label Center: %fx%fx%f", l_center[0], l_center[1], l_center[2]));
		
		// load the manual segmentation
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Loading Manual Segmentation"));
		load_manual_segmentation(man_vols.getParamVolume(num), man_seg);
		
		// load the feature channels
		JistLogger.logOutput(JistLogger.INFO, "-> Loading Feature Images:");
		load_feature_channels(feature_vols, num_channels, num, mask, feature_chans, mean_feature_vals);
		
		// get the features of interest
		JistLogger.logOutput(JistLogger.INFO, "-> Getting feature information...");
		for (int x = feature_rad[0]; x < dims[0] - feature_rad[0]; x++)
			for (int y = feature_rad[1]; y < dims[1] - feature_rad[1]; y++)
				for (int z = feature_rad[2]; z < dims[2] - feature_rad[2]; z++)
					if (segmask[x][y][z]) {
						
						boolean pos = man_seg[x][y][z] == tlabel;
						if (random_sampler.use_sample(pos)) {
							
							int sample_num = random_sampler.get_sample_ind();
							
							// set the class label
							Y[sample_num] = (man_seg[x][y][z] == tlabel);
							
							// set the feature matrix
							set_features(X, sample_num, x, y, z, feature_rad, l_center, 
									 	 est_seg, feature_chans, mean_feature_vals, 
									 	 num_channels, num_features);
						}
					}
		
		JistLogger.logOutput(JistLogger.INFO, "done.");
		JistLogger.logOutput(JistLogger.INFO, "");
	}
	
	private void load_manual_segmentation(ParamVolume man_vol,
										  short [][][] man_seg) {
		
		ImageData img = man_vol.getImageData(true);
		verify_dims(img); 
		int xi, yi, zi;
		for (int x = cr[0][0]; x <= cr[0][1]; x++)
			for (int y = cr[1][0]; y <= cr[1][1]; y++)
				for (int z = cr[2][0]; z <= cr[2][1]; z++) {
					xi = x - cr[0][0]; yi = y - cr[1][0]; zi = z - cr[2][0];
					man_seg[xi][yi][zi] = img.getShort(x, y, z);
				}
		img.dispose();
	}

}
