package edu.vanderbilt.masi.algorithms.labelfusion;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;

import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.utility.JistLogger;


public class InitialProbabilityCalculator {

	private List<ImageData> images;
	private List<Integer> classes;
	private HashMap<Integer,String> intToClass;
	private HashMap<String,Integer> classToInt;
	private int numClasses;
	private int numRaters;
	private File thetaFile;
	private int[] numLabels;
	private int[] dim;
	private MultiClassPerformanceParameters theta;

	public InitialProbabilityCalculator(ParamVolumeCollection pImages,ParamFile pRaterMap,File outDir){
		outDir.mkdirs();
		setRaterMap(pRaterMap.getValue());
		JistLogger.logOutput(JistLogger.WARNING,String.format("There were %d classes found and %d atlases in total", numClasses,numRaters));
		images = pImages.getImageDataList();
		thetaFile = new File(outDir,"theta_file.txt");
		numLabels = new int[numClasses];
		Arrays.fill(numLabels,0);
		dim = new int[4];
		Arrays.fill(dim, 0);
		dim[0] = images.get(0).getRows();
		dim[1] = images.get(0).getCols();
		dim[2] = images.get(0).getSlices();
		dim[3] = images.get(0).getComponents();
		for(int i=0;i<dim.length;i++)
			if(dim[i] <=0)
				dim[i] = 1;
		JistLogger.logOutput(JistLogger.WARNING, String.format("The images are of size [%d %d %d %d]", dim[0],dim[1],dim[2],dim[3]));
		calculateNumLabels();
		String str = "There are ";
		for(int i=0;i<numClasses;i++){
			str += numLabels[i];
			str += " labels of class ";
			str += intToClass.get(i);
			str += ", ";
		}
		str = new String((String) str.subSequence(0,str.length()-2));
		str += ".";
		JistLogger.logOutput(JistLogger.WARNING, str);
		theta = new MultiClassPerformanceParameters();
		theta.setUpToBuildMatrices(numLabels,intToClass);
	}
	
	private void calculateNumLabels(){
		for(int i=0;i<numRaters;i++){
			ImageData im = images.get(i);
			int c = classes.get(i);
			for(int x=0;x<dim[0];x++){
				for(int y=0;y<dim[1];y++){
					for(int z=0;z<dim[2];z++){
						for(int k=0;k<dim[3];k++){
							int l = im.getInt(x,y,z,k);
							if(l+1 > numLabels[c]) numLabels[c] = l+1;
						}
					}
				}
			}
		}
	}

	public File getThetaFile(){
		return thetaFile;
	}

	public void run(){
		ImageData im1,im2;
		int c1,c2;
		for(int i=0;i<numRaters;i++){
			c1 = classes.get(i);
			im1 = images.get(i);
			for(int j=i;j<numRaters;j++){
				JistLogger.logOutput(JistLogger.WARNING, "String images "+i+" and "+j);
				JistLogger.logFlush();
				c2 = classes.get(j);
				im2 = images.get(j);
				runTwoImages(im1,im2,c1,c2);
			}
		}
		theta.normalize();
		writeOutput();
	}
	
	private void writeOutput(){
		try {
			BufferedWriter bw = new BufferedWriter(new FileWriter(thetaFile));
			bw.write(Integer.toString(numClasses));
			bw.write("\n");
			bw.write(theta.toString());
			bw.close();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	
	private void runTwoImages(ImageData im1, ImageData im2, int c1, int c2){
		int l1,l2;
		for(int i=0;i<dim[0];i++){
			for(int j=0;j<dim[1];j++){
				for(int k=0;k<dim[2];k++){
					for(int l=0;l<dim[3];l++){
						l1 = im1.getInt(i, j,k,l);
						l2 = im2.getInt(i, j,k,l);
						theta.addValue(c1, c2, l1, l2, 1f);
						theta.addValue(c2, c1, l2, l1, 1f);
					}
				}
			}
		}
	}

	private void setRaterMap(File f){
		try 
		{
			BufferedReader br = new BufferedReader(new FileReader(f));
			String line;

			classes = new ArrayList<Integer>();
			intToClass = new HashMap<Integer,String>();
			classToInt = new HashMap<String,Integer>();
			int n = 0;
			while((line=br.readLine())!=null){
				if(!classToInt.containsKey(line)){
					classToInt.put(line, n);
					intToClass.put(n, line);
					n++;
				}
				classes.add(classToInt.get(line));
			}
			numClasses = classToInt.values().size();
			numRaters = classes.size();
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}

}
