package edu.umcu.plugins.spectro;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.Vector;

import edu.jhu.bme.smile.commons.math.Spline;
import edu.jhu.ece.iacl.jist.io.FileExtensionFilter;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamBoolean;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFile;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.VoxelType;
import edu.jhu.ece.iacl.jist.utility.JistLogger;


public class CESTFreqCorr extends ProcessingAlgorithm{
	/****************************************************
	 * Input Parameters
	 ****************************************************/
	private ParamVolume CESTVolume;			// 4-D Volume containing MT data with a range of offsets	
	private ParamVolume ShiftMap; 			// 3D-Volume containing F0 offsets
	private ParamVolume MaskVolume;			// 3D-Volume containing mask of the measured volume
	private ParamFile ShiftList;			// txt file containing list of input offsets
//	private ParamFloat Asymfreq;			// Frequency to calc asymmetry at
	private ParamBoolean NormBoo;		// Boolean whether or not to normalize the data between 0 and 1
	

	/****************************************************
	 * Output Parameters
	 ****************************************************/
	private ParamVolume CorrectedCESTVolume;	// Corrected CEST curves


	private static final String cvsversion = "$Revision: 1.2 $";
	private static final String revnum = cvsversion.replace("Revision: ", "").replace("$", "").replace(" ", "");
	private static final String shortDescription = "CEST data freq correction";
	private static final String longDescription = "Corrects a CEST dataset using a frequency shift map\n" +
			"Inputs are:\n" +
			"  * 4-D volume containing the CEST dataset (including the unsaturated volume as the first volume (V0))\n" +
			"  * 3-D volume containing F0/B0 offsets\n" +
			"  * 3-D volume containing a mask of the brain (optional)\n" +
			"  * .txt file containing a list with the frequency offsets with respect to the water frequency.\n"+
			"Output will be the nomalized signal intensities at the frequencies given in the list.";
	
	protected void createInputParameters(ParamCollection inputParams) {
		/****************************************************
		 * Step 1. Set Plugin Information
		 ****************************************************/
		inputParams.setPackage("UMCU");
		inputParams.setCategory("Magnetization Transfer");
		inputParams.setLabel("CEST frequency correction");
		inputParams.setName("CEST frequency correction");

		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("http://www.nitrc.org/projects/jist");
		info.add(new AlgorithmAuthor("Daniel Polders","daniel.polders@gmail.com",""));
		info.setAffiliation("UMC Utrecht, dep of Radiology");
		info.setDescription(shortDescription+"\n"+longDescription);
		info.setLongDescription(shortDescription + longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.ALPHA);
//		info.add(new Citation(""));

		/****************************************************
		 * Step 2. Add input parameters to control system
		 ****************************************************/
		inputParams.add(CESTVolume=new ParamVolume("MTR Data (4D)",null,-1,-1,-1,-1));
		inputParams.add(ShiftMap=new ParamVolume("Shift map (3D)",null,-1,-1,-1,1));
		inputParams.add(MaskVolume=new ParamVolume("Mask (3D)",null,-1,-1,-1,1));
		MaskVolume.setMandatory(false);
		
		inputParams.add(ShiftList=new ParamFile("List of Shifts (Hz)",new FileExtensionFilter(new String[]{"txt"})));
		//inputParams.add(Asymfreq= new ParamFloat ("Asymmetry frequency (Hz)"));
		inputParams.add(NormBoo = new ParamBoolean("Normalize Data",true));
	}


	protected void createOutputParameters(ParamCollection outputParams) {
		/****************************************************
		 * Step 1. Add output parameters to control system
		 ****************************************************/
//		AsymmetryMap = new ParamVolume("Estimated Asymmetry Map",VoxelType.FLOAT,-1,-1,-1,-1);
//		AsymmetryMap.setName("AsymmetryMap");
//		outputParams.add(AsymmetryMap);	
//		AsymmRatioMap = new ParamVolume("Estimated Asymmetry Ration Map", VoxelType.FLOAT, -1,-1,-1,-1);
//		AsymmRatioMap.setName("AsymmRatioMap");
//		outputParams.add(AsymmRatioMap);
		CorrectedCESTVolume = new ParamVolume("Corrected CEST data",VoxelType.FLOAT,-1,-1,-1,-1);
		CorrectedCESTVolume.setName("CorrCESTVolume");
		outputParams.add(CorrectedCESTVolume);
	}


