
#include "LLSBiasCorrector.h"

#include "itkImage.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkRescaleIntensityImageFilter.h"
#include "itkOutputWindow.h"
#include "itkTextOutput.h"
#include "itkImageDuplicator.h"

#include <iostream>
#include <sstream>
#include <vector>
#include <string>

#include <math.h>
#include <stdlib.h>

#include "ProbabilisticBiasCorrectionCLP.h"

int main(int argc, char** argv)
  {
  PARSE_ARGS;
  
  bool violated=false;
  if (inputVolume1.size() == 0) { violated = true; std::cout << "  --inputVolume1 Required! "  << std::endl; }
  if (inputWhiteProbability.size() == 0) { violated = true; std::cout << "  --inputWhiteProbability Required! "  << std::endl; }
  if (outputVolume1.size() == 0) { violated = true; std::cout << "  --outputVolume1 Required! "  << std::endl; }
  if (violated) exit(1);

  // Use text output
  itk::TextOutput::Pointer textout = itk::TextOutput::New();
  itk::OutputWindow::SetInstance(textout);

  typedef itk::Image<float, 3> FloatImageType;
  typedef itk::Image<short, 3> ShortImageType;

  // Read probabilities
  typedef itk::ImageFileReader<FloatImageType> ReaderType;
  ReaderType::Pointer whiteReader = ReaderType::New();
  whiteReader->SetFileName(inputWhiteProbability.c_str());
  whiteReader->Update();
 
  std::vector<FloatImageType::Pointer> probImages;
  probImages.push_back(whiteReader->GetOutput());
  unsigned int numProbs = 1;
  
  if(inputGrayProbability.size() != 0)
    {
    ReaderType::Pointer grayReader = ReaderType::New();
    grayReader->SetFileName(inputGrayProbability.c_str());
    grayReader->Update();
    probImages.push_back(grayReader->GetOutput());
    numProbs++;
    }

  if(inputCSFProbability.size() != 0)
    {
    ReaderType::Pointer csfReader = ReaderType::New();
    csfReader->SetFileName(inputCSFProbability.c_str());
    csfReader->Update();
    probImages.push_back(csfReader->GetOutput());
    numProbs++;
    }

  // Read input images
  typedef itk::ImageFileReader<FloatImageType> ReaderType;
  ReaderType::Pointer vol1Reader = ReaderType::New();
  vol1Reader->SetFileName(inputVolume1.c_str());
  vol1Reader->Update();

  std::vector<FloatImageType::Pointer> inputImages;
 
  unsigned int numImages = 1; 
  inputImages.push_back(vol1Reader->GetOutput());
   
  if(inputVolume2.size() !=0 )
    {
    ReaderType::Pointer vol2Reader = ReaderType::New();
    vol2Reader->SetFileName(inputVolume2.c_str());
    vol2Reader->Update();
    inputImages.push_back(vol2Reader->GetOutput());
    numImages++;
    }  
  
  if(inputVolume3.size() !=0 )
    {
    ReaderType::Pointer vol3Reader = ReaderType::New();
    vol3Reader->SetFileName(inputVolume3.c_str());
    vol3Reader->Update();
    inputImages.push_back(vol3Reader->GetOutput());
    numImages++;
    }
  
  // Create output images
  std::vector<FloatImageType::Pointer> outputImages;
  
  typedef itk::ImageDuplicator< FloatImageType > DuplicateImageType;
  DuplicateImageType::Pointer duplicateImage = DuplicateImageType::New();
  for(unsigned int k=0;k<numImages;k++)
    {
    duplicateImage->SetInputImage(inputImages[k]);
    duplicateImage->Modified();
    duplicateImage->Update();
    outputImages.push_back(duplicateImage->GetOutput()); 
    } 
  //
  // Do bias correction
  //

  typedef LLSBiasCorrector<FloatImageType, FloatImageType> BiasCorrectorType;

  // Compute mask for getting samples (sum of probs > 0)
  BiasCorrectorType::MaskImageType::Pointer maskImg =
    BiasCorrectorType::MaskImageType::New();
  maskImg->SetRegions(inputImages[0]->GetLargestPossibleRegion());
  maskImg->Allocate();
  maskImg->SetSpacing(inputImages[0]->GetSpacing());
  maskImg->FillBuffer(0);

  typedef itk::ImageRegionIteratorWithIndex<BiasCorrectorType::MaskImageType>
    IteratorType;
  IteratorType it(maskImg, maskImg->GetLargestPossibleRegion());
  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
    {
    FloatImageType::IndexType ind = it.GetIndex();
    double sump = 0;
    for (unsigned int i = 0; i < numProbs; i++)
      sump += probImages[i]->GetPixel(ind);
    if (sump > 0)
      it.Set(1);
    }

  BiasCorrectorType::Pointer biascorr = BiasCorrectorType::New();

  biascorr->SetClampBias(false);
  biascorr->SetSampleSpacing(4.0);
  biascorr->SetMask(maskImg);
  // Make MaxDegree an input variable.  For high bias set to 6.
  //biascorr->SetMaxDegree(4); // Most MRI protocols do well with poly order 4
  biascorr->SetMaxDegree(inputPolynomialDegree); // Most MRI protocols do well with poly order 4
  biascorr->SetProbabilities(probImages);

  // Do correction and store results in place (in place breaks it, doesn't it?
  // Method params is input, output, do all voxels
  //biascorr->CorrectImages(inputImages, inputImages, true);
  biascorr->CorrectImages(inputImages, outputImages, true);

  // Write output as short image in [0, 4096]

  //typedef itk:ImageFileWriter<FloatImageType> FloatWriterType;
  //FloatWriterType::Pointer imageWriter = FloatWriterType::New();
  //imageWriter->SetInput(InputImages[0]);

  std::vector<std::string> outputFileNames;
  outputFileNames.push_back(outputVolume1);
  
  if(outputVolume2.size() != 0 && inputVolume2.size() != 0)
  {
    outputFileNames.push_back(outputVolume2);
  }
  if(outputVolume3.size() != 0 && inputVolume3.size() != 0)
  {
    outputFileNames.push_back(outputVolume3);
  }

  for (unsigned int i = 0; i < numImages; i++)
    {
    typedef itk::RescaleIntensityImageFilter<FloatImageType, ShortImageType>
      ConverterType;

    ConverterType::Pointer converter = ConverterType::New();
    converter->SetOutputMinimum(0);
    converter->SetOutputMaximum(4096);
    //converter->SetInput(inputImages[i]);
    converter->SetInput(outputImages[i]);
    converter->Update();
    
    typedef itk::ImageFileWriter<ShortImageType> WriterType;
    WriterType::Pointer writer = WriterType::New();

    std::ostringstream oss;
    //oss << "corrected_image" << i << ".nii.gz" << std::ends;
 
    //writer->SetFileName(oss.str().c_str());
    writer->SetFileName(outputFileNames[i]);
    writer->SetInput(converter->GetOutput());
    writer->Update();
    }

  return 0;

  }
