/*
 *
 */
package edu.jhu.ece.iacl.plugins.registration;

import java.io.File;
import java.util.ArrayList;
import java.util.List;

import Jama.Matrix;

//import edu.jhu.ece.iacl.algorithms.PrinceGroupAuthors;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities;
import edu.jhu.ece.iacl.algorithms.registration.RegistrationUtilities.InterpolationType;
import edu.jhu.ece.iacl.jist.io.ArrayDoubleMtxReaderWriter;
import edu.jhu.ece.iacl.jist.io.ArrayDoubleReaderWriter;
import edu.jhu.ece.iacl.jist.io.FileExtensionFilter;
import edu.jhu.ece.iacl.jist.io.ImageDataReaderWriter;
import edu.jhu.ece.iacl.jist.pipeline.AbstractCalculation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.AlgorithmAuthor;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmInformation.Citation;
import edu.jhu.ece.iacl.jist.pipeline.AlgorithmRuntimeException;
import edu.jhu.ece.iacl.jist.pipeline.CalculationMonitor;
import edu.jhu.ece.iacl.jist.pipeline.DevelopmentStatus;
import edu.jhu.ece.iacl.jist.pipeline.ProcessingAlgorithm;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamFileCollection;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamMatrix;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamModel;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamOption;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolume;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamVolumeCollection;
import edu.jhu.ece.iacl.jist.structures.image.ImageData;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataFloat;
import edu.jhu.ece.iacl.jist.structures.image.ImageDataMipav;
import edu.jhu.ece.iacl.jist.structures.image.ImageHeader;
import edu.jhu.ece.iacl.jist.utility.JistLogger;
import edu.jhu.ece.iacl.jist.pipeline.parameter.ParamBoolean;
//import edu.jhu.ece.iacl.plugins.registration.MedicAlgorithmFLIRTCollection.FlirtWrapper;


/*
 * @author Min Chen (mchen55@jhu.edu)
 *
 */
public class MedicAlgorithmCombTransAndDeforms extends ProcessingAlgorithm{
	//Inputs
	public ParamFileCollection inParamTransformMatrices;
	public ParamVolumeCollection inParamDefFields;
	public ParamMatrix inParamTransformMatrix;
	
	//Outputs
	public ParamVolumeCollection outParamCombDefField;


	//Internal Variables
	int XN, YN, ZN; 
	int chN = 3;
	ArrayDoubleReaderWriter mtxRW;
	ImageDataReaderWriter volRW; 

	//Other Variables
	private static final String revnum = RegistrationUtilities.getVersion();
	private static final String shortDescription = "Combines a set of transformation matrices and deformation fields into a single deformation field.";
	private static final String longDescription = "The transformations and deformations are applied in the order inputted.";


	protected void createInputParameters(ParamCollection inputParams) {
		mtxRW = ArrayDoubleReaderWriter.getInstance();
		volRW = ImageDataReaderWriter.getInstance(); 

		FileExtensionFilter newExtFilter = new FileExtensionFilter();
		newExtFilter.getExtensions().addAll(mtxRW.getExtensionFilter().getExtensions());
		newExtFilter.getExtensions().addAll(volRW.getExtensionFilter().getExtensions());
		newExtFilter.getExtensions().add("mtx");


		Matrix identityMtx = new Matrix(4,4);
		for(int i = 0; i < 4; i++){
			identityMtx.set(i,i,1);
		}

		inputParams.add(inParamTransformMatrix = new ParamMatrix("Transform Matrix",4,4));
		inParamTransformMatrix.setMandatory(false);
		inputParams.add(inParamTransformMatrices=new ParamFileCollection("Transformation Matrices",new FileExtensionFilter(new String[]{"mtx"})));
		inParamTransformMatrices.setMandatory(false);
		inputParams.add(inParamDefFields = new ParamVolumeCollection("Deformation Fields"));
		inParamDefFields.setMandatory(false);
		//inputParams.add(inParamMatrix = new ParamMatrix("Transformation Matrix", 4, 4));

		inputParams.setPackage("IACL");
		inputParams.setCategory("Registration.Volume");
		inputParams.setLabel("Combine Transformations and Deformations");
		inputParams.setName("CombTransAndDef");


		AlgorithmInformation info = getAlgorithmInformation();
		info.setWebsite("http://www.iacl.ece.jhu.edu/");
		info.add(new AlgorithmAuthor("Min Chen", "", ""));
		info.setDescription(shortDescription);
		info.setLongDescription(shortDescription + longDescription);
		info.setVersion(revnum);
		info.setEditable(false);
		info.setStatus(DevelopmentStatus.BETA);
	}


	protected void createOutputParameters(ParamCollection outputParams) {
		outputParams.add(outParamCombDefField=new ParamVolumeCollection("Combined Deformation Fields",null,-1,-1,-1,3));
		outParamCombDefField.setLoadAndSaveOnValidate(false);
	}


	protected void execute(CalculationMonitor monitor) throws AlgorithmRuntimeException {
		ExecuteWrapper wrapper=new ExecuteWrapper();
		monitor.observe(wrapper);
		wrapper.execute(this);
	}


