
#include "itkImage.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"

#include "itkDiscreteGaussianImageFilter.h"
#include "itkImageDuplicator.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkLinearInterpolateImageFunction.h"
#include "itkNearestNeighborInterpolateImageFunction.h"
#include "itkNumericTraits.h"
#include "itkRescaleIntensityImageFilter.h"
#include "itkWarpImageFilter.h"

#include "itkOutputWindow.h"
#include "itkTextOutput.h"

#include "itkMultiThreader.h"

#include "itksys/SystemTools.hxx"

#include "vtkOutputWindow.h"
#include "vtkSmartPointer.h"
#include "vtkUnstructuredGrid.h"
#include "vtkUnstructuredGridReader.h"
#include "vtkUnstructuredGridWriter.h"

#include "itkVersion.h"
#include "vtkVersion.h"

#include "mu.h"
#include "muFile.h"
#include "Log.h"
#include "MersenneTwisterRNG.h"
#include "Timer.h"

#include "DTImage.h"
#include "DTImageReader.h"
#include "DTImageScalarSource.h"
#include "DTImageWriter.h"

// Tumor simulation pipeline
#include "TumorMassEffectGenerator.h"
#include "DTImageDestructiveWarpFilter.h"
#include "TumorEdemaInfiltrationGenerator.h"

#include "ContrastEnhancementFilter.h"

#include "TumorSimulationParameters.h"
#include "TumorSimulationParametersXMLFile.h"

#include <exception>
#include <iostream>
#include <sstream>
#include <stdexcept>

#if defined(_MSC_VER) || defined(__WATCOM_C__)
#define FOLDER_SEPARATOR '\\'
#else
#define FOLDER_SEPARATOR '/'
#endif


void runTumorSimulation(TumorSimulationParameters* params);

void
printUsage(char* progname)
{
  std::cerr << "Usage: " << progname << " <params.xml> " << std::endl;
  std::cerr << std::endl;
}

int
main(int argc, char** argv)
{

  //std::cerr << "Run without any arguments to see command line options" << std::endl;

  if (argc != 2)
  {
    printUsage(argv[0]);
    return -1;
  }

  itk::OutputWindow::SetInstance(itk::TextOutput::New());

  try
  {
    std::cout << "Reading " << argv[1] << "..." << std::endl;
    TumorSimulationParameters::Pointer tumorsimp = readTumorSimulationParametersXML(argv[1]);
    if (tumorsimp.IsNull())
    {
      std::cerr << "Failed creating XML object, bad input file?" << std::endl;
      return -1;
    }
    runTumorSimulation(tumorsimp);
  }
  catch (itk::ExceptionObject& e)
  {
    std::cerr << e << std::endl;
    return -1;
  }
  catch (std::exception& e)
  {
    std::cerr << "Exception: " << e.what() << std::endl;
    return -1;
  }
  catch (std::string& s)
  {
    std::cerr << "Exception: " << s << std::endl;
    return -1;
  }
  catch (...)
  {
    std::cerr << "Unknown exception" << std::endl;
    return -1;
  }

  return 0;

}

void
writeUShortImage(const char* fn, itk::Image<float, 3>* img)
{ 
  typedef itk::Image<float, 3> FloatImageType;
  typedef itk::Image<unsigned short, 3> UShortImageType;

  //unsigned short maxUShort = itk::NumericTraits<unsigned short>::max();

  typedef itk::RescaleIntensityImageFilter<FloatImageType, UShortImageType>
    RescalerType;

  RescalerType::Pointer res = RescalerType::New();
  res->SetInput(img);
  res->SetOutputMinimum(0);
  //res->SetOutputMaximum(maxUShort);
  res->SetOutputMaximum(4096);
  res->Update();

  typedef itk::ImageFileWriter<UShortImageType> WriterType;

  WriterType::Pointer writer = WriterType::New();
  writer->SetInput(res->GetOutput());
  writer->SetFileName(fn);
  writer->UseCompressionOn();
  writer->Update();
}

void
writeUShortProbability(const char* fn, itk::Image<float, 3>* img)
{ 
  typedef itk::Image<float, 3> FloatImageType;
  typedef itk::Image<unsigned short, 3> UShortImageType;

  UShortImageType::Pointer outImg = UShortImageType::New();
  outImg->CopyInformation(img);
  outImg->SetRegions(img->GetLargestPossibleRegion());
  outImg->Allocate();

  unsigned short maxUShort = itk::NumericTraits<unsigned short>::max();

  typedef itk::ImageRegionIteratorWithIndex<FloatImageType> FloatIteratorType;
  FloatIteratorType it(img, img->GetLargestPossibleRegion());
  
  for (it.GoToBegin(); !it.IsAtEnd(); ++it) 
  {
    FloatImageType::IndexType ind = it.GetIndex();
    float val = it.Get() * maxUShort;
    if (val < 0.0)
      val = 0.0;
    if (val > maxUShort)
      val = maxUShort;
    outImg->SetPixel(ind, static_cast<unsigned short>(val + 0.5));
  }
  
  typedef itk::ImageFileWriter<UShortImageType> WriterType;

  WriterType::Pointer writer = WriterType::New();
  writer->SetInput(outImg);
  writer->SetFileName(fn);
  writer->UseCompressionOn();
  writer->Update();
}

static inline void _clampProb(float& p)
{
  if (p < 0.0)
    p = 0.0;
  if (p > 1.0)
    p = 1.0;
}

// Rescale image to 12-bit range
itk::Image<float, 3>::Pointer
rescaleImage12(itk::Image<float, 3>* img)
{
  typedef itk::Image<float, 3> FloatImageType;
  typedef itk::RescaleIntensityImageFilter<FloatImageType, FloatImageType>
    RescalerType;

  RescalerType::Pointer res = RescalerType::New();
  res->SetInput(img);
  res->SetOutputMinimum(0);
  res->SetOutputMaximum(4095);
  res->Update();

  return res->GetOutput();
}

// Add bias field using random poly coefficients
void insertBias(itk::Image<float, 3>* img, unsigned int maxDegree)
{
  if (maxDegree == 0)
    return;

  unsigned int numCoeffs = (maxDegree+1)*(maxDegree+2)*(maxDegree+3)/6;

  DynArray<float> coeffs;
  coeffs.Initialize(numCoeffs, 0.0);

  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();

  for (unsigned int k = 0; k < numCoeffs; k++)
  {
    float a = 1.0 / pow(2.0f, (float)k);
    float r = 2.0 * rng->GenerateUniformRealClosedInterval() - 1.0;
    coeffs[k] = a * r;
  }

  typedef itk::Image<float, 3> FloatImageType;
  typedef itk::ImageRegionIterator<FloatImageType> IteratorType;

  FloatImageType::SizeType size = img->GetLargestPossibleRegion().GetSize();

  IteratorType it(img, img->GetLargestPossibleRegion());
  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    FloatImageType::IndexType ind = it.GetIndex();

    float x[3];
    for (unsigned int dim = 0; dim < 3; dim++)
    {
      x[dim] = (ind[dim] - size[dim]/2.0) / (float)size[dim] * 2.0;
    }

    float logbias = 0;
            
    unsigned int c = 0;

    for (float order = 0; order <= maxDegree; order += 1.0)
      for (float xorder = 0; xorder <= order; xorder += 1.0)
        for (float yorder = 0; yorder <= (order-xorder); yorder += 1.0)
        {
          float zorder = order - xorder - yorder;
          float poly =
            (float)(pow(x[0],xorder) * pow(x[1],yorder) * pow(x[2],zorder));
          logbias += coeffs[c] * poly;
          c++;
        }

    float bias = expf(logbias);

    if (bias < 0.2)
      bias = 0.2;
    if (bias > 5.0)
      bias = 5.0;

    //it.Set(it.Get() * bias);

    float newval = it.Get() * bias;
    if (newval < 0.0)
      newval = 0.0;
    it.Set(newval);
  }
  
}


