
// Do texture synthesis given image and mask

#include "itkBinaryBallStructuringElement.h"
#include "itkBinaryCrossStructuringElement.h"
#include "itkBinaryErodeImageFilter.h"
#include "itkCastImageFilter.h"
#include "itkImage.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkRescaleIntensityImageFilter.h"

#include "itkOutputWindow.h"
#include "itkTextOutput.h"

#include "WeiLevoyTextureGenerator.h"

#include "DynArray.h"
#include "Timer.h"

#include <iostream>
#include <sstream>
#include <string>

#include <stdlib.h>

static const int Dimension = 3;

typedef itk::Image<float, Dimension> FloatImageType;
typedef itk::Image<unsigned short, Dimension> UShortImageType;
typedef itk::Image<unsigned char, Dimension> ByteImageType;

static void writeUShortImage(const char* fn, FloatImageType* img)
{
  typedef itk::CastImageFilter<FloatImageType, UShortImageType>
    CasterType;

  CasterType::Pointer rescaler = CasterType::New();
  rescaler->SetInput(img);
  rescaler->Update();

  typedef itk::ImageFileWriter<UShortImageType> WriterType;

  WriterType::Pointer writer = WriterType::New();
  writer->SetFileName(fn);
  writer->SetInput(rescaler->GetOutput());
  writer->UseCompressionOn();
  writer->Update();
}

static int
_real_main(int argc, char** argv)
{
  itk::OutputWindow::SetInstance(itk::TextOutput::New());

  Timer timer;
  timer.Start();

  std::cout << "Reading input image..." << std::endl;
  FloatImageType::Pointer inputImg;
  {
    typedef itk::ImageFileReader<FloatImageType> ReaderType;
    ReaderType::Pointer reader = ReaderType::New();

    reader->SetFileName(argv[1]);
    reader->Update();

    inputImg = reader->GetOutput();
  }

  std::cout << "Reading label image..." << std::endl;
  ByteImageType::Pointer labelImg;
  {
    typedef itk::ImageFileReader<ByteImageType> ReaderType;
    ReaderType::Pointer reader = ReaderType::New();

    reader->SetFileName(argv[2]);
    reader->Update();

    labelImg = reader->GetOutput();
  }

  std::cout << "Computing number of labels..." << std::endl;
  typedef itk::ImageRegionIteratorWithIndex<ByteImageType> LabelIteratorType;
  LabelIteratorType labelIt(labelImg, labelImg->GetLargestPossibleRegion());

  int maxLabel = 0;
  for (labelIt.GoToBegin(); !labelIt.IsAtEnd(); ++labelIt)
  {
    int c = labelIt.Get();
    if (c > maxLabel)
      maxLabel = c;
  }

  // TODO:
  // Currently fixed to brainweb size, make it part of args?
  //FloatImageType::SizeType outSize =
  //  inputImg->GetLargestPossibleRegion().GetSize();
  //FloatImageType::SpacingType outSpacing = inputImg->GetSpacing();
  FloatImageType::SizeType outSize;
/*
  // Old BW
  outSize[0] = 181;
  outSize[1] = 217;
  outSize[2] = 181;
*/
  // New BW
  outSize[0] = 256;
  outSize[1] = 256;
  outSize[2] = 181;
  FloatImageType::SpacingType outSpacing;
  outSpacing.Fill(1.0);

  std::cout << "Generating textures for each label..." << std::endl;
  DynArray<FloatImageType::Pointer> outputTextures;
  for (int i = 1; i <= maxLabel; i++)
  {
    std::cout << "=============================" << std::endl;
    std::cout << "Texture from label " << i << std::endl;

    ByteImageType::Pointer maskImg = ByteImageType::New();
    maskImg->SetRegions(labelImg->GetLargestPossibleRegion());
    maskImg->Allocate();
    maskImg->SetOrigin(labelImg->GetOrigin());
    maskImg->SetSpacing(labelImg->GetSpacing());

    for (labelIt.GoToBegin(); !labelIt.IsAtEnd(); ++labelIt)
    {
      ByteImageType::IndexType ind = labelIt.GetIndex();

      int c = labelIt.Get();
      if (c == i)
        maskImg->SetPixel(ind, 1);
      else
        maskImg->SetPixel(ind, 0);
    }

/*
NOTE: Already part of texture generator
    // Erode mask to avoid sampling region borders
    //typedef itk::BinaryBallStructuringElement<unsigned char, 3>
    //  StructElementType;
    typedef itk::BinaryCrossStructuringElement<unsigned char, 3>
      StructElementType;
    typedef
      itk::BinaryErodeImageFilter<ByteImageType, ByteImageType,
        StructElementType> ErodeType;

    StructElementType structel;
    structel.SetRadius(1);
    structel.CreateStructuringElement();

    ErodeType::Pointer erode = ErodeType::New();
    erode->SetErodeValue(1);
    erode->SetInput(maskImg);
    erode->SetKernel(structel);

    erode->Update();

    maskImg = erode->GetOutput();
*/

    typedef WeiLevoyTextureGenerator<FloatImageType> TextureGenerator;

    TextureGenerator::Pointer textgen = TextureGenerator::New();
    textgen->SetMaskImage(maskImg);
    textgen->SetTextureImage(inputImg);
    textgen->SetNumberOfLevels(1);
    textgen->SetNeighborhoodRadius(2);
    textgen->SetSize(outSize);
    textgen->SetSpacing(outSpacing);

    textgen->Update();

    std::ostringstream oss;
    oss << argv[3] << "_" << i << ".mha" << std::ends;

    std::cout << "Writing " << oss.str() << "..." << std::endl;

    writeUShortImage(oss.str().c_str(), textgen->GetOutput());
  }

  timer.Stop();
  std::cout << "Texture gen took " << timer.GetElapsedHours() << " hours, "
    << timer.GetElapsedMinutes() << " minutes, "
    << timer.GetElapsedSeconds() << " seconds"
    << std::endl;

  return 0;

}

int
main(int argc, char** argv)
{

  if (argc != 4)
  {
    std::cerr << "Usage: " << argv[0];
    std::cerr << " <image> <mask> <output prefix>" << std::endl;
    return -1;
  }

  try
  {
    _real_main(argc, argv);
  }
  catch (itk::ExceptionObject& e)
  {
    std::cerr << e << std::endl;
    return -1;
  }
  catch (std::exception& e)
  {
    std::cerr << "Exception: " << e.what() << std::endl;
    return -1;
  }
  catch (std::string& s)
  {
    std::cerr << "Exception: " << s << std::endl;
    return -1;
  }
  catch (...)
  {
    std::cerr << "Unknown exception" << std::endl;
    return -1;
  }

  return 0;
}