	protected class ExecuteWrapper extends AbstractCalculation{
		public void execute(ProcessingAlgorithm alg){
			this.setLabel("combine Volumes");

			List<ImageData> inDefFieldList = inParamDefFields.getImageDataList(); 
			ArrayList<double[][]> transMatricesList;
			Matrix singleMatrix = inParamTransformMatrix.getValue();
			if(singleMatrix != null){
				transMatricesList = new ArrayList<double[][]>();
				transMatricesList.add(singleMatrix.getArray());
			}else{
				transMatricesList = readMultiMatrices(inParamTransformMatrices.getValue());
			}
			for (int i = 0; i < transMatricesList.size(); i++){
				ImageData[] currentDef = new ImageDataFloat[chN];
				ImageData[] inDefFieldSplit = RegistrationUtilities.split4DImageDataIntoArray(inDefFieldList.get(i));
				XN = inDefFieldList.get(i).getRows();
				YN = inDefFieldList.get(i).getCols();
				ZN = inDefFieldList.get(i).getSlices();
				float[] currentDimRes = inDefFieldList.get(i).getHeader().getDimResolutions();
				ImageHeader inDefFieldHeader = inDefFieldList.get(i).getHeader();
				String inDefFieldName = inDefFieldList.get(i).getName();
				//Matrix inMatrix =  inParamMatrix.getValue(); 

				//currentDef = applyTransformation(currentDef, inMatrix);

				//List<File> DWtoB0Trans = inParamDWtoB0Transform.getValue();
				//List<File> correctDWtoStructTrans = inParamTransDWCorrectedToStruct.getValue();
				//ArrayList<double[][]> arraycorrectDWtoStructTrans = readMultiMatrices(inParamRigidB0toStruct.getValue());
				//if (arrayDWtoB0Trans.size() != arraycorrectDWtoStructTrans.size()) System.out.format("Transform Size Mismatch!");


				//File currentFile = DWtoB0Trans.get(i);
				//apply transformations or deformations one by one
				//currentDef = new ImageDataFloat[chN];

				//Pull out Matrix
				Matrix m;
				double[][] currentMatrix;
				m = new Matrix(4,4);
				currentMatrix = transMatricesList.get(i);
				if(currentMatrix.length!=4 || currentMatrix[0].length!=4){
					System.err.println(getClass().getCanonicalName()+"Invalid transformation - must be 4x4");
				}else{
					for(int k=0; k<currentMatrix.length; k++){
						for(int j=0; j<currentMatrix[0].length; j++){
							m.set(k, j, currentMatrix[k][j]);
						}
					}
				}

				
				
				//System.out.format("hi\n");
				for(int k=0;k<4;k++)for(int l=0;l<4;l++)System.out.format(m.get(k,l)+"\n");

				currentDef=createDefFieldFromTransMatrix(m, currentDimRes);
				currentDef=applyDeformation(currentDef, inDefFieldSplit);

				ImageDataFloat outVolume = new ImageDataFloat(RegistrationUtilities.combineImageDataArrayTo4D(currentDef));
				
				outVolume.setHeader(inDefFieldHeader);
				//outVolume.setName(inDefFieldName+"_combDef");
				//outVolume.setHeader(inParamEPICorrection.getImageData().getHeader());
				outVolume.setName(inDefFieldName.replace(".","_") + "_combDefField");
				outParamCombDefField.add(outVolume);
				outParamCombDefField.writeAndFreeNow(alg);

			}

		}
	}

	private ArrayList<double[][]> readMultiMatrices(List<File> files){
		ArrayDoubleMtxReaderWriter rw = new ArrayDoubleMtxReaderWriter();
		ArrayList<double[][]> allxfms = new ArrayList<double[][]>(files.size());
		int i=0;
		while(i<files.size()){
			allxfms.add(rw.read(files.get(i)));
			i++;
		}
		return allxfms;
	}

	public ImageData[] applyDeformation(ImageData[] currentDef, ImageData[] epiDef){
		double[] currentVec = new double[chN];
		double[] newVec = new double[chN];

		ImageData[] newDef = new ImageDataFloat[chN];
		for (int c = 0; c < chN; c++) newDef[c] = currentDef[c].clone();


		for(int i = 0; i < XN; i++)
			for(int j = 0; j < YN; j++)
				for(int k = 0; k < ZN; k++){

					for(int c = 0; c < chN; c++) newVec[c] = epiDef[c].getDouble(i, j, k); 

					//get current deformation at where the new deformation is pointing
					for(int c = 0; c < chN; c++) currentVec[c] = RegistrationUtilities.Interpolation(currentDef[c], XN, YN, ZN, 
							newVec[0]+i, newVec[1]+j, newVec[2]+k, InterpolationType.TRILINEAR); 

					for(int c = 0; c < chN; c++) newDef[c].set(i,j,k, newVec[c]+currentVec[c]);

				}
		return newDef;			

	}
	public ImageData[] createDefFieldFromTransMatrix(Matrix newTrans,float[] dimRes){
		double[] currentVec = new double[chN];
		double[] newVec = new double[chN];
		ImageData[] newDef = new ImageDataFloat[chN];

		for(int c = 0; c < chN; c++) newDef[c] = new ImageDataFloat(XN,YN,ZN);
		for(int i = 0; i < XN; i++)
			for(int j = 0; j < YN; j++)
				for(int k = 0; k < ZN; k++){

					newVec[0] = i*dimRes[0];
					newVec[1] = j*dimRes[1];
					newVec[2] = k*dimRes[2];

					Matrix v = new Matrix(new double[]{newVec[0],newVec[1],newVec[2],1},4);
					Matrix vp = newTrans.solve(v);
					for(int c = 0; c < chN; c++) {
						newVec[c] = vp.get(c,0) - newVec[c]; 
					}

					//get current deformation at where the new deformation is pointing
					//for(int c = 0; c < chN; c++) currentVec[c] = RegistrationUtilities.Interpolation(currentDef[c], XN, YN, ZN, 
					//		newVec[0]+i, newVec[1]+j, newVec[2]+k, InterpolationType.TRILINEAR); 

					//for(int c = 0; c < chN; c++) newDef[c].set(i,j,k, newVec[c]+currentVec[c]);

					for(int c = 0; c < chN; c++) newDef[c].set(i,j,k, newVec[c]/dimRes[c]);

				}

		return newDef;

	}

}