
#include "SVMTrainApplication.h"

using namespace std;

/*
SVMTrainer::SVMTrainer()
{
}

SVMTrainer::~SVMTrainer()
{
}
*/


bool SVMTrainer::ParseSVMKernel( SampleType::Pointer sample, SVMType::Pointer svm,
               const std::string& kernelTypeString,
               const std::vector<std::string>& kernelParameters)
{

 if (kernelTypeString == "rbf") printf ("Your Kernel type is RBF.\n");

/* 
 if (kernelTypeString == "linear")
  {
    typedef itk::Statistics::LinearSVMKernel<MeasurementVectorType,double>
      LinearKernelType;
    svm->SetKernel(LinearKernelType::New());
    return true;
  }
  else if (kernelTypeString == "polynomial")
  {
    typedef itk::Statistics::
      PolynomialSVMKernel<MeasurementVectorType,double> PolynomialKernelType;
    PolynomialKernelType::Pointer polynomialKernel = 
      PolynomialKernelType::New();

    if (kernelParameters.size() != 3)
    {
      return false;
    }
    polynomialKernel->SetGamma(atof(kernelParameters[0].c_str()));
    polynomialKernel->SetR(atof(kernelParameters[1].c_str()));
    polynomialKernel->SetPower(atof(kernelParameters[2].c_str()));
    svm->SetKernel(polynomialKernel);
    return true;
  }
  else*/ if (kernelTypeString == "rbf")
  {
    //printf ("I am in SVMApplicationUtils::ParseSVMKernel\n");
   // printf ("I am setting RBF Kernel...\n");

    typedef itk::Statistics::RBFSVMKernel<MeasurementVectorType,double>
      RBFKernelType;
    RBFKernelType::Pointer rbfKernel = RBFKernelType::New();

    if (kernelParameters.size() != 1)
    {
      return false;
    }
    svm->SetKernel(rbfKernel);
    rbfKernel->SetGamma(atof(kernelParameters[0].c_str()));

    rbfKernel->SetData(sample);


    cout << "gamma for RBF: " << kernelParameters[0].c_str() << endl;




    return true;
  }


  return false;


}

//int main(int argc, char** argv)
void SVMTrainer::Training(std::string& kernelType, StringVectorType& kernelParameters, std::ostringstream& svmOutputFilename, std::string&  svmOutputFormat, std::string& trainfilename, float epsilon)  
{
  //
  // parse command line arguments
  //


  cout << "Parameters you chose for SVM Training: " << endl;
  cout << "kernelType: " << kernelType << endl;
  cout << "svmOutputFilename: " << svmOutputFilename << endl;
  cout << "svmOutputFormat: " << svmOutputFormat << ends;
  cout << "trainfilename: " << trainfilename << endl;


  if (kernelType.size() == 0)
  {
    std::cerr << "a kernel type must be specified." << endl;
    exit(0);    
  }
  if (svmOutputFilename.str().size() == 0)
  {
    std::cerr << "no support vector machine output file specified." << endl;
    exit(0);    
  }

  //
  // load vectors from file
  //
  SampleType::Pointer           sample           = SampleType::New();


  // NEW file reader
  int pointId;
 
 int numOfSamples;
 int numOfFeatures;
 
 ifstream infile;
 infile.open (trainfilename.c_str()); 
  
 infile >> numOfSamples;
 cout << "numOfSamples= " << numOfSamples << endl;
  
 infile >> numOfFeatures;
 numOfFeatures--;
 cout << "numOfFeatures= " << numOfFeatures << endl;

 unsigned int measurementVectorSize = numOfFeatures;
 sample->SetMeasurementVectorSize(measurementVectorSize);

 SampleType::MeasurementVectorType::ValueType tempdouble;
//  typedef SVMApplicationUtils::ClassLabelType LabelType;
//  LabelType label;
 int label; // not unsigned int


  LabelVectorType lv;
  lv.SetSize(numOfSamples);
  for (pointId=0; pointId<numOfSamples; pointId++) {    
    SampleType::MeasurementVectorType  mv;
    mv.SetSize(numOfFeatures);
    for (int j=0; j<numOfFeatures; j++) {

       infile >> tempdouble;

#ifdef debug
       cout << tempdouble << " " ;	
#endif // debug

       mv.SetElement (j, tempdouble);
    }

#ifdef debug
    cout << endl;
#endif //debug

    infile >> label;
    lv.SetElement (pointId, label);

#ifdef debug
    cout << "label= " << lv.GetElement (pointId) << endl;
#endif // debug

  sample->PushBack(mv);
  //membershipSample->AddInstance (label, instanceID++);

  }
  
  infile.close();

  cout << "DATA LOADED" << endl;




  //
  // create a support vector machine
  //
  SVMType::Pointer svm = SVMType::New();


  //
  // setup the SVM solver
  //
  SVMSolverType::Pointer solver = SVMSolverType::New();

   int tempint = sample->Size();
   printf ("sample_size: %d\n", tempint);

   tempint = sample->GetMeasurementVectorSize();
   printf ("MeasurementVectorSize: %d\n", tempint);

  solver->SetClassTrainingSamples(sample);

  //
  // setup the kernel for the svm
  //
  bool didSetupKernel =
   ParseSVMKernel(sample, svm, kernelType, kernelParameters);
  if (!didSetupKernel)
  {
      std::cerr << "Could not parse kernel: " << endl;

    exit(0);
  }


  solver->SetLabel(lv);
  solver->SetSVM(svm);
  solver->SetEPS(epsilon);     

  std::cout << "Training..." << std::endl;

  solver->train();
  solver->save(svmOutputFilename);



#ifdef SVM_WRITE
  //
  // write SVM configuration to file
  //
  SVMWriterType::Pointer svmWriter = SVMWriterType::New();
  svmWriter->SetFileName(svmOutputFilename.c_str());
  if (svmOutputFormat == "ascii")
  {
    svmWriter->SetWriteASCII(true);
  }
  else 
  {
    svmWriter->SetWriteASCII(false);
  }
  svmWriter->SetInput(svm);
//  if (SVMApplicationUtils::OutputLevelAtLeast(SVMApplicationUtils::Verbose)) 
//  {
    std::cerr << "Writing SVM...";
//  }
  try 
  {
    svmWriter->Update();
 //   if (SVMApplicationUtils::OutputLevelAtLeast(SVMApplicationUtils::Verbose)) 
 //   {
      std::cerr << "DONE" << std::endl;  
 //   }
  }
  catch (...)
  {
    std::cerr << std::endl << "Error writing SVM to file." << std::endl;
  }
#endif // SVM_WRITE



}
