#include "SVMTestApplication.h"

using namespace std;
//#define debug

typedef double                                      ValueType;
typedef itk::VariableLengthVector<ValueType>       MeasurementVectorType;
typedef itk::Statistics::NumericsListSample<MeasurementVectorType> SampleType;
typedef itk::Statistics::SupportVectorMachine<MeasurementVectorType>        SVMType;
typedef itk::Statistics::SVMSolver1<SampleType, SampleType>SVMSolverType;
typedef itk::VariableLengthVector<double>          TargetVectorType;

#ifdef SVM_WRITE
typedef itk::SVMFileWriter<MeasurementVectorType, TargetVectorType>  SVMWriterType;
#endif // SVM_WRITE
typedef double                                        ClassLabelType;
typedef itk::VariableLengthVector<ClassLabelType>  LabelVectorType;


SVMTester::SVMTester()
{
}

SVMTester::~SVMTester()
{
}


void  SVMTester::Testing(std::ostringstream& testfilename, std::ostringstream& svmFilename, std::ostringstream& labelOutputFilename) 
{

  cout << "SVMTester::Testing" << endl;

  cout << "SVM filename: " << svmFilename.str() << endl;


  if (svmFilename.str().size() == 0)
  {
    cerr << "no support vector machine file specified." << endl;
    exit(0);    
  }

  SampleType::Pointer           sample           = SampleType::New();

  // NEW file reader
  int pointId;

 int numOfSamples;
 int numOfFeatures;

 ifstream infile;
 infile.open (testfilename.str().c_str());

 cout << "Reading " << testfilename.str().c_str() << endl;

  infile >> numOfSamples;
  cout << "numOfSamples= " << numOfSamples << endl;

  infile >> numOfFeatures;
  //numOfFeatures--;
  cout << "numOfFeatures= " << numOfFeatures << endl;

  unsigned int measurementVectorSize = numOfFeatures;
  sample->SetMeasurementVectorSize(measurementVectorSize);

  SampleType::MeasurementVectorType::ValueType tempdouble;

  // parametres.desired (target existed)

  LabelVectorType lv;
  lv.SetSize(numOfSamples);

  for (pointId=0; pointId<numOfSamples; pointId++) {
    SampleType::MeasurementVectorType  mv;
    mv.SetSize(numOfFeatures);
    for (int j=0; j<numOfFeatures; j++) {

       infile >> tempdouble;

#ifdef debug
       cout << tempdouble << " " ;
#endif //debug

       mv.SetElement (j, tempdouble);
    }

#ifdef debug
    cout << endl;
#endif //debug

  //  infile >> label;   // CAUTION!!! LATER
  //  lv.SetElement (pointId, label);

#ifdef debug
    cout << "label= " << lv.GetElement (pointId) << endl;
#endif //debug

  sample->PushBack(mv);
  //membershipSample->AddInstance (label, instanceID++);

  }


  infile.close();

  cout << "DATA LOADED" << endl;



  //
  // load the support vector machine from file
  //
  SVMType::Pointer svm;

#ifdef SVM_READER
  SVMReaderType::Pointer svmReader = SVMReaderType::New();
  svmReader->SetFileName(svmFilename);

  try
  {
    svmReader->Update();
    svm = svmReader->GetSVM();
  }
  catch (...)
  {
    std::cerr << "Error loading SVM from file." << std::endl;
    exit(0);
  }
#endif // SVM_READER


/*
  //
  // classify vectors using svm
  //
  std::vector<ClassLabelType> labels(sample->Size());
  unsigned int labelIndex = 0;
  for (SampleType::ConstIterator i = sample->Begin(); i != sample->End(); ++i)
  {
   // if (SVMApplicationUtils::
   //     OutputLevelAtLeast(SVMApplicationUtils::Verbose)) 
    //{
      std::cerr << std::setw(6) << std::setfill('0')
                << labelIndex << " --- "; 
    //}
    double svmLabel = svm->Classify(i.GetMeasurementVector());
    labels[labelIndex] = (svmLabel < 0.0 
                          ? SVMApplicationUtils::Class2Label 
                          : SVMApplicationUtils::Class1Label);
    //if (SVMApplicationUtils::
    //    OutputLevelAtLeast(SVMApplicationUtils::Verbose)) 
   // {
      std::cerr << labels[labelIndex] << std::endl; 
    //}
    ++labelIndex;
  }

*/
  ////////////////
  // Test

  cout << "testing sample size: " << sample->Size() << endl;

  SVMSolverType::Pointer solver = SVMSolverType::New();

  solver->SetTestingSamples(sample);


  // load svm file
  solver->load(svmFilename);  

  // classification mode
   cout << "# Classification mode" << endl;


  int n_aff_prochain = 0;
  int n_deja_aff = 0;
  cout << "                    ";
  cout << "[________Test________]" << endl;
  cout << "                    [";
  cout.flush();

  // l, y_pred

  int l, c;
  l = numOfSamples; c = numOfFeatures;

  LabelVectorType classify_lv;
  classify_lv.SetSize(numOfSamples);

  cout << "l: " << l << endl;

  FILE* fp;
 // fp = fopen("test.result", "w");
   fp = fopen (labelOutputFilename.str().c_str(), "w");

  // non-sparse mode
    for(int i = 0; i < l; i++)
    {
      SampleType::MeasurementVectorType  mv;
      mv.SetSize(numOfFeatures);

      mv = sample->GetMeasurementVector(i);

      // for debug
      //cout << "print mv" << endl;
      //for (int j=0; j<20; j++)
      //  cout << mv.GetElement(j) << " " ;
      // --

      tempdouble = solver->use(mv);
      classify_lv.SetElement(i, tempdouble);

#ifdef debug
//	cout <<"classfied label: " << tempdouble << "  ";
#endif // debug

	fprintf (fp, "%f\n", tempdouble);

		
 //     y_pred[i] = emilie.use(data[i]);

      if(i >= n_aff_prochain)
      {
        n_aff_prochain = ++n_deja_aff * l / 20;
        cout << "#";
        cout.flush();
      }
    }

   cout << endl;
  cout << "]" << endl;


  fclose(fp);

   printf("after svm_test\n");

 
 
  // only for non-sparse, classification -- from main of SVMTest
  // parametres.desired == true (target existed)


    int missclassified = 0;
    int missclassified_pos = 0;
    for(int i = 0; i < numOfSamples; i++)
    {
      if(lv.GetElement(i)*classify_lv.GetElement(i) <= 0)
      {
        missclassified++;
        if(lv.GetElement(i) < 0)
          missclassified_pos++;
      }
    }

    double z = 100.0 * ((double)missclassified) / ((double)l);
    cout << endl;
                cout << "# Number of missclassified          : " << missclassified << " [" << setprecision(4) << z << "%]" << endl;
                cout << "     -> False positives             : " << missclassified_pos << endl;
                cout << "     -> False negatives             : " << missclassified-missclassified_pos << endl;
    z = 100.0 * ((double)(l-missclassified)) / ((double)numOfSamples);
    cout << "# Number of correct classifications : " << l-missclassified << " ["  << setprecision(4) <<  z << "%]" << endl;

  
  

  //
  // write labels to file
  //

#ifdef FILE_EXPORT

  if (labelOutputFilename.size() != 0)
  {
    try
    {
      std::ofstream output(labelOutputFilename.c_str());
      std::copy(labels.begin(), labels.end(), 
                std::ostream_iterator<ClassLabelType>(output, " "));
      if (output.bad()) throw;
      output.close();
    }
    catch (...)
    {
      std::cerr << "Error writing labels to file." << std::endl;
      exit(0);
    }
  }
#endif 

}
