#include "itkListSample.h"
#include "SurvivalPredictor.h"
#include "itkCSVNumericObjectFileWriter.h"
#include "itkCovarianceSampleFilter.h"
#include "itkMeanSampleFilter.h"

typedef itk::Image< float, 3 > ImageType;

SurvivalPredictor::~SurvivalPredictor()
{
  //delete mFeatureExtractionLocalPtr;
  //delete mFeatureScalingLocalPtr;
}
void SurvivalPredictor::LoadQualifiedSubjectsFromGivenDirectory(std::string directoryname,
  std::vector<std::string> & qualifiedSubjectNames,
  std::vector<std::string> & t1FileNames,
  std::vector<std::string> & t1ceFileNames,
  std::vector<std::string> & t2FileNames,
  std::vector<std::string> & t2FlairFileNames,
  std::vector<std::string> & axFileNames,
  std::vector<std::string> & faFileNames,
  std::vector<std::string> & radFileNames,
  std::vector<std::string> & trFileNames,
  std::vector<std::string> & rcbvFileNames,
  std::vector<std::string> & psrFileNames,
  std::vector<std::string> & phFileNames,
  std::vector<std::string> & labelNames,
  std::vector<std::string> & rejectedSubjectNames)
{
  std::vector<std::string> subjectNames = cbica::subdirectoriesInDirectory(directoryname);
  mLastEncounteredError = "";

  for (unsigned int sid = 0; sid < subjectNames.size(); sid++)
  {
    std::string subjectPath = directoryname + "/" + subjectNames[sid];

    std::string t1ceFilePath = "";
    std::string t1FilePath = "";
    std::string t2FilePath = "";
    std::string t2FlairFilePath = "";
    std::string axFilePath = "";
    std::string faFilePath = "";
    std::string radFilePath = "";
    std::string trFilePath = "";
    std::string rcbvFilePath = "";
    std::string psrFilePath = "";
    std::string phFilePath = "";
    std::string labelPath = "";

    std::vector<std::string> files;

    if (cbica::directoryExists(subjectPath + "/SEGMENTATION"))
    {
      files = cbica::filesInDirectory(subjectPath + "/SEGMENTATION");
      if (files.size() == 1)
      {
        labelPath = subjectPath + "/SEGMENTATION" + "/" + files[0];
      }
      else
      {
        for (unsigned int i = 0; i < files.size(); i++)
        {
          std::string filePath = subjectPath + "/SEGMENTATION" + "/" + files[i], filePath_lower;
          std::string extension = cbica::getFilenameExtension(filePath);
          filePath_lower = filePath;
          std::transform(filePath_lower.begin(), filePath_lower.end(), filePath_lower.begin(), ::tolower);
          if ((filePath_lower.find("label-map") != std::string::npos || filePath_lower.find("label") != std::string::npos
            || filePath_lower.find("segmentation") != std::string::npos || filePath_lower.find("labelmap") != std::string::npos)
            && (extension == HDR_EXT || extension
            == NII_EXT || extension == NII_GZ_EXT))
            labelPath = subjectPath + "/SEGMENTATION" + "/" + files[i];
        }
      }
    }

    if (cbica::directoryExists(subjectPath + "/T1CE"))
    {
      files = cbica::filesInDirectory(subjectPath + "/T1CE");
      if (files.size() == 1)
      {
        t1ceFilePath = subjectPath + "/T1CE" + "/" + files[0];
      }
      else
      {
        for (unsigned int i = 0; i < files.size(); i++)
        {
          std::string filePath = subjectPath + "/T1CE" + "/" + files[i], filePath_lower;
          std::string extension = cbica::getFilenameExtension(filePath);
          filePath_lower = filePath;
          std::transform(filePath_lower.begin(), filePath_lower.end(), filePath_lower.begin(), ::tolower);
          if ((filePath_lower.find("t1ce") != std::string::npos || files[i].find("t1-ce") != std::string::npos
            || files[i].find("t1c") != std::string::npos || files[i].find("t1_ce") != std::string::npos 
            || files[i].find("t1gd") != std::string::npos || files[i].find("t1-gd") != std::string::npos
            || files[i].find("t1_gd") != std::string::npos)
            && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
            t1ceFilePath = subjectPath + "/T1CE" + "/" + files[i];
        }
      }
    }

    if (cbica::directoryExists(subjectPath + "/T1"))
    {
      files = cbica::filesInDirectory(subjectPath + "/T1");
      if (files.size() == 1)
      {
        t1FilePath = subjectPath + "/T1" + "/" + files[0];
      }
      else
      {
        for (unsigned int i = 0; i < files.size(); i++)
        {
          std::string filePath = subjectPath + "/T1" + "/" + files[i], filePath_lower;
          std::string extension = cbica::getFilenameExtension(filePath);
          filePath_lower = filePath;
          std::transform(filePath_lower.begin(), filePath_lower.end(), filePath_lower.begin(), ::tolower);
          if ((filePath_lower.find("t1") != std::string::npos) 
            && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
            t1FilePath = subjectPath + "/T1" + "/" + files[i];
        }
      }
    }

    if (cbica::directoryExists(subjectPath + "/T2"))
    {
      files = cbica::filesInDirectory(subjectPath + "/T2");
      if (files.size() == 1)
      {
        t2FilePath = subjectPath + "/T2" + "/" + files[0];
      }
      else
      {
        for (unsigned int i = 0; i < files.size(); i++)
        {
          std::string filePath = subjectPath + "/T2" + "/" + files[i], filePath_lower;
          std::string extension = cbica::getFilenameExtension(filePath);
          filePath_lower = filePath;
          std::transform(filePath_lower.begin(), filePath_lower.end(), filePath_lower.begin(), ::tolower);
          if ((filePath_lower.find("t2") != std::string::npos) 
            && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
            t2FilePath = subjectPath + "/T2" + "/" + files[i];
        }
      }
    }

    if (cbica::directoryExists(subjectPath + "/FLAIR"))
    {
      files = cbica::filesInDirectory(subjectPath + "/FLAIR");
      if (files.size() == 1)
      {
        t2FlairFilePath = subjectPath + "/FLAIR" + "/" + files[0];
      }
      else
      {
        for (unsigned int i = 0; i < files.size(); i++)
        {
          std::string filePath = subjectPath + "/FLAIR" + "/" + files[i], filePath_lower;
          std::string extension = cbica::getFilenameExtension(filePath);
          filePath_lower = filePath;
          std::transform(filePath_lower.begin(), filePath_lower.end(), filePath_lower.begin(), ::tolower);
          if ((filePath_lower.find("flair") != std::string::npos) 
            && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
            t2FlairFilePath = subjectPath + "/FLAIR" + "/" + files[i];
        }
      }
    }

    if (cbica::directoryExists(subjectPath + "/PERFUSION"))
    {
      files = cbica::filesInDirectory(subjectPath + "/PERFUSION");
      for (unsigned int i = 0; i < files.size(); i++)
      {
        std::string filePath = subjectPath + "/PERFUSION" + "/" + files[i], filePath_lower;
        std::string extension = cbica::getFilenameExtension(filePath);
        filePath_lower = filePath;
        std::transform(filePath_lower.begin(), filePath_lower.end(), filePath_lower.begin(), ::tolower);
        if ((filePath_lower.find("rcbv") != std::string::npos)
          && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
          rcbvFilePath = subjectPath + "/PERFUSION" + "/" + files[i];
        else if ((filePath_lower.find("psr") != std::string::npos)
          && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
          psrFilePath = subjectPath + "/PERFUSION" + "/" + files[i];
        else if ((filePath_lower.find("ph") != std::string::npos)
          && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
          phFilePath = subjectPath + "/PERFUSION" + "/" + files[i];
      }
    }

    if (cbica::directoryExists(subjectPath + "/DTI"))
    {
      files = cbica::filesInDirectory(subjectPath + "/DTI");
      for (unsigned int i = 0; i < files.size(); i++)
      {
        std::string filePath = subjectPath + "/DTI" + "/" + files[i], filePath_lower;
        std::string extension = cbica::getFilenameExtension(filePath);
        filePath_lower = filePath;
        std::transform(filePath_lower.begin(), filePath_lower.end(), filePath_lower.begin(), ::tolower);

        if ((filePath_lower.find("ax") != std::string::npos || filePath_lower.find("axial") != std::string::npos) 
          && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
          axFilePath = subjectPath + "/DTI" + "/" + files[i];
        else if ((filePath_lower.find("fa") != std::string::npos || files[i].find("fractional") != std::string::npos) 
          && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
          faFilePath = subjectPath + "/DTI" + "/" + files[i];
        else if ((filePath_lower.find("rad") != std::string::npos || filePath_lower.find("radial") != std::string::npos) 
          && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
          radFilePath = subjectPath + "/DTI" + "/" + files[i];
        else if ((filePath_lower.find("tr") != std::string::npos || filePath_lower.find("trace") != std::string::npos) 
          && (extension == HDR_EXT || extension == NII_EXT || extension == NII_GZ_EXT))
          trFilePath = subjectPath + "/DTI" + "/" + files[i];
      }
    }
    if (labelPath == "")
    {
      rejectedSubjectNames.push_back(subjectNames[sid]);
      continue;
    }
    if (t1FilePath == "" || t2FilePath == "" || t1ceFilePath == "" || t2FlairFilePath == "" || rcbvFilePath == "" || axFilePath == "" || faFilePath
      == "" || radFilePath == "" || trFilePath == "" || psrFilePath == "" || phFilePath == "")
    {
      rejectedSubjectNames.push_back(subjectNames[sid]);
      continue;
    }
    t1ceFileNames.push_back(t1ceFilePath);
    t1FileNames.push_back(t1FilePath);
    t2FileNames.push_back(t2FilePath);
    t2FlairFileNames.push_back(t2FlairFilePath);
    axFileNames.push_back(axFilePath);
    faFileNames.push_back(faFilePath);
    radFileNames.push_back(radFilePath);
    trFileNames.push_back(trFilePath);
    rcbvFileNames.push_back(rcbvFilePath);
    psrFileNames.push_back(psrFilePath);
    phFileNames.push_back(phFilePath);
    labelNames.push_back(labelPath);
    qualifiedSubjectNames.push_back(subjectNames[sid]);
  }

}


std::vector<double> SurvivalPredictor::GetStatisticalFeatures(std::vector<double> intensities)
{
  std::vector<double> StatisticalFeatures;

  double temp = 0.0;
  double mean = 0.0;
  double std = 0.0;

  for (int i = 0; i < intensities.size(); i++)
    temp = temp + intensities[i];
  mean = temp / intensities.size();

  for (int i = 0; i < intensities.size(); i++)
    temp = temp + (intensities[i] - mean)*(intensities[i] - mean);
  std = std::sqrt(temp / (intensities.size() - 1));

  StatisticalFeatures.push_back(mean);
  StatisticalFeatures.push_back(std);


  return StatisticalFeatures;
}
std::vector<double> SurvivalPredictor::GetHistogramFeatures(std::vector<double> intensities, int start, int interval, int end)
{
  VariableLengthVectorType BinCount;
  std::vector<double> finalBins;
  std::vector<std::vector<double>> Ranges;
  int counter = 0;
  for (int i = start; i <= end; i = i + interval)
  {
    std::vector<double> onerange;
    int lowerbound = i - (interval / 2);
    if (lowerbound < 0)
      lowerbound = 0;

    int upperbound = i + (interval / 2);
    if (upperbound >255)
      upperbound = 255;

    //if (i+interval>end)
    // upperbound = 255;

    onerange.push_back(lowerbound);
    onerange.push_back(upperbound);
    Ranges.push_back(onerange);
  }
  std::vector<double> finalRange = Ranges[Ranges.size() - 1];
  finalRange[1] = 255;

  Ranges.resize(Ranges.size() - 1);
  Ranges.push_back(finalRange);

  BinCount.SetSize(Ranges.size());
  for (int j = 0; j < Ranges.size(); j++)
  {
    std::vector<double> onerange = Ranges[j];
    int counter = 0;
    for (int i = 0; i < intensities.size(); i++)
    {
      if (onerange[0] == 0)
      {
        if (intensities[i] >= onerange[0] && intensities[i] <= onerange[1])
          counter = counter + 1;
      }
      else
      {
        if (intensities[i] > onerange[0] && intensities[i] <= onerange[1])
          counter = counter + 1;
      }
    }
    finalBins.push_back(counter);
  }
  return finalBins;
}

std::vector<double> SurvivalPredictor::GetVolumetricFeatures(double edemaSize, double tuSize, double neSize, double totalSize)
{
  std::vector<double> VolumetricFeatures;
  VolumetricFeatures.push_back(tuSize);
  VolumetricFeatures.push_back(neSize);
  VolumetricFeatures.push_back(edemaSize);
  VolumetricFeatures.push_back(totalSize);

  VolumetricFeatures.push_back(tuSize + neSize);
  VolumetricFeatures.push_back(100 * ((tuSize + neSize) / totalSize));
  VolumetricFeatures.push_back(100 * (edemaSize / totalSize));
  VolumetricFeatures.push_back(100 * (tuSize / (tuSize + neSize)));
  VolumetricFeatures.push_back(100 * (neSize / (tuSize + neSize)));

  return VolumetricFeatures;
}


void SurvivalPredictor::PrepareNewSurvivalPredictionModel(const std::string inputdirectory, const std::string outputdirectory)
{
  VariableSizeMatrixType FeaturesOfAllSubjects;
  std::vector<std::string> t1ceFileNames;
  std::vector<std::string> t1FileNames;
  std::vector<std::string> t2FileNames;
  std::vector<std::string> t2FlairFileNames;
  std::vector<std::string> axFileNames;
  std::vector<std::string> faFileNames;
  std::vector<std::string> radFileNames;
  std::vector<std::string> trFileNames;
  std::vector<std::string> rcbvFileNames;
  std::vector<std::string> psrFileNames;
  std::vector<std::string> phFileNames;
  std::vector<std::string> labelNames;
  std::vector<std::string> rejectedSubjectNames;
  std::vector<std::string> qualifiedSubjectNames;
  //std::vector<int> qualifiedIndices;

  LoadQualifiedSubjectsFromGivenDirectory(inputdirectory,
    qualifiedSubjectNames,
    t1FileNames,
    t1ceFileNames,
    t2FileNames,
    t2FlairFileNames,
    axFileNames,
    faFileNames,
    radFileNames,
    trFileNames,
    rcbvFileNames,
    psrFileNames,
    phFileNames,
    labelNames,
    rejectedSubjectNames);

  // //--------------------------------------------------------------------
  // //qualifiedIndices.push_back(1);
  // //qualifiedIndices.push_back(2);
  // //qualifiedIndices.push_back(3);
  // //qualifiedSubjectNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC");
  // //phFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC_PH.nii.gz");
  // //psrFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC_PSR.nii.gz");
  // //labelNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_labels.nii.gz");
  // //trFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_TR_to_t1ce_str.nii.gz");
  // //axFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_AX_to_t1ce_str.nii.gz");
  // //faFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_FA_to_t1ce_str.nii.gz");
  // //radFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_RAD_to_t1ce_str.nii.gz");
  // //t1FileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_t1_sus_byte_n3_r_strip_hist.nii.gz");
  // //t2FileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_t2_sus_byte_n3_r_strip_hist.nii.gz");
  // //t2FlairFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_flair_sus_byte_n3_r_strip_hist.nii.gz");
  // //t1ceFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_t1ce_sus_byte_n3_strip_hist.nii.gz");
  // //rcbvFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAC/AAAC0_rcbv_r_strip.nii.gz");

  // //qualifiedSubjectNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA");
  // //phFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_PH.nii.gz");
  // //psrFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_PSR.nii.gz");
  // //labelNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_labels.nii.gz");
  // //rcbvFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_rcbv_r_strip.nii.gz");
  // //trFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_TR_to_t1ce_str.nii.gz");
  // //axFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_AX_to_t1ce_str.nii.gz");
  // //faFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_FA_to_t1ce_str.nii.gz");
  // //radFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_RAD_to_t1ce_str.nii.gz");
  // //t1FileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_t1_sus_byte_n3_r_strip_hist.nii.gz");
  // //t2FileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_t2_sus_byte_n3_r_strip_hist.nii.gz");
  // //t2FlairFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_flair_sus_byte_n3_r_strip_hist.nii.gz");
  // //t1ceFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAA/AAAA0_t1ce_sus_byte_n3_strip_hist.nii.gz");
  // //qualifiedSubjectNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB");
  // //phFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB_PH.nii.gz");
  // //psrFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB_PSR.nii.gz");
  // //rcbvFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_rcbv_r_strip.nii.gz");
  // //labelNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_labels.nii.gz");
  // //trFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_TR_to_t1ce_str.nii.gz");
  // //axFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_AX_to_t1ce_str.nii.gz");
  // //faFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_FA_to_t1ce_str.nii.gz");
  // //radFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_RAD_to_t1ce_str.nii.gz");
  // //t1FileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_t1_sus_byte_n3_r_strip_hist.nii.gz");
  // //t2FileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_t2_sus_byte_n3_r_strip_hist.nii.gz");
  // //t2FlairFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_flair_sus_byte_n3_r_strip_hist.nii.gz");
  // //t1ceFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_t1ce_sus_byte_n3_strip_hist.nii.gz");

  std::vector<double> ages;
  std::vector<double> survival;

  //// generic csv reader which uses header information to parse
  //auto parsedCSV = cbica::parseCSVFile(inputdirectory + "/features.csv", "Survival,Ages", "", false, false);
  //for (size_t i = 0; i < parsedCSV.size(); i++)
  //{
  //  survival.push_back(std::atof(parsedCSV[i].inputImages[0].c_str()));
  //  ages.push_back(std::atof(parsedCSV[i].inputImages[1].c_str()));
  //}

  typedef itk::CSVArray2DFileReader<double> ReaderType;
  ReaderType::Pointer readerMean = ReaderType::New();
  readerMean->SetFileName(inputdirectory + "/features.csv");
  readerMean->SetFieldDelimiterCharacter(',');
  readerMean->HasColumnHeadersOff();
  readerMean->HasRowHeadersOff();
  readerMean->Parse();
  typedef vnl_matrix<double> MatrixType;
  MatrixType dataMatrix = readerMean->GetArray2DDataObject()->GetMatrix();

  for (unsigned int i = 0; i < dataMatrix.rows(); i++)
  {
    ages.push_back(dataMatrix(i, 0));
    survival.push_back(dataMatrix(i, 1));
  }

  //readerMean->SetFileName(inputdirectory + "/trainingsheet.csv");
  //readerMean->SetFieldDelimiterCharacter(',');
  //readerMean->HasColumnHeadersOff();
  //readerMean->HasRowHeadersOff();
  //readerMean->Parse();
  //dataMatrix = readerMean->GetArray2DDataObject()->GetMatrix();
  //FeaturesOfAllSubjects.SetSize(20, 167);
  //for (unsigned int i = 0; i < dataMatrix.rows(); i++)
  //{
  //  for (unsigned int j = 0; j < dataMatrix.cols(); j++)
  //  {
  //    FeaturesOfAllSubjects(i, j) = dataMatrix(i, j);
  //  }
  //}
  // //---------------------------------------------------------------------------


  FeaturesOfAllSubjects.SetSize(qualifiedSubjectNames.size(), 170);

  for (unsigned int sid = 0; sid < qualifiedSubjectNames.size(); sid++)
  {
    ImageType::Pointer T1CEImagePointer;
    ImageType::Pointer T2FlairImagePointer;
    ImageType::Pointer T1ImagePointer;
    ImageType::Pointer T2ImagePointer;
    ImageType::Pointer AXImagePointer;
    ImageType::Pointer RADImagePointer;
    ImageType::Pointer FAImagePointer;
    ImageType::Pointer TRImagePointer;
    ImageType::Pointer RCBVImagePointer;
    ImageType::Pointer PSRImagePointer;
    ImageType::Pointer PHImagePointer;
    ImageType::Pointer LabelImagePointer;

    LabelImagePointer = ReadNiftiImage<ImageType>(labelNames[sid]);
    RCBVImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(rcbvFileNames[sid]));
    PSRImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(psrFileNames[sid]));
    PHImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(phFileNames[sid]));
    T1CEImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(t1ceFileNames[sid]));
    T2FlairImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(t2FlairFileNames[sid]));
    T1ImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(t1FileNames[sid]));
    T2ImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(t2FileNames[sid]));
    AXImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(axFileNames[sid]));
    RADImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(radFileNames[sid]));
    FAImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(faFileNames[sid]));
    TRImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(trFileNames[sid]));


    std::vector<double> TestFeatures = LoadTestData<ImageType>(T1CEImagePointer, T2FlairImagePointer, T1ImagePointer, T2ImagePointer,
      RCBVImagePointer, PSRImagePointer, PHImagePointer, AXImagePointer, FAImagePointer, RADImagePointer, TRImagePointer, LabelImagePointer);

    typedef vnl_matrix<double> MatrixType;
    MatrixType data;

    //data.set_size(170, 1);
    //for (int i = 0; i < TestFeatures.size(); i++)
    //  data(i, 0) = TestFeatures[i];
    //typedef itk::CSVNumericObjectFileWriter<double, 170, 1> WriterType;
    //WriterType::Pointer writer = WriterType::New();
    //writer->SetFileName(qualifiedSubjectNames[sid] + "_tData.csv");
    //writer->SetInput(&data);
    //writer->Write();
    for (int i = 0; i < TestFeatures.size(); i++)
      FeaturesOfAllSubjects(sid, i) = TestFeatures[i];
  }



  VariableSizeMatrixType scaledFeatureSet;
  VariableLengthVectorType meanVector;
  VariableLengthVectorType stdVector;

  mFeatureScalingLocalPtr.ScaleGivenTrainingFeatures(FeaturesOfAllSubjects, scaledFeatureSet, meanVector, stdVector);
  VariableSizeMatrixType ScaledFeatureSetAfterAddingAge;
  ScaledFeatureSetAfterAddingAge.SetSize(scaledFeatureSet.Rows(), scaledFeatureSet.Cols() + 1);
  for (unsigned int i = 0; i < scaledFeatureSet.Rows(); i++)
  {
    unsigned int j = 0;
    for (j = 0; j < scaledFeatureSet.Cols(); j++)
    {
      ScaledFeatureSetAfterAddingAge(i, j) = scaledFeatureSet(i, j);
    }
    ScaledFeatureSetAfterAddingAge(i, j) = ages[i]; 
  }

  VariableSizeMatrixType SixModelFeatures;
  VariableSizeMatrixType EighteenModelFeatures;
  mFeatureExtractionLocalPtr.FormulateSurvivalTrainingData(ScaledFeatureSetAfterAddingAge, survival, SixModelFeatures, EighteenModelFeatures);

  // //--------writing in files--------------------------------
  // //-----------------------writing in files to compare results------------------------------
   typedef vnl_matrix<double> MatrixType;
   MatrixType data;
  
   data.set_size(170, 1); // TOCHECK - are these hard coded sizes fine?
   for (unsigned int i = 0; i < meanVector.Size(); i++)
     data(i, 0) = meanVector[i];
   typedef itk::CSVNumericObjectFileWriter<double, 170, 1> WriterType;
   WriterType::Pointer writer = WriterType::New();
   writer->SetFileName(outputdirectory + "/mean.csv");
   writer->SetInput(&data);
   writer->Write();
  
   for (unsigned int i = 0; i < stdVector.Size(); i++)
     data(i, 0) = stdVector[i];
   writer->SetFileName(outputdirectory + "/std.csv");
   writer->SetInput(&data);
   writer->Write();
  //
  // data.set_size(7, 171);
  // for (unsigned int i = 0; i < ScaledFeatureSetAfterAddingAge.Rows(); i++)
  // {
  //   for (unsigned int j = 0; j < ScaledFeatureSetAfterAddingAge.Cols(); j++)
  //   {
  //     data(i, j) = ScaledFeatureSetAfterAddingAge(i, j);
  //   }
  // }
  // typedef itk::CSVNumericObjectFileWriter<double, 7, 171> WriterTypeMatrix;
  // WriterTypeMatrix::Pointer writermatrix = WriterTypeMatrix::New();
  // writermatrix->SetFileName(outputdirectory + "/scaledfeatures.csv");
  // writermatrix->SetInput(&data);
  // writermatrix->Write();
  //
  // data.set_size(7, 172);
  // for (unsigned int i = 0; i < SixModelFeatures.Rows(); i++)
  // {
  //   for (unsigned int j = 0; j < SixModelFeatures.Cols(); j++)
  //   {
  //     data(i, j) = SixModelFeatures(i, j);
  //   }
  // }
  // writermatrix->SetFileName(outputdirectory + "/sixmodel.csv");
  // writermatrix->SetInput(&data);
  // writermatrix->Write();
  // 
  // for (unsigned int i = 0; i < EighteenModelFeatures.Rows(); i++)
  // {
  //   for (unsigned int j = 0; j < EighteenModelFeatures.Cols(); j++)
  //   {
  //     data(i, j) = EighteenModelFeatures(i, j);
  //   }
  // }
  // writermatrix->SetFileName(outputdirectory + "/eighteenmodel.csv");
  // writermatrix->SetInput(&data);
  // writermatrix->Write();
  //---------------------------------------------------------------------------

  //SixModelFeatures.SetSize(7, 172);
  //EighteenModelFeatures.SetSize(7, 172);
  if (!trainOpenCVSVM(SixModelFeatures, outputdirectory + "/" + mSixTrainedFile, true, false).empty())
  {
    trainOpenCVSVM(EighteenModelFeatures, outputdirectory + "/" + mEighteenTrainedFile, true, false);
  }
  else
  {
    ShowErrorMessage("Training for Six month survivor patients failed. Please check.");
  }
}

