/*=========================================================================

Program:   Insight Segmentation & Registration Toolkit
Module:    $RCSfile: itkSVMSolver1.txx,v $
Language:  C++
Date:      $Date: 2006/09/15 15:26:15 $
Version:   $Revision: 1.3 $

Copyright (c) Insight Software Consortium. All rights reserved.
See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.

This software is distributed WITHOUT ANY WARRANTY; without even 
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
PURPOSE.  See the above copyright notices for more information.

=========================================================================*/

#ifndef __SVMSolver1_txx
#define __SVMSolver1_txx

#include "itkSVMSolver1.h"
#include <iomanip>

using namespace std;

//#define debug

namespace itk
{
namespace Statistics
{

/** Constructor */
template<class TVector, class TOutput>
SVMSolver1<TVector,TOutput>
::SVMSolver1()
{
  C = 100;
  svm_is_ready = false;
  support_vectors_is_ready = false;
  n_support_vectors = 0;
  n_support_vectors_bound = 0;
  eps_regression = 0.7;
  cache_size_meg = 50;
  regression_mode = false;

//  outputSamplePointer = SampleType::New();
}

template<class TVector, class TOutput>
SVMSolver1<TVector,TOutput>
::~SVMSolver1()
{
 }
  

template<class TXSample, class TYSample>
void SVMSolver1<TXSample,TYSample>
::train( )
{

  int mvector = Superclass::m_SamplePointer->GetMeasurementVectorSize ();
  printf ("MeasurementVectorSize: %d\n", mvector);

  int l_ = Superclass::m_SamplePointer->Size();
  printf ("l_: %d\n", l_);

  this->m_NumSamples = l_; // Classification mode
  lm = l_;
 

  prepareToLauch();


  cout << "# System loaded\n";  

  this->solve();

  cout << "# System thermonuclearised\n";

  if(!this->bCompute())
  {
    cout << "! Warning : b is not unique. It's probably wrong.\n";
    cout << "! I think you are using silly parameters.\n";
  }

   // La je prie pour que l'utilisateur normal utilise
  // un processeur deterministe ///
  n_support_vectors = 0;
  n_support_vectors_bound = 0;

 int l = l_;
 for(int i = 0; i < l; i++)
  {
    if(this->alpha[i] > this->eps_bornes)
    {
      if(this->alpha[i] > C - this->eps_bornes)
        n_support_vectors_bound++;

      n_support_vectors++;
    }
  }


  support_vectors.SetSize(n_support_vectors);
  sv_alpha.SetSize(n_support_vectors);

  n_support_vectors = 0;
  for(int i = 0; i < l; i++)
  {
    if(this->alpha[i] > this->eps_bornes)
    {
      support_vectors[n_support_vectors] = i;
      if(regression_mode)
      {
        //if(i < n_train_examples)
         if(i < this->m_NumSamples)
          sv_alpha[n_support_vectors++] = -this->alpha[i];
        else
          sv_alpha[n_support_vectors++] =  this->alpha[i];
      }
      else
        //sv_alpha[n_support_vectors++] = y[i]*this->alpha[i];
          sv_alpha[n_support_vectors++] = this->m_Label.GetElement(i)*this->alpha[i];
    }
  }

  /////

  cout << "# " << n_support_vectors << " support vectors\n";
  cout << "# With " << n_support_vectors_bound << " support vectors at C\n";
  support_vectors_is_ready = true;
 
}


/** Print the object */
template<class TXSample, class TYSample>
void SVMSolver1<TXSample,TYSample>
::PrintSelf( std::ostream& os, Indent indent ) const 
{ 
  Superclass::PrintSelf( os, indent ); 
}


template<class TXSample, class TYSample>
void SVMSolver1<TXSample,TYSample>
::prepareToLauch()
{
  printf ("\nSVMSolver1::prepareToLauch()\n");

  cout << "# Squatting memory\n";

  this->n_active_var = this->m_NumSamples;
  this->active_var.SetSize(this->m_NumSamples);
  this->active_var_new.SetSize(this->m_NumSamples);
  this->grad.SetSize(this->m_NumSamples);

  this->not_at_bound_at_iter.SetSize(this->m_NumSamples);
  this->alpha.SetSize(this->m_NumSamples);

  this->status_alpha.SetSize(this->m_NumSamples);
  this->Cx = new double[Superclass::m_NumSamples];

  for(int i = 0; i < this->m_NumSamples; i++)
  {
    this->active_var[i] = i;
    this->alpha[i] = 0;
    this->status_alpha[i] = 1;
    this->not_at_bound_at_iter[i] = 0;
    this->Cx[i] = C;
  }


  //  for classification
  if (regression_mode) {
    regression_mode = false;
    cout << "Classification mode...\n" << endl;
  }

  if(regression_mode)
  {
    for(int i = 0 ; i < lm; i++)
    {
      Superclass::grad[i] =  Superclass::m_Label.GetElement(i) + eps_regression;
      this->m_Label.SetElement(i, 1);
    }
    for(int i = lm; i < Superclass::m_NumSamples ; i++)
    {
      Superclass::grad[i] =  -Superclass::m_Label.GetElement(i-lm) + eps_regression;
      this->m_Label.SetElement(i, -1);
    }
  }
  else
  {
    for(int i = 0; i < Superclass::m_NumSamples; i++) {
        Superclass::grad.SetElement(i, -1); 
#ifdef debug
       cout << "grad.GetElement: " << Superclass::grad.GetElement(i) << endl;
#endif // debug
     }
  }

  
  Superclass::m_SVM->GetKernel()->init();


  this->Cache(cache_size_meg);
  this->CacheClassification(cache_size_meg);

  svm_is_ready = true;

}

template<class TXSample, class TYSample>
void SVMSolver1<TXSample,TYSample>
::analyticSolve(int xi, int xj)
{

//  printf ("SVMSolver1::analyticSolve\n");

  double ww, H, L;

//  double s = Superclass::y[xi]*Superclass::y[xj];
  double s = Superclass::m_Label.GetElement(xi)*Superclass::m_Label.GetElement(xj);

#ifdef debug
  cout << "s : " << s << endl;
#endif // debug

  if(s < 0)
  {
    ww = Superclass::old_alpha_xi - Superclass::old_alpha_xj;
    L = ((ww   > 0.0) ? ww :  0.0);
    H = ((C+ww >   C) ? C  : C+ww);
  }
  else
  {
    ww = Superclass::old_alpha_xi + Superclass::old_alpha_xj;
    L = ((ww-C > 0.0) ? ww-C : 0.0);
    H = ((ww   >   C) ? C    :  ww);
  }


#ifdef debug
  cout << "ww: " << ww << ", L: " << L << ", H: " << H << endl;

  cout << "Superclass::k_xi[" << xi << "]: " << Superclass::k_xi[xi] << endl;
  cout << "Superclass::k_xi[" << xj << "]: " << Superclass::k_xi[xj] << endl;
  cout << "Superclass::k_xj[" << xj << "]: " << Superclass::k_xj[xj] << endl;
#endif // debug

  double eta = Superclass::k_xi[xi] - 2.*s*Superclass::k_xi[xj] + Superclass::k_xj[xj];

#ifdef debug
    cout << "eta: " << eta << endl;
#endif // debug


  if(eta > 0)
  {
    double alph = Superclass::old_alpha_xi + (s*Superclass::grad[xj] - Superclass::grad[xi])/eta;

    if(alph > H)
      alph = H;
    else
    {
      if(alph < L)
        alph = L;
    }

#ifdef debug
   cout << "debugging: before alpha" << endl;
#endif // debug

    Superclass::alpha[xi] = alph;
    Superclass::alpha[xj] -= s*(Superclass::alpha[xi]-Superclass::old_alpha_xi);
  }
  else
  {
    double alph = Superclass::grad[xi] - s*Superclass::grad[xj];
    if(alph > 0)
    {
      Superclass::alpha[xi] = L;
      Superclass::alpha[xj] += s*(Superclass::alpha[xi]-Superclass::old_alpha_xi);
    }
    else
    {
      Superclass::alpha[xi] = H;
      Superclass::alpha[xj] += s*(Superclass::alpha[xi]-Superclass::old_alpha_xi);
    }
  }

#ifdef debug
   cout << "Superclass::alpha[" << xi << "]: " << Superclass::alpha[xi] << endl;
   cout << "Superclass::alpha[" << xj << "]: " << Superclass::alpha[xj] << endl;
#endif // debug


  this->updateStatus(xi);
  this->updateStatus(xj);

}


template<class TXSample, class TYSample>
void SVMSolver1<TXSample,TYSample>
::save(std::ostringstream& svmfilename)
{
  cout << "SVMSolver1::save()" << endl;

  bool sparse_mode = Superclass::sparse_mode;

  if(!support_vectors_is_ready)
  {
    cout << "$ Nothing to save.\n\n";
    exit(0);
  }

  ofstream f(svmfilename.str().c_str(), ios::out | ios::trunc | ios::binary);
  
  cout << "writing " << svmfilename.str().c_str() << endl;

  if(!f)
  {
    cout << "$ File error. Arg.\n" << endl;
    exit(0);
  }

//  f << "# " << VERSION << endl;
//  f << "# " << comment << endl;
  f << "# " << n_support_vectors << " support vectors inside" << endl;
  f << "# With " << n_support_vectors_bound << " support vectors at C" << endl;

  f.write((char *)&regression_mode, sizeof(bool));
  f.write((char *)&sparse_mode, sizeof(bool));


  int kernel_id = 2;
  double std = 6000.0;

  double b = Superclass::b;
  int n_input_dim = Superclass::m_NumFeatures;
  double value;

  XVectorType mv;
  mv.SetSize(Superclass::m_SamplePointer->GetMeasurementVectorSize());


  f.write((char *)&kernel_id, sizeof(int)); // Gaussian kernel_id: 2
  f.write((char *)&std, sizeof(double));

  f.write((char *)&n_support_vectors, sizeof(int));
  f.write((char *)&n_input_dim, sizeof(int));
  f.write((char *)&b, sizeof(double));

  cout << "n_support_vectors: " << n_support_vectors << endl;

  for(int it = 0; it < n_support_vectors; it++)
  {
      int t = support_vectors[it];
  //    cout << "*support_vectors[" << it << "]: " << support_vectors[it] << "  ";
	
      f.write((char *)&sv_alpha[it], sizeof(double));

#ifdef debug
      cout << "*sv_alpha[" << it << "]: " << sv_alpha[it] << "  " ;
#endif // debug

      mv = Superclass::m_SamplePointer->GetMeasurementVector(t%lm);

      for (unsigned int it2 = 0; it2 < mv.Size(); it2++) {
	value = mv.GetElement(it2);
        f.write((char *)&value, sizeof(double));
      }


//      f.write((char *)data[t%lm], sizeof(double)*n_input_dim);
   }
}

template<class TXSample, class TYSample>
void SVMSolver1<TXSample,TYSample>
::load(std::ostringstream& svmfilename)
{

  bool sparse_mode;

  cout << "SVMSolver1::load()" << endl;

  ifstream f(svmfilename.str().c_str(), ios::in | ios::binary);

  cout << "reading " << svmfilename.str().c_str() << endl;

  if(!f)
  {
    cout << "$ File error. Arg.\n" << endl;
    exit(0);
  }

  // delete comments from svm model file 
  {
    char *buffer = new char[1000];
    while(f.peek() == '#')
      f.getline(buffer, 1000);
    delete[] buffer;
  }

  f.read((char *)&regression_mode, sizeof(bool));
  f.read((char *)&sparse_mode, sizeof(bool));

  if (regression_mode) cout << "Regression mode ON" << endl;
  if (sparse_mode) cout << "Sparse data mode ON" << endl;

  // only for RBF(Gaussian)
  // TODO: create Kernel type
 
  int id;
  double std;

  f.read((char *)&id, sizeof(int));
  f.read((char *)&std, sizeof(double));

  if (id==2) cout << "Gaussian(RBF) kernel" << endl; else cout << "NOT Gaussian Kernel" << endl;
  cout <<"std: " << std; 

  int n_input_dim;
  double b; 

  f.read((char *)&n_support_vectors, sizeof(int));
 cout << "n_support_vectors: " << n_support_vectors << endl;


  //TODO: Put together memory allocation for training, testing
  support_vectors.SetSize(n_support_vectors);
  sv_alpha.SetSize(n_support_vectors);

  IntVectorType support_vectors2;
  DoubleVectorType sv_alpha2;

  f.read((char *)&n_input_dim, sizeof(int));
  f.read((char *)&b, sizeof(double));


  Superclass::m_NumFeatures = n_input_dim;
  Superclass::b = b;

   cout << "n_input_dim: " << n_input_dim << endl;
   cout << "b: " << b << endl;


  XVectorType mv;
  mv.SetSize(n_input_dim);

  cout << "Superclass::m_OutputSamplePointer->GetMeasurementVectorSize():" << Superclass::m_OutputSamplePointer->GetMeasurementVectorSize() << endl;

//  OutputSamplePointerType outputSamplePointer = TXSample::New();
  outputSamplePointer = SampleType::New();


  double value;
  // data = new real*[n_support_vectors];
    for(int t = 0; t < n_support_vectors; t++)
    {
      support_vectors[t] = t;
      //cout << "support_vectors[" << t << "]: " << support_vectors[t] << "  ";

      f.read((char *)&sv_alpha[t], sizeof(double));
#ifdef debug
  //    cout << "sv_alpha[" << t << "]: " << sv_alpha[t] << "  ";  
 //     cout << "value read:" << endl;      	
#endif // debug

       for (unsigned int t2 = 0; t2 < mv.Size(); t2++) {
 	 f.read((char *)&value, sizeof(double));
	 mv.SetElement(t2, value);

#ifdef debug
	 //cout << value << " " ;
//	   cout << mv.GetElement(t2) << " ";
#endif // debug
       }

#ifdef debug
	cout << endl;
#endif // debug

        //Superclass::m_OutputSamplePointer->PushBack(mv);
	  outputSamplePointer->PushBack(mv);

      //data[t] = new real[n_input_dim];
      //f.read((char *)data[t], sizeof(real)*n_input_dim);
    }

  cout << "check again: outputSamplePointer->GetMeasurementVector(0)" << endl;
   //cout << "check again:Superclass::m_OutputSamplePointer->GetMeasurementVector(0)" << endl;


 // for debug
  mv = outputSamplePointer->GetMeasurementVector(0);
//  mv = Superclass::m_OutputSamplePointer->GetMeasurementVector(0);
  for(int t = 0; t < n_input_dim; t++) 
	 cout << mv.GetElement(t) << " ";
  cout << endl;
  // -- OK

  Superclass::m_NumSamples = n_support_vectors;
  lm = n_support_vectors;
 // n_train_examples = l;
 // kernel->init();


  // only RBF
    typedef itk::Statistics::RBFSVMKernel<MeasurementVectorType,double>  RBFKernelType;
    RBFKernelType::Pointer rbfKernel = RBFKernelType::New();

  Superclass::m_SVM = SVMType::New();

  Superclass::m_SVM->SetKernel(rbfKernel);
//  rbfKernel->init();
  rbfKernel->SetGamma(std);
//  rbfKernel->SetData(Superclass::m_OutputSamplePointer);
//   rbfKernel->SetData(outputSamplePointer);

#ifdef debug
  cout << "check again 2: Superclass::m_OutputSamplePointer->GetMeasurementVector(0)" << endl;

  mv = Superclass::m_OutputSamplePointer->GetMeasurementVector(0);
  for(int t = 0; t < n_input_dim; t++)      
         cout << mv.GetElement(t) << " ";
  cout << endl;

#endif // debug



//  Superclass::m_SVM->GetKernel()->init();
 // Superclass::m_SVM->GetKernel()->SetData(Superclass::m_OutputSamplePointer);
  Superclass::m_SVM->GetKernel()->SetData(outputSamplePointer);
 Superclass::m_SVM->GetKernel()->init();

  cout << "# " << n_support_vectors << " support vectors in the model" << endl;
  support_vectors_is_ready = true;

 
}


template<class TXSample, class TYSample>
double SVMSolver1<TXSample,TYSample>
::use(MeasurementVectorType& x)
{

#ifdef debug
   cout << "SVMSolver1::use()" << endl;
#endif // debug

//  Superclass::m_SVM->GetKernel()->SetData(Superclass::m_OutputSamplePointer);

  if(!support_vectors_is_ready)
  {
    cout << "$ Error. No support vectors inside.\n\n";
    exit(0);
  }

  double sum = 0;
  for(int it = 0; it < n_support_vectors; it++)
  {
    int t = support_vectors[it];
// LATER
//    sum += sv_alpha[it]*Superclass::m_SVM->GetKernel()->evalue(x, t%lm);
    Superclass::m_SVM->GetKernel()->evalue(x, t%lm);
    sum += sv_alpha[it]*Superclass::m_SVM->GetKernel()->GetEvalvalue();
//    if (it==0) cout << " * " << Superclass::m_SVM->GetKernel()->evalue(x, t%lm) << " " ;
  }
//  cout << endl;

  sum += Superclass::b;

//  cout << "sum: " << sum;

  return(sum);

}



}      // namespace Statistics
}      // namespace itk
#endif // __SVMSolver1_txx

