package edu.vanderbilt.masi.algorithms.adaboost;

import java.util.Random;

import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.utility.JistLogger;

public class RandomSampler extends AbstractCalculation {
	
	Random rand;
	boolean equal_class;
	boolean equal_examples;
	int num_samples_total;
	int real_num_samples;
	int real_num_pos_samples;
	int real_num_neg_samples;
	int max_num_samples;
	double sample_rate;
	double currim_posrate;
	double currim_negrate;
	
	public RandomSampler(int max_num_samples,
						 boolean equal_class,
						 boolean equal_examples) {
		
		this.equal_class = equal_class;
		this.equal_examples = equal_examples;
		this.max_num_samples = max_num_samples;
		rand = new Random();
		real_num_samples = 0;
		real_num_pos_samples = 0;
		real_num_neg_samples = 0;
	
	}

	public void set_sampling_information(SegAdapterImageTraining [] ims) {
		
		JistLogger.logOutput(JistLogger.INFO, String.format("-> Determining Sampling Information"));
		
		int num_obs = ims.length;
		int orig_num_samples = 0;
		int orig_num_pos_samples = 0;
		int orig_num_neg_samples = 0;
		int con_num_samples = 0;
		int con_num_pos_samples = 0;
		int con_num_neg_samples = 0;
		
		// set the minimum possible samples on a single image
		int min_possible_samples_im = Integer.MAX_VALUE;
		for (int i = 0; i < num_obs; i++) {
			orig_num_samples += ims[i].get_num_samples();
			orig_num_pos_samples += ims[i].get_num_pos_samples();
			orig_num_neg_samples += ims[i].get_num_neg_samples();
			if (ims[i].get_num_samples(equal_class) < min_possible_samples_im)
				min_possible_samples_im = ims[i].get_num_samples(equal_class);
		}
		
		// set the relative rates for each image
		for (int i = 0; i < num_obs; i++)
			ims[i].set_rate((int)min_possible_samples_im, equal_class);
		
		// set the total number of possible samples given the above constraints
		for (int i = 0; i < num_obs; i++) {
			con_num_samples += ims[i].get_num_samples(equal_class, equal_examples);
			con_num_pos_samples += ims[i].get_num_pos_samples(equal_class, equal_examples);
			con_num_neg_samples += ims[i].get_num_neg_samples(equal_class, equal_examples);
		}
		
		num_samples_total = Math.min(max_num_samples, con_num_samples);
		sample_rate = ((double)num_samples_total) / ((double)(con_num_samples));
		int num_pos_samples_total = (int)(con_num_pos_samples * sample_rate);
		int num_neg_samples_total = (int)(con_num_neg_samples * sample_rate);
		
		// print out some sample information
		JistLogger.logOutput(JistLogger.INFO, String.format("Total Number of Possible Samples: %d", orig_num_samples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Total Number of Possible Positive Samples: %d", orig_num_pos_samples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Total Number of Possible Negative Samples: %d", orig_num_neg_samples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Constrained Number of Possible Samples: %d", con_num_samples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Constrained Number of Possible Positive Samples: %d", con_num_pos_samples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Constrained Number of Possible Negative Samples: %d", con_num_neg_samples));
		JistLogger.logOutput(JistLogger.INFO, String.format("Approximate Final Number of Samples: %d", num_samples_total));
		JistLogger.logOutput(JistLogger.INFO, String.format("Approximate Final Number of Positive Samples: %d", num_pos_samples_total));
		JistLogger.logOutput(JistLogger.INFO, String.format("Approximate Final Number of Negative Samples: %d", num_neg_samples_total));
		JistLogger.logOutput(JistLogger.INFO, String.format("Overall Sample Rate: %f", sample_rate));
		JistLogger.logOutput(JistLogger.INFO, String.format(""));
		
	}

	public int get_num_samples_total() { return num_samples_total; }
	public int get_num_samples_final() { return real_num_samples; }
	public int get_num_pos_samples_final() { return real_num_pos_samples; }
	public int get_num_neg_samples_final() { return real_num_neg_samples; }
	public int get_sample_ind() { return(real_num_samples-1); }
	
	public void set_current_image(SegAdapterImageTraining im) {
		currim_posrate = sample_rate * im.get_rate(equal_examples) * im.get_pos_rate(equal_class);
		currim_negrate = sample_rate * im.get_rate(equal_examples) * im.get_neg_rate(equal_class);
	}
	
	public boolean use_sample(boolean pos) {

		// apply the random sampling approach
		if (real_num_samples >= num_samples_total)
			return(false);
		
		double num = rand.nextDouble();
		
		if (pos && num > currim_posrate)
			return(false);
		if (!pos && num > currim_negrate)
			return(false);
		
		if (pos)
			real_num_pos_samples++;
		else
			real_num_neg_samples++;
		
		real_num_samples++;
		
		return(true);
	}
}