inline std::vector< double > estimateCombination(const std::vector< double > &estimates1, const std::vector< double > &estimates2)
{
  std::vector< double > returnVec;
  returnVec.resize(estimates1.size());
  for (size_t i = 0; i < estimates1.size(); i++)
  {
    float temp_abs, temp_pos1, temp_neg1, temp_1, temp_2;
    // estimate for 1st vector
    if (std::abs(estimates1[i]) < 2)
    {
      temp_abs = estimates1[i];
    }
    else
    {
      temp_abs = 0;
    }

    if (estimates1[i] > 1)
    {
      temp_pos1 = 1;
    }
    else
    {
      temp_pos1 = 0;
    }

    if (estimates1[i] < -1)
    {
      temp_neg1 = 1;
    }
    else
    {
      temp_neg1 = 0;
    }
    temp_1 = temp_abs + (temp_pos1 - temp_neg1);

    // estimate for 2nd vector, all temp values are getting overwritten
    if (std::abs(estimates2[i]) < 2)
    {
      temp_abs = estimates2[i];
    }
    else
    {
      temp_abs = 0;
    }

    if (estimates2[i] > 1)
    {
      temp_pos1 = 1;
    }
    else
    {
      temp_pos1 = 0;
    }

    if (estimates2[i] < -1)
    {
      temp_neg1 = 1;
    }
    else
    {
      temp_neg1 = 0;
    }
    temp_2 = temp_abs + (temp_pos1 - temp_neg1);

    // combine the two
    returnVec[i] = temp_1 + temp_2;
  }

  return returnVec;
}