void runTumorSimulation(TumorSimulationParameters* params)
{
  if (!params->CheckValues())
    throw std::runtime_error("Invalid parameter values");

  itk::OutputWindow::SetInstance(itk::TextOutput::New());
  vtkOutputWindow::SetInstance(vtkOutputWindow::GetInstance());

/*
  // Avoid using all CPUs?
  unsigned int maxThreads = itk::MultiThreader::GetGlobalDefaultNumberOfThreads();
  maxThreads = maxThreads / 4 * 3;
  if (maxThreads < 2)
    maxThreads = 2;
  itk::MultiThreader::SetGlobalMaximumNumberOfThreads(maxThreads);
*/
    
  // Create and start a new timer (for the whole process)
  Timer* timer = new Timer();
  
  // Initialize random number generators
  srand(542948474);
  MersenneTwisterRNG* rng = MersenneTwisterRNG::GetGlobalInstance();
  rng->Initialize(87584359);

  // Directory separator string
  std::string separator = std::string("/");
  separator[0] = FOLDER_SEPARATOR;

  // Get output directory
  std::string outdir = params->GetOutputDirectory();
  // Make sure last character in output directory string is a separator
  if (outdir[outdir.size()-1] != FOLDER_SEPARATOR)
    outdir += separator;

  // Create the output directory, stop if it does not exist
  if(!mu::create_dir(outdir.c_str()))
    return;

  // Write out the parameters in XML
  {
    std::string xmlfn = outdir + params->GetDatasetName() + ".xml";
    writeTumorSimulationParametersXML(xmlfn.c_str(), params);
  }

  // Set up the logger
  {
    std::string logfn = outdir + "tumorsim_" + params->GetDatasetName() + ".log";
    (mu::Log::GetInstance())->EchoOn();
    (mu::Log::GetInstance())->SetOutputFileName(logfn.c_str());
  }

  muLogMacro(<< "Brain Tumor MRI Simulator\n");
  muLogMacro(<< "========================================\n");
  muLogMacro(<< "Marcel Prastawa - prastawa@sci.utah.edu\n");
  muLogMacro(<< "This software is provided for research purposes only\n");
  muLogMacro(<< "\n");
  muLogMacro(<< "Version 1.2.3\n");
  muLogMacro(<< "Program compiled on: " << __DATE__ << "\n");
  muLogMacro(<< "\n");
    
  muLogMacro(<< "Using ITK version " << itk::Version::GetITKVersion() << "\n");
  muLogMacro(<< "Using VTK version " << vtkVersion::GetVTKVersion() << "\n");
  muLogMacro(<< "\n");

  // Write input parameters
  params->PrintSelf(std::cout);
  params->PrintSelf((mu::Log::GetInstance())->GetFileObject());

  muLogMacro(<< "\n");

  std::string name = params->GetDatasetName();

  // Get input directory
  std::string inputdir = params->GetInputDirectory();
  // Make sure last character is a separator
  if (inputdir[inputdir.size()-1] != FOLDER_SEPARATOR)
    inputdir += separator;

  // Check parameters

  if (params->GetInfiltrationIterations() < 2)
  {
    throw std::runtime_error("Need at least two infiltration steps");
  }

  float infil_t = params->GetInfiltrationTimeStep() * params->GetInfiltrationIterations();
  if (params->GetInfiltrationEarlyTime() >= infil_t)
  {
    throw std::runtime_error("Early infiltration time is invalid");
  }

  if (params->GetDeformationKappa() <= 0.0)
  {
    throw std::runtime_error("Deformation kappa must be > 0");
  }
  if (params->GetDeformationDamping() <= 0.0)
  {
    throw std::runtime_error("Deformation damping must be > 0");
  }

  ContrastEnhancementFilter::EnhancementMode enhmode;
  if (itksys::SystemTools::Strucmp(params->GetContrastEnhancementType().c_str(),
        "none") == 0)
    enhmode = ContrastEnhancementFilter::NoEnhancement;
  else if (
    itksys::SystemTools::Strucmp(params->GetContrastEnhancementType().c_str(),
      "ring") == 0)
    enhmode = ContrastEnhancementFilter::RingEnhancement;
  else if (
    itksys::SystemTools::Strucmp(params->GetContrastEnhancementType().c_str(),
      "uniform") == 0)
    enhmode = ContrastEnhancementFilter::UniformEnhancement;
  else
    throw std::runtime_error("Unknown contrast enhancement mode");

  //
  // Obtain input data
  //

  // Read mesh
  muLogMacro(<< "Reading VTK mesh...\n");
  std::string meshfn = inputdir + "mesh.vtk";

  vtkSmartPointer<vtkUnstructuredGridReader> meshReader =
    vtkSmartPointer<vtkUnstructuredGridReader>::New();
  meshReader->SetFileName(meshfn.c_str());
  meshReader->Update();

  vtkSmartPointer<vtkUnstructuredGrid> mesh = meshReader->GetOutput();

  // Read labels
  muLogMacro(<< "Reading voxel labels...\n");
  std::string labelfn = inputdir + "labels.mha";

  typedef itk::Image<unsigned char, 3> ByteImageType;
  typedef itk::ImageFileReader<ByteImageType> ByteReaderType;
  typedef itk::ImageRegionIteratorWithIndex<ByteImageType> ByteIteratorType;
    
  ByteImageType::Pointer labelImg;
  {
    ByteReaderType::Pointer reader = ByteReaderType::New();
    reader->SetFileName(labelfn.c_str());
    reader->Update();
    labelImg = reader->GetOutput();
  }

  // Read and process probabilities
  muLogMacro(<< "Reading anatomical probabilities...\n");
  DynArray<std::string> probNames;
  probNames.Append("p_white.mha");
  probNames.Append("p_gray.mha");
  probNames.Append("p_csf.mha");
  probNames.Append("p_dura.mha");
  probNames.Append("p_vessel.mha");

  typedef itk::Image<float, 3> FloatImageType;
  typedef itk::ImageFileReader<FloatImageType> FloatReaderType;
  typedef itk::ImageRegionIteratorWithIndex<FloatImageType> FloatIteratorType;

  DynArray<FloatImageType::Pointer> probImages;
  for (unsigned int k = 0; k < probNames.GetSize(); k++)
  {
    std::string probfn = inputdir + probNames[k];

    FloatReaderType::Pointer reader = FloatReaderType::New();
    reader->SetFileName(probfn.c_str());

    try
    {
      reader->Update();
    }
    catch (...)
    {
      continue;
    }

    FloatImageType::Pointer img = reader->GetOutput();

    FloatIteratorType it(img, img->GetLargestPossibleRegion());
    for (it.GoToBegin(); !it.IsAtEnd(); ++it)
      it.Set(it.Get() / 255.0);

    probImages.Append(img);
  }

  // Read tumor seed
  muLogMacro(<< "Reading tumor seed image...\n");
  ByteImageType::Pointer tumorSeedImg;
  {
    ByteReaderType::Pointer reader = ByteReaderType::New();
    reader->SetFileName(params->GetDeformationSeedFileName());
    reader->Update();
    tumorSeedImg = reader->GetOutput();

    ByteIteratorType sIt(tumorSeedImg, tumorSeedImg->GetLargestPossibleRegion());
    for (sIt.GoToBegin(); !sIt.IsAtEnd(); ++sIt)
    {
      // seed = seed * brain mask
      ByteImageType::IndexType ind = sIt.GetIndex();

      float sumP = 0;
      for (unsigned int k = 0; k < (probImages.GetSize()-1); k++)
        sumP += probImages[k]->GetPixel(ind);
      if (sumP < 1e-2)
        sIt.Set(0);

/*
      // seed = seed * tissue mask
      unsigned char l = labelImg->GetPixel(sIt.GetIndex());
      if (l != 1 && l != 2)
        sIt.Set(0);
*/
    }
  }

  // Write seed in output directory for record keeping
  {
    typedef itk::ImageFileWriter<ByteImageType> ByteWriterType;
    ByteWriterType::Pointer writer = ByteWriterType::New();
    std::string fn = outdir + params->GetDatasetName() + "_seed.mha";

    writer->SetInput(tumorSeedImg);
    writer->SetFileName(fn.c_str());
    writer->UseCompressionOn();
    writer->Update();
  }

  // Read and process DTI
  muLogMacro(<< "Reading diffusion tensor image...\n");
  std::string dtifn = inputdir + "dti.mha";

  DTImagePointer dtImg;
  {
    DTImageReader::Pointer dtReader = DTImageReader::New();
    dtReader->ReadStackedImage(dtifn.c_str());
    dtImg = dtReader->GetOutput();
  }

  muLogMacro(<< "Obtaining MD image...\n");
  FloatImageType::Pointer mdImg;
  {
    DTImageScalarSource::Pointer dtiss = DTImageScalarSource::New();
    dtiss->SetInput(dtImg);
    mdImg = dtiss->GetMDImage();
  }

  muLogMacro(<< "Scaling DT image...\n");
  FloatIteratorType mdIt(mdImg, mdImg->GetLargestPossibleRegion());
  
  float maxMD = 0.0; 

  for (mdIt.GoToBegin(); !mdIt.IsAtEnd(); ++mdIt)
  { 
    float v = mdIt.Get();
    if (v > maxMD)
      maxMD = v;
  }

  float maxTr = (3.0 * maxMD);

  DTImageIteratorType dtIt(dtImg, dtImg->GetLargestPossibleRegion());

  for (dtIt.GoToBegin(); !dtIt.IsAtEnd(); ++dtIt)
  {
    DiffusionTensor D = dtIt.Get();
    D /= maxTr;
    dtIt.Set(D);
  }

  // Read texture images
  muLogMacro(<< "Reading precomputed texture images...\n");
  DynArray<FloatImageType::Pointer> t1Textures;
  for (unsigned int k = 0; k < 5; k++)
  {
    std::ostringstream oss;
    oss << inputdir << "textures" << FOLDER_SEPARATOR << "t1_" << k+1 << ".mha"
      << std::ends;
    FloatReaderType::Pointer reader = FloatReaderType::New();
    reader->SetFileName(oss.str().c_str());
    reader->Update();
    t1Textures.Append(reader->GetOutput());
  }

  DynArray<FloatImageType::Pointer> t2Textures;
  for (unsigned int k = 0; k < 5; k++)
  {
    std::ostringstream oss;
    oss << inputdir << "textures" << FOLDER_SEPARATOR << "t2_" << k+1 << ".mha"
      << std::ends;
    FloatReaderType::Pointer reader = FloatReaderType::New();
    reader->SetFileName(oss.str().c_str());
    reader->Update();
    t2Textures.Append(reader->GetOutput());
  }

  DynArray<FloatImageType::Pointer> gadTextures;
  for (unsigned int k = 0; k < 6; k++)
  {
    std::ostringstream oss;
    if (k < 5)
      oss << inputdir << "textures" << FOLDER_SEPARATOR << "t1_" << k+1
        << ".mha" << std::ends;
    else
      oss << inputdir << "textures" << FOLDER_SEPARATOR << "gad_" << k+1
        << ".mha" << std::ends;
    FloatReaderType::Pointer reader = FloatReaderType::New();
    reader->SetFileName(oss.str().c_str());
    reader->Update();
    gadTextures.Append(reader->GetOutput());
  }

  DynArray<FloatImageType::Pointer> flairTextures;
  for (unsigned int k = 0; k < 5; k++)
  {
    std::ostringstream oss;
    oss << inputdir << "textures" << FOLDER_SEPARATOR << "flair_" << k+1
      << ".mha" << std::ends;
    FloatReaderType::Pointer reader = FloatReaderType::New();
    reader->SetFileName(oss.str().c_str());
    reader->Update();
    flairTextures.Append(reader->GetOutput());
  }

  // Read background MRI for T1 and T2
  FloatImageType::Pointer t1bgImg;
  {
    std::string t1bgfn = inputdir + "textures" + FOLDER_SEPARATOR + "t1_bg.mha";
    FloatReaderType::Pointer reader = FloatReaderType::New();
    reader->SetFileName(t1bgfn.c_str());
    reader->Update();
    t1bgImg = reader->GetOutput();
  }

  FloatImageType::Pointer t2bgImg;
  {
    std::string t2bgfn = inputdir + "textures" + FOLDER_SEPARATOR + "t2_bg.mha";
    FloatReaderType::Pointer reader = FloatReaderType::New();
    reader->SetFileName(t2bgfn.c_str());
    reader->Update();
    t2bgImg = reader->GetOutput();
  }

  //
  // Mass effect simulation
  //

  muLogMacro(<< "Mass effect simulation...\n");

  typedef itk::ImageRegionIteratorWithIndex<ByteImageType> LabelIteratorType;

  // Mark / include tumor seed voxels in the label image
  LabelIteratorType labelIt(labelImg, labelImg->GetLargestPossibleRegion());

  for (labelIt.GoToBegin(); !labelIt.IsAtEnd(); ++labelIt)
  {
    //if (labelIt.Get() >= 5)
    //  muLogMacro( << ">->->- WARNING: read label value " << labelIt.Get() << "\n");
    ByteImageType::IndexType ind = labelIt.GetIndex();
    if (tumorSeedImg->GetPixel(ind) != 0)
      labelIt.Set(5);
  }

  TumorMassEffectGenerator* massEffect = new TumorMassEffectGenerator();

  // Young modulus for parenchyma = 694 Pa = 694 N/m^2
  // Brain space is specified in mm, so need to use 694*1e-6 N/mm^2
  //massEffect->SetBrainMaterialParameters(694, 0.4);
  //massEffect->SetFalxMaterialParameters(200000, 0.4);
  //massEffect->SetBrainMaterialParameters(2000, 0.4);
  //massEffect->SetBrainMaterialParameters(694e-6, 0.4);
  massEffect->SetBrainMaterialParameters(
    params->GetBrainYoungModulus()*1e-6, params->GetBrainPoissonRatio());
  //massEffect->SetFalxMaterialParameters(2e-1, 0.4);
  massEffect->SetFalxMaterialParameters(
    params->GetFalxYoungModulus()*1e-6, params->GetFalxPoissonRatio());

  massEffect->SetVMFKappa(params->GetDeformationKappa());
  massEffect->SetDamping(params->GetDeformationDamping());

  massEffect->SetLabelImage(labelImg);
  massEffect->SetInitialVTKMesh(mesh);
  massEffect->SetDeformationIterations(params->GetDeformationIterations());
  // Pressure specified in KPa (so x kPa = x * 1e-3 N / mm^2)
  massEffect->SetPressure(params->GetDeformationInitialPressure() * 1e-3);
  massEffect->SetUseQHull(params->GetUseQHull());
  massEffect->SetDeformationSolverIterations(params->GetDeformationSolverIterations());
  massEffect->SetNumberOfThreads(params->GetNumberOfThreads());

  massEffect->ComputeDeformation();

  typedef TumorMassEffectGenerator::DeformationFieldType
    DeformationFieldType;

  DeformationFieldType::Pointer def1Img = massEffect->GetDeformation();
  DeformationFieldType::Pointer invDef1Img = massEffect->GetInverseDeformation();

  // Update variables
  mesh = massEffect->GetCurrentMesh()->GetVTKMesh();

  typedef itk::WarpImageFilter<
    ByteImageType, ByteImageType, DeformationFieldType> ByteWarperType;
  
  typedef itk::NearestNeighborInterpolateImageFunction<
    ByteImageType, double> NNInterpolatorType;
  
  {
    ByteWarperType::Pointer warper = ByteWarperType::New();
    warper->SetEdgePaddingValue(0);
    warper->SetInput(labelImg);
    warper->SetInterpolator(NNInterpolatorType::New());
    warper->SetOutputDirection(labelImg->GetDirection());
    warper->SetOutputOrigin(labelImg->GetOrigin());
    warper->SetOutputSpacing(labelImg->GetSpacing());
    //warper->SetOutputSize(labelImg->GetLargestPossibleRegion().GetSize());
#if ITK_VERSION_MAJOR >= 4
    warper->SetDisplacementField(def1Img);
#else
    warper->SetDeformationField(def1Img);
#endif
    warper->Update();

    labelImg = warper->GetOutput();

    typedef itk::ImageFileWriter<ByteImageType> ByteWriterType;
    ByteWriterType::Pointer writer = ByteWriterType::New();
    std::string label1fn = outdir + params->GetDatasetName() + "_warped_labels1.mha";

    writer->SetInput(labelImg);
    writer->SetFileName(label1fn.c_str());
    writer->UseCompressionOn();
    writer->Update();
  }

  // Write intermediate results

  typedef itk::ImageFileWriter<DeformationFieldType> DeformationWriterType;

  {
    DeformationWriterType::Pointer writer = DeformationWriterType::New();
    std::string def1fn = outdir + params->GetDatasetName() + "_def1.mha";

    writer->SetInput(def1Img);
    writer->SetFileName(def1fn.c_str());
    writer->UseCompressionOn();
    writer->Update();

    def1fn = outdir + params->GetDatasetName() + "_def1_inverse.mha";
    writer->SetInput(massEffect->GetInverseDeformation());
    writer->SetFileName(def1fn.c_str());
    writer->UseCompressionOn();
    writer->Update();
  }

  // Clean up
  delete massEffect;

  //
  // Warp and destroy DTI
  //

  // Modify tensors in seed region
  muLogMacro(<< "Modifying tensors in tumor seed region...\n");
  DiffusionTensor::MatrixType id(3, 3);
  id.set_identity();

  DiffusionTensor idT;
  idT.FromMatrix(id  * 0.1);

  for (dtIt.GoToBegin(); !dtIt.IsAtEnd(); ++dtIt)
  {
    DTImageIndexType ind = dtIt.GetIndex();
    if (tumorSeedImg->GetPixel(ind) != 0)
    {
      DiffusionTensor T = dtIt.Get();
      //dtIt.Set(T * 2.0);

/*
      DiffusionTensor::MatrixType Tmat = T.GetMatrix();
      float md = (Tmat(0, 0) + Tmat(1, 1) + Tmat(2,2)) / 3.0;
      Tmat.set_identity();
      Tmat += (id * 1e-5);
      Tmat *= (md * 2.0);
      T.FromMatrix(Tmat);
      dtIt.Set(T);
*/

      // Already divided by max trace, so D in tumor should be identity?
      dtIt.Set(idT);
    }
  }

  // Apply deformation
  muLogMacro(<< "Warping and destroying DTI...\n");
  DTImageDestructiveWarpFilter::Pointer dtWarper =
    DTImageDestructiveWarpFilter::New();

  dtWarper->SetInput(dtImg);
  dtWarper->SetDeformationField(def1Img);
  dtWarper->SetInverseDeformationField(invDef1Img);
  dtWarper->Update();

  dtImg = dtWarper->GetOutput();

  // Write intermediate results
  {
    DTImageWriter::Pointer dtWriter = DTImageWriter::New();
    dtWriter->SetInput(dtImg);
    std::string outname = outdir + name + "_modified_dti.mha";
    dtWriter->WriteStackedImage(outname.c_str());
  }

  //
  // Infiltration simulation
  //

  muLogMacro(<< "Estimating infiltration...\n");

  TumorEdemaInfiltrationGenerator* infilGen =
    new TumorEdemaInfiltrationGenerator();
  
  infilGen->SetLabelImage(labelImg);
  infilGen->SetInitialVTKMesh(mesh);
  infilGen->SetDTImage(dtImg);
  infilGen->SetUseQHull(params->GetUseQHull());

  infilGen->SetReactionCoefficient(params->GetInfiltrationReactionCoefficient());
  infilGen->SetWhiteMatterTensorMultiplier(params->GetWhiteMatterTensorMultiplier());
  infilGen->SetGrayMatterTensorMultiplier(params->GetGrayMatterTensorMultiplier());
  
  infilGen->SetInfiltrationTimeStep(params->GetInfiltrationTimeStep());
  infilGen->SetInfiltrationIterations(params->GetInfiltrationIterations());
  infilGen->SetEarlyInfiltrationTime(params->GetInfiltrationEarlyTime());

  infilGen->SetLambda(params->GetInfiltrationBodyForceCoefficient() * 1e-3);
  infilGen->SetLambdaDamping(params->GetInfiltrationBodyForceDamping());
  infilGen->SetDeformationIterations(
    params->GetInfiltrationBodyForceIterations());

  infilGen->SetDeformationSolverIterations(
    params->GetDeformationSolverIterations());
  infilGen->SetInfiltrationSolverIterations(
    params->GetInfiltrationSolverIterations());
  infilGen->SetNumberOfThreads(params->GetNumberOfThreads());

  //infilGen->SetBrainMaterialParameters(694.0, 0.4);
  //infilGen->SetFalxMaterialParameters(200000.0, 0.4);
  //infilGen->SetBrainMaterialParameters(694e-6, 0.4);
  infilGen->SetBrainMaterialParameters(
    params->GetBrainYoungModulus()*1e-6, params->GetBrainPoissonRatio());
  //infilGen->SetFalxMaterialParameters(2e-1, 0.4);
  infilGen->SetFalxMaterialParameters(
    params->GetFalxYoungModulus()*1e-6, params->GetFalxPoissonRatio());

  FloatImageType::Pointer infilImg = infilGen->ComputeInfiltration();
  FloatImageType::Pointer earlyInfilImg = infilGen->GetEarlyInfiltration();

  muLogMacro(<< "Applying infiltration body forces...\n");

  //infilGen->ComputeDeformation();

  DeformationFieldType::Pointer def2Img = infilGen->GetDeformation();

  // Update variables

  mesh = infilGen->GetCurrentMesh()->GetVTKMesh();

  // Write intermediate results

  {
    ByteWarperType::Pointer warper = ByteWarperType::New();
    warper->SetEdgePaddingValue(0);
    warper->SetInput(labelImg);
    warper->SetInterpolator(NNInterpolatorType::New());
    warper->SetOutputDirection(labelImg->GetDirection());
    warper->SetOutputOrigin(labelImg->GetOrigin());
    warper->SetOutputSpacing(labelImg->GetSpacing());
    //warper->SetOutputSize(labelImg->GetLargestPossibleRegion().GetSize());
#if ITK_VERSION_MAJOR >= 4
    warper->SetDisplacementField(def2Img);
#else
    warper->SetDeformationField(def2Img);
#endif
    warper->Update();

    labelImg = warper->GetOutput();

    typedef itk::ImageFileWriter<ByteImageType> ByteWriterType;
    ByteWriterType::Pointer writer = ByteWriterType::New();
    std::string label2fn = outdir + params->GetDatasetName() + "_warped_labels2.mha";

    writer->SetInput(labelImg);
    writer->SetFileName(label2fn.c_str());
    writer->UseCompressionOn();
    writer->Update();
  }

  typedef itk::ImageFileWriter<DeformationFieldType> DeformationWriterType;

  {
    DeformationWriterType::Pointer writer = DeformationWriterType::New();

    std::string def2fn = outdir + params->GetDatasetName() + "_def2.mha";
    writer->SetInput(def2Img);
    writer->SetFileName(def2fn.c_str());
    writer->UseCompressionOn();
    writer->Update();

    def2fn = outdir + params->GetDatasetName() + "_def2_inverse.mha";
    writer->SetInput(infilGen->GetInverseDeformation());
    writer->SetFileName(def2fn.c_str());
    writer->UseCompressionOn();
    writer->Update();
  }

  // Adjust generated tumor+edema probs to exclude non-brain tissue
  FloatIteratorType eIt(
    earlyInfilImg, earlyInfilImg->GetLargestPossibleRegion());
  for (eIt.GoToBegin(); !eIt.IsAtEnd(); ++eIt)
  {
    FloatImageType::IndexType ind = eIt.GetIndex();
  
    float sump = 0;
    for (unsigned int k = 0; k < probImages.GetSize(); k++)
      sump += probImages[k]->GetPixel(ind);
    if (sump < 1e-10)
    {
      infilImg->SetPixel(ind, 0);
      earlyInfilImg->SetPixel(ind, 0);
    }
  }

  std::string infilfn = outdir + name + "_infil.mha";
  //writeUShortProbability(infilfn.c_str(), infilImg);
  writeUShortImage(infilfn.c_str(), infilImg);

  std::string earlyinfilfn = outdir + name + "_early_infil.mha";
  //writeUShortProbability(earlyinfilfn.c_str(), earlyInfilImg);
  writeUShortImage(earlyinfilfn.c_str(), earlyInfilImg);

  // Clean up
  delete infilGen;

  //
  // Generate final probabilities
  //

  muLogMacro(<< "Computing probabilities...\n");

  DynArray<FloatImageType::Pointer> outProbImages;

  //for (unsigned int k = 0; k < (probImages.GetSize()-1); k++)
  for (unsigned int k = 0; k < probImages.GetSize(); k++)
  {
    typedef itk::ImageDuplicator<FloatImageType> DuplicatorType;
    DuplicatorType::Pointer duper = DuplicatorType::New();
    duper->SetInputImage(probImages[k]);
    duper->Update();
    outProbImages.Append(duper->GetOutput());
  }


  FloatIteratorType pIt(
    outProbImages[0], outProbImages[0]->GetLargestPossibleRegion());

  // Blur seed and subtract from probs
  //TODO: no need, do it all at once later?
  FloatImageType::Pointer blurredSeedImg;
  {
    typedef itk::DiscreteGaussianImageFilter<ByteImageType, FloatImageType>
      BlurFilterType;
    BlurFilterType::Pointer blurf = BlurFilterType::New();
    blurf->SetInput(tumorSeedImg);
    blurf->SetVariance(2.0);
    blurf->Update();
    blurredSeedImg = blurf->GetOutput();
  }

  for (pIt.GoToBegin(); !pIt.IsAtEnd(); ++pIt)
  {
    FloatImageType::IndexType ind = pIt.GetIndex();
    float s = blurredSeedImg->GetPixel(ind);
    _clampProb(s);

    float sump = 1e-10;
    for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
      sump += outProbImages[k]->GetPixel(ind);

    for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
    { 
      float p = outProbImages[k]->GetPixel(ind);
      float newp = p - p/sump*s;
      _clampProb(newp);
      outProbImages[k]->SetPixel(ind, newp);
    }
  }

  typedef itk::WarpImageFilter<
    FloatImageType, FloatImageType, DeformationFieldType> FloatWarperType;
  
  typedef itk::LinearInterpolateImageFunction<
    FloatImageType, double> LinearInterpolatorType;

  muLogMacro(<< "  Applying mass effect deformation...\n");
  for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
  {
    FloatWarperType::Pointer warper = FloatWarperType::New();
    warper->SetEdgePaddingValue(0);
    warper->SetInput(outProbImages[k]);
    warper->SetInterpolator(LinearInterpolatorType::New());
    warper->SetOutputDirection(outProbImages[k]->GetDirection());
    warper->SetOutputOrigin(outProbImages[k]->GetOrigin());
    warper->SetOutputSpacing(outProbImages[k]->GetSpacing());
    //warper->SetOutputSize(outProbImages[k]->GetLargestPossibleRegion().GetSize());
#if ITK_VERSION_MAJOR >= 4
    warper->SetDisplacementField(def1Img);
#else
    warper->SetDeformationField(def1Img);
#endif
    warper->Update();

    outProbImages[k] = warper->GetOutput();
  }

  // Apply body force
  muLogMacro(<< "  Applying infiltration body force deformations...\n");
  // Excluding tumor and edema since they already have deformation
  for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
  {
    FloatWarperType::Pointer warper = FloatWarperType::New();
    warper->SetEdgePaddingValue(0);
    warper->SetInput(outProbImages[k]);
    warper->SetInterpolator(LinearInterpolatorType::New());
    warper->SetOutputDirection(outProbImages[k]->GetDirection());
    warper->SetOutputOrigin(outProbImages[k]->GetOrigin());
    warper->SetOutputSpacing(outProbImages[k]->GetSpacing());
    //warper->SetOutputSize(outProbImages[k]->GetLargestPossibleRegion().GetSize());
#if ITK_VERSION_MAJOR >= 4
    warper->SetDisplacementField(def2Img);
#else
    warper->SetDeformationField(def2Img);
#endif
    warper->Update();

    outProbImages[k] = warper->GetOutput();
  }

  // Subtract early infiltration
  // TODO: do it together with late infil?
  for (pIt.GoToBegin(); !pIt.IsAtEnd(); ++pIt)
  {
    FloatImageType::IndexType ind = pIt.GetIndex();
    
    float p_tumor0 = earlyInfilImg->GetPixel(ind);
    _clampProb(p_tumor0);
  
    float sump = 1e-10;
    for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
      sump += outProbImages[k]->GetPixel(ind);

    for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
    {
      float p = outProbImages[k]->GetPixel(ind);
      float newp = p - (p/sump)*p_tumor0;
      _clampProb(newp);
      outProbImages[k]->SetPixel(ind, newp);
    }
  }

  // Adjust for tumor and edema
  FloatImageType::Pointer tumorImg = earlyInfilImg;

  FloatImageType::Pointer edemaImg = FloatImageType::New();
  edemaImg->CopyInformation(outProbImages[0]);
  edemaImg->SetRegions(outProbImages[0]->GetLargestPossibleRegion());
  edemaImg->Allocate();
  edemaImg->FillBuffer(0);

  for (pIt.GoToBegin(); !pIt.IsAtEnd(); ++pIt)
  {
    FloatImageType::IndexType ind = pIt.GetIndex();

    float p_brain = 1e-10;
    float p_brain_orig = 1e-10;
    //for (unsigned int k = 0; k < (probImages.GetSize()-1); k++)
    for (unsigned int k = 0; k < probImages.GetSize(); k++)
    {
      p_brain += outProbImages[k]->GetPixel(ind);
      p_brain_orig += probImages[k]->GetPixel(ind);
    }

    float p_tissue = 1e-10;
    float p_tissue_orig = 1e-10;
    //for (unsigned int k = 0; k < (probImages.GetSize()-1); k++)
    for (unsigned int k = 0; k < probImages.GetSize(); k++)
    {
      // Exclude csf
      //if (k == 2)
        //continue;

      p_tissue += outProbImages[k]->GetPixel(ind);
      p_tissue_orig += probImages[k]->GetPixel(ind);
    }

    float infil = infilImg->GetPixel(ind);
    float early_infil = earlyInfilImg->GetPixel(ind);

    _clampProb(infil);
    _clampProb(early_infil);

    float late_infil = infil - early_infil;
/*
    if (late_infil < 0)
      early_infil = infil;
    _clampProb(late_infil);
*/

    // Initial tumor prob from early infil
    float p_tumor = early_infil;
    //float p_tumor = early_infil * p_tissue;

    float p_edema = late_infil;
    //float p_edema = late_infil * p_tissue;
    //float p_edema = late_infil * origp_brain;


/*  
// Other possible adjustments:
    p_wm -= (p_wm*early_infil + p_wm*late_infil) / p_brain;
    p_wm -= (p_wm*early_infil + p_wm*late_infil);
*/  
    
    float p_bad = p_tumor + p_edema;

    float rem = p_brain - p_bad;
    _clampProb(rem);

    // Adjust probs, except tumor
    for (unsigned int k = 0; k < (outProbImages.GetSize()-1); k++)
    {
      // Exclude csf
      //if (k == 2)
      //  continue;

      float p_k = outProbImages[k]->GetPixel(ind);
      //float p_adjust = p_k - p_k/p_tissue * p_bad;
      float p_adjust = p_k - p_k/p_brain * p_bad;
      //float p_adjust = outProbImages[k]->GetPixel(ind) * rem;
      _clampProb(p_adjust);
      outProbImages[k]->SetPixel(ind, p_adjust);
    }
    
    _clampProb(p_tumor);
    tumorImg->SetPixel(ind, p_tumor);

    _clampProb(p_edema);
    edemaImg->SetPixel(ind, p_edema);
  }

  // Inserting edema probability before tumor, vessels last
  FloatImageType::Pointer vesselImg = outProbImages[outProbImages.GetSize()-1];

  outProbImages[outProbImages.GetSize()-1] = edemaImg;
  outProbImages.Append(tumorImg);
  outProbImages.Append(vesselImg);

  // Rescale probabilities to original p(brain)
  for (pIt.GoToBegin(); !pIt.IsAtEnd(); ++pIt)
  {
    FloatImageType::IndexType ind = pIt.GetIndex();

    float orig_sump = 1e-10;
    //for (unsigned int k = 0; k < (probImages.GetSize()-1); k++)
    for (unsigned int k = 0; k < probImages.GetSize(); k++)
      orig_sump += probImages[k]->GetPixel(ind);
    
    float sump = 1e-10;
    for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
      sump += outProbImages[k]->GetPixel(ind);
    
    for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
    {
      float p = outProbImages[k]->GetPixel(ind);
      outProbImages[k]->SetPixel(ind, p/sump * orig_sump);
    }
  }

  // Write probabilities
  for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
  {
    std::ostringstream oss;
    oss << outdir << name << "_prob" << k+1 << ".mha" << std::ends;
    writeUShortProbability(oss.str().c_str(), outProbImages[k]);
  }

  //
  // Compute contrast enhancement
  //

#if 0
  // Combine deformations and warp p(vessel)
  muLogMacro(<< "Warping p(vessel)...\n");

  FloatImageType::Pointer vesselImg = probImages[probImages.GetSize()-1];

  // Repeated application of deformations
  FloatWarperType::Pointer warper1 = FloatWarperType::New();
  warper1->SetEdgePaddingValue(0);
  warper1->SetInput(vesselImg);
  warper1->SetInterpolator(LinearInterpolatorType::New());
  warper1->SetOutputDirection(vesselImg->GetDirection());
  warper1->SetOutputOrigin(vesselImg->GetOrigin());
  warper1->SetOutputSpacing(vesselImg->GetSpacing());
  //warper1->SetOutputSize(vesselImg->GetLargestPossibleRegion().GetSize());
#if ITK_VERSION_MAJOR >= 4
  warper1->SetDisplacementField(def1Img);
#else
  warper1->SetDeformationField(def1Img);
#endif
  warper1->Update();

  vesselImg = warper1->GetOutput();

  FloatWarperType::Pointer warper2 = FloatWarperType::New();
  warper2->SetEdgePaddingValue(0);
  warper2->SetInput(vesselImg);
  warper2->SetInterpolator(LinearInterpolatorType::New());
  warper2->SetOutputDirection(vesselImg->GetDirection());
  warper2->SetOutputOrigin(vesselImg->GetOrigin());
  warper2->SetOutputSpacing(vesselImg->GetSpacing());
  //warper2->SetOutputSize(vesselImg->GetLargestPossibleRegion().GetSize());
#if ITK_VERSION_MAJOR >= 4
  warper2->SetDisplacementField(def2Img);
#else
  warper2->SetDeformationField(def2Img);
#endif
  warper2->Update();

  vesselImg = warper2->GetOutput();
  
  {
    // Write p(vessel) as last probability
    std::ostringstream oss;
    oss << outdir << name << "_prob" << outProbImages.GetSize()+1 << ".mha" << std::ends;
    writeUShortProbability(oss.str().c_str(), vesselImg);

  }
#endif

  // Compute p(enhancement)
  muLogMacro(<< "Simulating contrast enhancement in T1...\n");

  // Extract tissue probabilities (no vessel)
  DynArray<FloatImageType::Pointer> tissueOutProbImages;
  for (unsigned int i = 0; i < (outProbImages.GetSize()-1); i++)
    tissueOutProbImages.Append(outProbImages[i]);

  ContrastEnhancementFilter::Pointer enhgen = ContrastEnhancementFilter::New();

  enhgen->SetNumberOfThreads(params->GetNumberOfThreads());
    
  enhgen->SetMode(enhmode);
  enhgen->SetInputProbabilities(tissueOutProbImages);
  enhgen->SetVesselProbability(vesselImg);
  enhgen->Update();

  DynArray<FloatImageType::Pointer> outEnhProbImages = enhgen->GetOutput();

  // Write probabilities
  for (unsigned int k = 0; k < outEnhProbImages.GetSize(); k++)
  {
    std::ostringstream oss;
    oss << outdir << name << "_enh_prob" << k+1 << ".mha" << std::ends;
    writeUShortProbability(oss.str().c_str(), outEnhProbImages[k]);
  }

  enhgen = 0;

  // Blur probability images before image synthesis to make boundaries less
  // obvious

  for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
  {
    typedef itk::DiscreteGaussianImageFilter<FloatImageType, FloatImageType>
      BlurFilterType;
    BlurFilterType::Pointer blurf = BlurFilterType::New();
    blurf->SetInput(outProbImages[k]);
    blurf->SetVariance(0.5);
    blurf->Update();
    outProbImages[k] = blurf->GetOutput();
  }

  for (unsigned int k = 0; k < outEnhProbImages.GetSize(); k++)
  {
    typedef itk::DiscreteGaussianImageFilter<FloatImageType, FloatImageType>
      BlurFilterType;
    BlurFilterType::Pointer blurf = BlurFilterType::New();
    blurf->SetInput(outEnhProbImages[k]);
    blurf->SetVariance(0.5);
    blurf->Update();
    outEnhProbImages[k] = blurf->GetOutput();
  }

  //
  // Generate MR images
  //
  muLogMacro(<< "Generating multimodal MR images\n");

  // Mix textures
  DynArray<float> mixGadWeights;
  mixGadWeights.Append(1.2);
  mixGadWeights.Append(1.0);
  mixGadWeights.Append(0.8);
  mixGadWeights.Append(0.75);
  mixGadWeights.Append(1.0);
  mixGadWeights.Append(2.2);
  mixGadWeights.Append(0.05);

  DynArray<float> mixT1Weights;
  mixT1Weights.Append(1.2);
  mixT1Weights.Append(1.0);
  mixT1Weights.Append(0.8);
  mixT1Weights.Append(0.9);
  mixT1Weights.Append(1.1);
  mixT1Weights.Append(0.05);

  DynArray<float> mixT2Weights;
  mixT2Weights.Append(0.6);
  mixT2Weights.Append(1.0);
  mixT2Weights.Append(1.2);
  mixT2Weights.Append(2.25);
  mixT2Weights.Append(1.0);
  mixT2Weights.Append(1.5);

  DynArray<float> mixFLAIRWeights;
  mixFLAIRWeights.Append(1.0);
  mixFLAIRWeights.Append(1.0);
  mixFLAIRWeights.Append(0.7);
  mixFLAIRWeights.Append(0.8);
  mixFLAIRWeights.Append(1.1);
  mixFLAIRWeights.Append(0.4);

  FloatImageType::Pointer gadImg = FloatImageType::New();
  gadImg->CopyInformation(outProbImages[0]);
  gadImg->SetRegions(outProbImages[0]->GetLargestPossibleRegion());
  gadImg->Allocate();
  gadImg->FillBuffer(0);

  for (unsigned int k = 0; k < gadTextures.GetSize(); k++)
  {
    float w = mixGadWeights[k];

    FloatIteratorType it(gadImg, gadImg->GetLargestPossibleRegion());
    for (it.GoToBegin(); !it.IsAtEnd(); ++it)
    {
      FloatImageType::IndexType ind = it.GetIndex();
      float p = outEnhProbImages[k]->GetPixel(ind);
      float t = gadTextures[k]->GetPixel(ind);
      it.Set(it.Get() + w*p*t);
    }
  }

  FloatImageType::Pointer t1Img = FloatImageType::New();
  t1Img->CopyInformation(outProbImages[0]);
  t1Img->SetRegions(outProbImages[0]->GetLargestPossibleRegion());
  t1Img->Allocate();
  t1Img->FillBuffer(0);

  for (unsigned int k = 0; k < t1Textures.GetSize(); k++)
  {
    float w = mixT1Weights[k];

    FloatIteratorType it(t1Img, t1Img->GetLargestPossibleRegion());
    for (it.GoToBegin(); !it.IsAtEnd(); ++it)
    {
      FloatImageType::IndexType ind = it.GetIndex();
      float p = outProbImages[k]->GetPixel(ind);
      float t = t1Textures[k]->GetPixel(ind);
      it.Set(it.Get() + w*p*t);
    }
  }

  FloatImageType::Pointer t2Img = FloatImageType::New();
  t2Img->CopyInformation(outProbImages[0]);
  t2Img->SetRegions(outProbImages[0]->GetLargestPossibleRegion());
  t2Img->Allocate();
  t2Img->FillBuffer(0);

  for (unsigned int k = 0; k < t2Textures.GetSize(); k++)
  {
    float w = mixT2Weights[k];

    FloatIteratorType it(t2Img, t2Img->GetLargestPossibleRegion());
    for (it.GoToBegin(); !it.IsAtEnd(); ++it)
    {
      FloatImageType::IndexType ind = it.GetIndex();
      float p = outProbImages[k]->GetPixel(ind);
      float t = t2Textures[k]->GetPixel(ind);
      it.Set(it.Get() + w*p*t);
    }
  }

  FloatImageType::Pointer flairImg = FloatImageType::New();
  flairImg->CopyInformation(outProbImages[0]);
  flairImg->SetRegions(outProbImages[0]->GetLargestPossibleRegion());
  flairImg->Allocate();
  flairImg->FillBuffer(0);

  for (unsigned int k = 0; k < flairTextures.GetSize(); k++)
  {
    float w = mixT2Weights[k];

    FloatIteratorType it(flairImg, flairImg->GetLargestPossibleRegion());
    for (it.GoToBegin(); !it.IsAtEnd(); ++it)
    {
      FloatImageType::IndexType ind = it.GetIndex();
      float p = outProbImages[k]->GetPixel(ind);
      float t = flairTextures[k]->GetPixel(ind);
      it.Set(it.Get() + w*p*t);
    }
  }

  // Add background T1 and T2
  ByteImageType::Pointer fgMask = ByteImageType::New();
  fgMask->CopyInformation(gadImg);
  fgMask->SetRegions(gadImg->GetLargestPossibleRegion());
  fgMask->Allocate();
  fgMask->FillBuffer(1);

//TODO: fix T1, T1Gad BG
// BG too bright, why?
// make max(bg) = max(t1)
// FLAIR BG also too bright
// Add new entries on mix*Weights and use them here?

  if (params->GetDrawBackground())
  {
    FloatIteratorType bIt(gadImg, gadImg->GetLargestPossibleRegion());
    for (bIt.GoToBegin(); !bIt.IsAtEnd(); ++bIt)
    {
      FloatImageType::IndexType ind = bIt.GetIndex();
      gadImg->SetPixel(ind,
        gadImg->GetPixel(ind) + mixGadWeights[mixGadWeights.GetSize()-1] * t1bgImg->GetPixel(ind));
      t1Img->SetPixel(ind,
        t1Img->GetPixel(ind) + mixT1Weights[mixT1Weights.GetSize()-1] * t1bgImg->GetPixel(ind));
      t2Img->SetPixel(ind,
        t2Img->GetPixel(ind) + mixT2Weights[mixT2Weights.GetSize()-1] * t2bgImg->GetPixel(ind));
      flairImg->SetPixel(ind,
        flairImg->GetPixel(ind) + mixFLAIRWeights[mixFLAIRWeights.GetSize()-1] * t2bgImg->GetPixel(ind));
    }
  }
  else
  {
    // No background
    // Remove vessels in background of T1-Gad
    DynArray<FloatImageType::Pointer> blurredProbImages;
    for (unsigned int k = 0; k < (probImages.GetSize()-1); k++)
    {
      typedef itk::DiscreteGaussianImageFilter<FloatImageType, FloatImageType>
        BlurFilterType;
      BlurFilterType::Pointer blurf = BlurFilterType::New();
      blurf->SetInput(probImages[k]);
      blurf->SetVariance(4.0);
      blurf->Update();
      blurredProbImages.Append(blurf->GetOutput());
    }

    FloatIteratorType bIt(gadImg, gadImg->GetLargestPossibleRegion());
    for (bIt.GoToBegin(); !bIt.IsAtEnd(); ++bIt)
    {
      FloatImageType::IndexType ind = bIt.GetIndex();
      float sumP = 0;
      for (unsigned int k = 0; k < blurredProbImages.GetSize(); k++)
        sumP += blurredProbImages[k]->GetPixel(ind);
      if (sumP < 0.05)
      {
        gadImg->SetPixel(ind, 0);
        fgMask->SetPixel(ind, 0);
      }
    }
  }

  // Rescale MR images to [0, 4095]
  gadImg = rescaleImage12(gadImg);
  t1Img = rescaleImage12(t1Img);
  t2Img = rescaleImage12(t2Img);
  flairImg = rescaleImage12(flairImg);

  // Add some Gaussian noise
  float gadNoiseStd = params->GetGadNoiseStddev();
  float t1NoiseStd = params->GetT1NoiseStddev();
  float t2NoiseStd = params->GetT2NoiseStddev();
  float flairNoiseStd = params->GetFLAIRNoiseStddev();

  float gadNoiseVar = gadNoiseStd * gadNoiseStd;
  float t1NoiseVar = t1NoiseStd * t1NoiseStd;
  float t2NoiseVar = t2NoiseStd * t2NoiseStd;
  float flairNoiseVar = flairNoiseStd * flairNoiseStd;

  FloatIteratorType rIt(gadImg, gadImg->GetLargestPossibleRegion());

  for (rIt.GoToBegin(); !rIt.IsAtEnd(); ++rIt)
  {
    FloatImageType::IndexType ind = rIt.GetIndex();
    gadImg->SetPixel(ind,
      gadImg->GetPixel(ind) + 
      rng->GenerateNormal(0.0, gadNoiseVar));
    t1Img->SetPixel(ind,
      t1Img->GetPixel(ind) +
      rng->GenerateNormal(0.0, t1NoiseVar));
    t2Img->SetPixel(ind,
      t2Img->GetPixel(ind) + 
      rng->GenerateNormal(0.0, t2NoiseVar));
    flairImg->SetPixel(ind,
      flairImg->GetPixel(ind) + 
      rng->GenerateNormal(0.0, flairNoiseVar));
  }

  // Add randomly generated bias field
  // Make sure coeffs do not get too large
  muLogMacro(<< "Inserting artificial inhomogeneity\n");

  insertBias(gadImg, params->GetGadMaxBiasDegree());
  insertBias(t1Img, params->GetT1MaxBiasDegree());
  insertBias(t2Img, params->GetT2MaxBiasDegree());
  insertBias(flairImg, params->GetFLAIRMaxBiasDegree());

  // Zero out bg if neeeded
  if (!params->GetDrawBackground())
  {
    for (rIt.GoToBegin(); !rIt.IsAtEnd(); ++rIt)
    {
      FloatImageType::IndexType ind = rIt.GetIndex();
      if (fgMask->GetPixel(ind) != 0)
        continue;
      gadImg->SetPixel(ind, 0);
      t1Img->SetPixel(ind, 0);
      t2Img->SetPixel(ind, 0);
      flairImg->SetPixel(ind, 0);
    }
  }

  // Write synthetic images
  muLogMacro(<< "Writing final images\n");

  std::string gadfn = outdir + name + "_T1Gad.mha";
  writeUShortImage(gadfn.c_str(), gadImg);
  std::string t1fn = outdir + name + "_T1.mha";
  writeUShortImage(t1fn.c_str(), t1Img);
  std::string t2fn = outdir + name + "_T2.mha";
  writeUShortImage(t2fn.c_str(), t2Img);
  std::string flairfn = outdir + name + "_FLAIR.mha";
  writeUShortImage(flairfn.c_str(), flairImg);

  // Create discrete ground truth
  ByteImageType::Pointer truthLabelImg = ByteImageType::New();
  truthLabelImg->CopyInformation(outProbImages[0]);
  truthLabelImg->SetRegions(outProbImages[0]->GetLargestPossibleRegion());
  truthLabelImg->Allocate();
  truthLabelImg->FillBuffer(0);

  FloatImageType::Pointer bgProbImg = FloatImageType::New();
  bgProbImg->CopyInformation(outProbImages[0]);
  bgProbImg->SetRegions(outProbImages[0]->GetLargestPossibleRegion());
  bgProbImg->Allocate();
  bgProbImg->FillBuffer(0);

  ByteIteratorType it(truthLabelImg, truthLabelImg->GetLargestPossibleRegion());

  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    FloatImageType::IndexType ind = it.GetIndex();

    float sumP = 0;
    for (unsigned int k = 0; k < outProbImages.GetSize(); k++)
      sumP += outProbImages[k]->GetPixel(ind);

    bgProbImg->SetPixel(ind, 1.0-sumP);
  }

  outProbImages.Append(bgProbImg);

  // Max posterior
  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    FloatImageType::IndexType ind = it.GetIndex();

    float maxP = outProbImages[0]->GetPixel(ind);
    unsigned char c = 1;

    for (unsigned int k = 1; k < outProbImages.GetSize(); k++)
    {
      float p = outProbImages[k]->GetPixel(ind);
      if (p > maxP)
      {
        maxP = p;
        c = k+1;
      }
    }

    //if (maxP < 1e-10 || c == outProbImages.GetSize())
    if (c == outProbImages.GetSize())
      c = 0;

    it.Set(c);
  }

  {
    typedef itk::ImageFileWriter<ByteImageType> ByteWriterType;
    ByteWriterType::Pointer writer = ByteWriterType::New();
    std::string fn = outdir + params->GetDatasetName() + "_discrete_truth.mha";

    writer->SetInput(truthLabelImg);
    writer->SetFileName(fn.c_str());
    writer->UseCompressionOn();
    writer->Update();
  }

  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
  {
    FloatImageType::IndexType ind = it.GetIndex();

    unsigned char c = it.Get();

    // Mark active tumor regions
    if (c == 5)
    {
      float pN = outEnhProbImages[4]->GetPixel(ind);
      float pE = outEnhProbImages[5]->GetPixel(ind);

      if (pE > pN)
        it.Set(outProbImages.GetSize());
    }
  }

  {
    typedef itk::ImageFileWriter<ByteImageType> ByteWriterType;
    ByteWriterType::Pointer writer = ByteWriterType::New();
    std::string fn = outdir + params->GetDatasetName() + "_discrete_enh_truth.mha";

    writer->SetInput(truthLabelImg);
    writer->SetFileName(fn.c_str());
    writer->UseCompressionOn();
    writer->Update();
  }

  // Done
  timer->Stop();

  muLogMacro(<< "All simulation processes took "
    << timer->GetElapsedHours() << " hours, ");
  muLogMacro(<< timer->GetElapsedMinutes() << " minutes, ");
  muLogMacro(<< timer->GetElapsedSeconds() << " seconds\n");

  delete timer;


}