	protected void execute(CalculationMonitor monitor) throws AlgorithmRuntimeException {
		AlgorithmWrapper wrapper=new AlgorithmWrapper();
		monitor.observe(wrapper);
		wrapper.execute();
	}


	protected class AlgorithmWrapper extends AbstractCalculation {
		protected void execute() {
			this.setLabel("Loading Data");
			/****************************************************
			 * Step 1. Indicate that the plugin has started.
			 * 		 	Tip: Use limited JistLogger.logOutput statements
			 * 			to allow end users to monitor the status of
			 * 			your program and report potential problems/bugs
			 * 			along with information that will allow you to
			 * 			know when the bug happened.
			 ****************************************************/
			JistLogger.logOutput(JistLogger.INFO, getClass().getCanonicalName()+"\t START");
			/****************************************************
			 * Step 2. Parse the input data
			 ****************************************************/

			ImageDataFloat cestvol=new ImageDataFloat(CESTVolume.getImageData());
			ImageDataFloat shiftvol=new ImageDataFloat(ShiftMap.getImageData());	
			boolean normbool = NormBoo.getValue();
			int r=cestvol.getRows(), c=cestvol.getCols(), s=cestvol.getSlices(), t = cestvol.getComponents();
			int rw=shiftvol.getRows(), cw=shiftvol.getCols(), sw=shiftvol.getSlices();
			//double asymfreq = Asymfreq.getFloat();
			double[]shiftArray;
			this.setTotalUnits(r);
			
			//Do mask related stuff...
			int rm = 0, cm = 0, sm = 0;
			ImageData maskvol = MaskVolume.getImageData();
			if(maskvol!= null){
				//maskboo = true;
				rm = maskvol.getRows(); cm= maskvol.getCols(); sm =maskvol.getSlices();  
			}
			
			Vector<Boolean> ignore = new Vector<Boolean>();
			Vector<Double> shift = new Vector<Double>();
			int ignoreCnt=0;
			try{ //read in all values from shiftlistfile
				BufferedReader rdr = new BufferedReader(new FileReader(ShiftList.getValue()));
				String thisline = " ";
				thisline = rdr.readLine();
				while(thisline!=null && !thisline.isEmpty()){ 	//while there is stuff to read...
					Double val = Double.valueOf(thisline);		//get values,
					if(val==null)								// and clean								
						val = Double.NaN;
					if(val.isNaN() || val.isInfinite()) {		//ignore invalid values
						val=Double.NaN;
						ignore.add(true);
						ignoreCnt++;
					} else
						ignore.add(false);
					shift.add(val);								// add to list
					System.out.println("Shift: "+val);
					thisline = rdr.readLine();
				}
			}catch(IOException e){								// check fails
				JistLogger.logError(JistLogger.SEVERE, getClass().getCanonicalName()+"\tCannot parse input shift file");
			 
			}
			
			boolean []ignoreArray= new boolean[ignore.size()];
			shiftArray= new double[ignore.size()-ignoreCnt];
			int idx =0;
			for(int i=0;i<ignoreArray.length;i++) {
				ignoreArray[i] = ignore.get(i);
				if(!ignoreArray[i]) {
					shiftArray[idx]=shift.get(i);
					idx++;
				}
			}
			
			// Check if MTR and shiftvol dimensions match
			if ((r!=rw)||(c!=cw)|(s!=sw)){
				JistLogger.logError(JistLogger.SEVERE, getClass().getCanonicalName()+"\tThe dimensions of MTR volumes and shift map do not match, aborting.");
				return;
			}
			// Check if t and shiftArray.length are the same
			JistLogger.logError(JistLogger.INFO, getClass().getCanonicalName()+"\tNumber of MTW volumes (incl. non MTw): "+t+"\t Number of offset frequencies found in file: "+shiftArray.length );
			if (shiftArray.length != t-1){
				JistLogger.logError(JistLogger.SEVERE, getClass().getCanonicalName()+"\tThe number of MTW volumes ("+ (t-1) +") and shift values ("+ shiftArray.length +") do not match, aborting.");
				return;
			}
			// Check if Asymmetry offset is within range of frequencies in ShiftArray
			double max = -Double.MAX_VALUE;
			double min = Double.MAX_VALUE;
			for(int i=0;i<shiftArray.length;i++) {
				max = (max<shiftArray[i]?shiftArray[i]:max);
				min = (min>shiftArray[i]?shiftArray[i]:min);
			}		
//			if ((!FreqRangeBoo.getValue())&&((asymfreq < min) | (asymfreq > max))){
//				JistLogger.logError(JistLogger.SEVERE, "Assymmetry frequency not within range of listed frequencies, aborting.");
//				return;
//			}
			// Check if MTR and mask dimensions match
			if ((maskvol != null) &&((r!=rm)||(c!=cm)|(s!=sm))){
				JistLogger.logError(JistLogger.SEVERE, "The dimensions of MTR volumes and mask do not match, aborting.");
				return;
			}

			/****************************************************
			 * Step 3. Setup memory for the computed volumes
			 ****************************************************/	
			ImageData corr_cest = null;
			corr_cest = new ImageDataFloat(r,c,s,t-1);
			corr_cest.setHeader(cestvol.getHeader());
			if (normbool){
				corr_cest.setName(cestvol.getName()+"_f0corr_normalized");
			}else{
				corr_cest.setName(cestvol.getName()+"_f0corr");
			}
			
			


			/****************************************************
			 * Step 4. Run the core algorithm. Note that this program
			 * 		   has NO knowledge of the MIPAV data structure and
			 * 		   uses NO MIPAV specific components. This dramatic
			 * 		   separation is a bit inefficient, but it dramatically
			 * 		   lower the barriers to code re-use in other applications.
			 ****************************************************/
			this.setLabel("Correcting Data for frequency offsets...");
				
			for(int i=0;i<r;i++) {											//Loop rows			
				this.setCompletedUnits(i);				
				for(int j=0;j<c;j++){										//Loop Columns
					for(int k=0;k<s;k++) {									//Loop Slices
						if((maskvol == null)||(maskvol.getFloat(i,j,k) > 0.0)){
							double []S_sat = new double[t-1];							//get vectors for all mtw images
							double []freqs = new double[t-1];
							double S_0 = (double)cestvol.getFloat(i,j,k,0);
							
							double maxint= 0.0;
							for(int m=0;m<t-1;m++){
								S_sat[m]=(double)cestvol.getFloat(i,j,k,m+1);
								freqs[m]= shiftArray[m]-shiftvol.getFloat(i, j, k);
								if (maxint <S_sat[m]){
									maxint = S_sat[m];
								}
							}
							
								//Normalize the saturated values first here
								for(int m=0;m<t-1;m++){
									if(normbool){
										S_sat[m]=S_sat[m]/maxint;
									}else{
										S_sat[m] = S_sat[m]/S_0;
									}
								}	
							
							Spline CESTcurve = new Spline(freqs,S_sat);
							for(int n=0; n<t-1; n++){
								double corrval = CESTcurve.spline_value(shiftArray[n]);
								if (corrval < 0){// scrub unphysical values
									corrval = 0;
								}else if (corrval > 5)
									corrval = 5;
								corr_cest.set(i,j,k,n, corrval );
								
							}							
						}
					}
				}
			}


			/****************************************************
			 * Step 5. Retrieve the image data and put it into a new
			 * 			data structure. Be sure to update the file information
			 * 			so that the resulting image has the correct
			 * 		 	field of view, resolution, etc.
			 ****************************************************/
			CorrectedCESTVolume.setValue(corr_cest);
			JistLogger.logOutput(JistLogger.INFO, getClass().getCanonicalName()+"\t FINISHED");
			//}
		}
	}
}