void SurvivalPredictor::SurvivalPredictionOnExistingModel(const std::string modeldirectory, const std::string inputdirectory, const std::string
  outputdirectory)
{
  VariableSizeMatrixType FeaturesOfAllSubjects;
  std::vector<std::string> t1ceFileNames;
  std::vector<std::string> t1FileNames;
  std::vector<std::string> t2FileNames;
  std::vector<std::string> t2FlairFileNames;
  std::vector<std::string> axFileNames;
  std::vector<std::string> faFileNames;
  std::vector<std::string> radFileNames;
  std::vector<std::string> trFileNames;
  std::vector<std::string> rcbvFileNames;
  std::vector<std::string> psrFileNames;
  std::vector<std::string> phFileNames;
  std::vector<std::string> labelNames;
  std::vector<std::string> rejectedSubjectNames;
  std::vector<std::string> qualifiedSubjectNames;
  std::vector<int> qualifiedIndices;

  LoadQualifiedSubjectsFromGivenDirectory(inputdirectory,
    qualifiedSubjectNames,
    t1FileNames,
    t1ceFileNames,
    t2FileNames,
    t2FlairFileNames,
    axFileNames,
    faFileNames,
    radFileNames,
    trFileNames,
    rcbvFileNames,
    psrFileNames,
    phFileNames,
    labelNames,
    rejectedSubjectNames);

  //qualifiedSubjectNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB");
  //phFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB_PH.nii.gz");
  //psrFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB_PSR.nii.gz");
  //rcbvFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_rcbv_r_strip.nii.gz");
  //labelNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_labels.nii.gz");
  //trFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_TR_to_t1ce_str.nii.gz");
  //axFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_AX_to_t1ce_str.nii.gz");
  //faFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_FA_to_t1ce_str.nii.gz");
  //radFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_RAD_to_t1ce_str.nii.gz");
  //t1FileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_t1_sus_byte_n3_r_strip_hist.nii.gz");
  //t2FileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_t2_sus_byte_n3_r_strip_hist.nii.gz");
  //t2FlairFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_flair_sus_byte_n3_r_strip_hist.nii.gz");
  //t1ceFileNames.push_back("Z:/Projects/SurvivalIntegration/TrainingData/AAAB/AAAB0_t1ce_sus_byte_n3_strip_hist.nii.gz");



  //-------------------------------------------------
  typedef itk::CSVArray2DFileReader<double> ReaderType;
  ReaderType::Pointer reader = ReaderType::New();
  typedef vnl_matrix<double> MatrixType;
  std::vector<double> ages;

  //// generic csv reader which uses header information to parse
  //auto parsedCSV = cbica::parseCSVFile(inputdirectory + "/features.csv", "Survival,Ages", "", false, false);
  //for (size_t i = 0; i < parsedCSV.size(); i++)
  //{
  //  survival.push_back(std::atof(parsedCSV[i].inputImages[0].c_str()));
  //  ages.push_back(std::atof(parsedCSV[i].inputImages[1].c_str()));
  //}

  MatrixType dataMatrix;
  if (cbica::fileExists(inputdirectory + "/features.csv"))
  {
    reader->SetFileName(inputdirectory + "/features.csv");
    reader->SetFieldDelimiterCharacter(',');
    reader->HasColumnHeadersOff();
    reader->HasRowHeadersOff();
    reader->Parse();
    dataMatrix = reader->GetArray2DDataObject()->GetMatrix();
  }
  else
  {
    m_LastError = "Couldn't find the file 'features.csv' in the input directory.";
    return;
  }

  for (unsigned int i = 0; i < dataMatrix.rows(); i++)
  {
    ages.push_back(dataMatrix(i, 0));
  }

  //reader->SetFileName(inputdirectory + "/testsheet.csv");
  //reader->SetFieldDelimiterCharacter(',');
  //reader->HasColumnHeadersOff();
  //reader->HasRowHeadersOff();
  //reader->Parse();
  //dataMatrix = reader->GetArray2DDataObject()->GetMatrix();
  //FeaturesOfAllSubjects.SetSize(10, 167);
  //for (unsigned int i = 0; i < dataMatrix.rows(); i++)
  //{
  //  for (unsigned int j = 0; j < dataMatrix.cols(); j++)
  //  {
  //    FeaturesOfAllSubjects(i, j) = dataMatrix(i, j);
  //  }
  //}

  MatrixType meanMatrix;
  if (cbica::fileExists(modeldirectory + "/mean.csv"))
  {
    reader->SetFileName(modeldirectory + "/mean.csv");
    reader->SetFieldDelimiterCharacter(',');
    reader->HasColumnHeadersOff();
    reader->HasRowHeadersOff();
    reader->Parse();
    meanMatrix = reader->GetArray2DDataObject()->GetMatrix();
  }
  else
  {
    m_LastError = "Couldn't find the file 'mean.csv' in the input directory.";
    return;
  }

  MatrixType stdMatrix;
  if (cbica::fileExists(modeldirectory + "/std.csv"))
  {
    reader->SetFileName(modeldirectory + "/std.csv");
    reader->SetFieldDelimiterCharacter(',');
    reader->HasColumnHeadersOff();
    reader->HasRowHeadersOff();
    reader->Parse();
    stdMatrix = reader->GetArray2DDataObject()->GetMatrix();
  }
  else
  {
    m_LastError = "Couldn't find the file 'std.csv' in the input directory.";
    return;
  }

  VariableLengthVectorType mean;
  VariableLengthVectorType stddevition;
  mean.SetSize(meanMatrix.size());
  stddevition.SetSize(meanMatrix.size());
  for (unsigned int i = 0; i < meanMatrix.size(); i++)
  {
    mean[i] = meanMatrix(i, 0);
    stddevition[i] = stdMatrix(i, 0);
  }

  //----------------------------------------------------
  FeaturesOfAllSubjects.SetSize(qualifiedSubjectNames.size(), 170);

  for (unsigned int sid = 0; sid < qualifiedSubjectNames.size(); sid++)
  {
    ImageType::Pointer T1CEImagePointer;
    ImageType::Pointer T2FlairImagePointer;
    ImageType::Pointer T1ImagePointer;
    ImageType::Pointer T2ImagePointer;
    ImageType::Pointer AXImagePointer;
    ImageType::Pointer RADImagePointer;
    ImageType::Pointer FAImagePointer;
    ImageType::Pointer TRImagePointer;
    ImageType::Pointer RCBVImagePointer;
    ImageType::Pointer PSRImagePointer;
    ImageType::Pointer PHImagePointer;
    ImageType::Pointer LabelImagePointer;

    LabelImagePointer = ReadNiftiImage<ImageType>(labelNames[sid]);
    RCBVImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(rcbvFileNames[sid]));
    PSRImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(psrFileNames[sid]));
    PHImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(phFileNames[sid]));
    T1CEImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(t1ceFileNames[sid]));
    T2FlairImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(t2FlairFileNames[sid]));
    T1ImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(t1FileNames[sid]));
    T2ImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(t2FileNames[sid]));
    AXImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(axFileNames[sid]));
    RADImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(radFileNames[sid]));
    FAImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(faFileNames[sid]));
    TRImagePointer = RescaleImageIntensity<ImageType>(ReadNiftiImage<ImageType>(trFileNames[sid]));

    std::vector<double> TestFeatures = LoadTestData<ImageType>(T1CEImagePointer, T2FlairImagePointer, T1ImagePointer, T2ImagePointer,
      RCBVImagePointer, PSRImagePointer, PHImagePointer, AXImagePointer, FAImagePointer, RADImagePointer, TRImagePointer, LabelImagePointer);
    //------------------integrate with the test code------------------
    for (int i = 0; i < TestFeatures.size(); i++)
      FeaturesOfAllSubjects(sid, i) = TestFeatures[i];
  }
  VariableSizeMatrixType ScaledTestingData = mFeatureScalingLocalPtr.ScaleGivenTestingFeatures(FeaturesOfAllSubjects, mean, stddevition);

  VariableSizeMatrixType ScaledFeatureSetAfterAddingAge;
  ScaledFeatureSetAfterAddingAge.SetSize(ScaledTestingData.Rows(), ScaledTestingData.Cols() + 2);
  for (unsigned int i = 0; i < ScaledTestingData.Rows(); i++)
  {
    unsigned int j = 0;
    for (j = 0; j < ScaledTestingData.Cols(); j++)
    {
      ScaledFeatureSetAfterAddingAge(i, j) = ScaledTestingData(i, j);
    }
    ScaledFeatureSetAfterAddingAge(i, j) = ages[i];
    ScaledFeatureSetAfterAddingAge(i, j + 1) = 0;
  }

  try
  {
    // VectorVectorDouble resultSixMonths = mClassificationLocalPtr->Testing(ScaledTestingData, true, modeldirectory + "/SixMonthsModel.model");
    //VectorDouble resultSixMonths = testOpenCVSVM(ScaledTestingData, modeldirectory + "/SixMonthsModel.model");
    //VectorVectorDouble resultEighteenMonths = mClassificationLocalPtr->Testing(ScaledTestingData, true, modeldirectory + "/EighteenModelFile.model");
    VectorDouble resultEighteenMonths = testOpenCVSVM(ScaledFeatureSetAfterAddingAge, modeldirectory + "/" + mEighteenTrainedFile);
    VectorDouble resultSixMonths = testOpenCVSVM(ScaledFeatureSetAfterAddingAge, modeldirectory + "/" + mSixTrainedFile);

    // Estimates=(abs(Estimates)<=1).*Estimates+(Estimates>1)-(Estimates<-1);
    VectorDouble results = estimateCombination(resultSixMonths, resultEighteenMonths);
    std::ofstream myfile;
    myfile.open(outputdirectory + "/results.csv");
    myfile << "SubjectName,Result\n";

    for (size_t i = 0; i < results.size(); i++)
    {
      myfile << qualifiedSubjectNames[i] + "," + std::to_string(results[i]) + "\n";
    }
    myfile.close();

    // write qualifiedSubjectNames[i] and results[i] in a csv - this will be a variable size csv

  }
  catch (itk::ExceptionObject & excp)
  {
    cbica::Logging(loggerFile, "Error caught during testing: " + std::string(excp.GetDescription()));
    //exit(EXIT_FAILURE);
  }

}