/*=========================================================================

  Program:   Insight Segmentation & Registration Toolkit
  Module:    $RCSfile: itkSVMSolverBase.txx,v $
  Language:  C++
  Date:      $Date: 2006/09/08 19:19:14 $
  Version:   $Revision: 1.3 $

  Copyright (c) Insight Software Consortium. All rights reserved.
  See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even 
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/

#ifndef __SVMSolverBase_txx
#define __SVMSolverBase_txx

using namespace std;

namespace itk
{
namespace Statistics
{
/** Constructor */
template<class TVector, class TOutput>
SVMSolverBase<TVector,TOutput>
::
SVMSolverBase()
{
  sparse_mode = false;
  unshrink_mode = false;
  n_iter_min_to_shrink = 100;

#ifdef USEDOUBLE
  eps_shrink = 1E-9;
  eps_bornes = 1E-12;
#else
  eps_shrink = 1E-4;
  eps_bornes = 1E-4;
#endif

  eps_fin = 0.01;

  deja_shrink = false;

  n_max_unshrink = 1;

}


template<class TVector, class TOutput>
SVMSolverBase<TVector,TOutput>
::
~SVMSolverBase()
{
}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::SetClassTrainingSamples(TXSample* samples)
{

  printf ("SVMSolverBase::SetClassTrainingSamples\n");

  // HERE
  m_SamplePointer = samples;
  
  m_NumSamples = m_SamplePointer->Size();
  m_NumFeatures = m_SamplePointer->GetMeasurementVectorSize();

  cout << "m_NumSamples: " << m_NumSamples << ", m_NumFeatures: " << m_NumFeatures << endl;
}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::SetTestingSamples(TXSample* samples)
{

  printf ("SVMSolverBase::SetTestingSamples\n");

  // HERE
  m_OutputSamplePointer = samples;

  m_NumSamples = m_OutputSamplePointer->Size();
  m_NumFeatures = m_OutputSamplePointer->GetMeasurementVectorSize();

  cout << "m_NumSamples: " << m_NumSamples << ", m_NumFeatures: " << m_NumFeatures << endl;
}


template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::SetLabel(LabelVectorType& label){

  printf ("SVMSolverBase::SetLabel\n");

  m_Label = label;

//   myLabel.SetSize(m_NumSamples);

/*   for (int i=0; i< m_NumSamples; i++) {

    	int tempint = label.GetElement(i);
	myLabel.SetElement(i, tempint);

	cout << "myLabel.GetElement: " << myLabel.GetElement(i) << endl;
   }
*/
  printf ("SVMSolverBase::SetLabel -- OK\n");

}



template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::PrintSelf( std::ostream& os, Indent indent ) const 
{ 
  Superclass::PrintSelf( os, indent ); 
} 


// mjkim for SVMTorch
template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::setOption()
{
}

template<class TXSample, class TYSample>
bool SVMSolverBase<TXSample,TYSample>
::selectVariables(int &i, int &j)
{
//cout << "CALL SVMSolverBase::selectVariables(int &i, int &j)" << endl;

  double gmax_i = -100000.0;
  double gmin_j =  100000.0;
  int i_ = -1;
  int j_ = -1;

#ifdef debug
  cout << "n_active_var: " << n_active_var << endl;

//    //for debug
//  for (int i=0; i<m_NumSamples; i++) {
//   cout << "m_Label.GetElement: " << m_Label.GetElement(i) << endl; 
//   cout << "grad[" << i << "]:" << grad.GetElement(i) << endl;
//  } -- don't use index i!!!
#endif // debug


  for(int it = 0; it < n_active_var; it++)
  {
    int t = active_var[it];

#ifdef debug
    //if(y[t] > 0)
   cout << "m_Label.GetElement: " << m_Label.GetElement(t) << endl;
#endif // debug


    if (m_Label.GetElement(t) > 0)
    {
#ifdef debug
     printf ("y[%d]>0\n", t);
#endif // debug

      if(isNotDown(t))
      {
#ifdef debug
        printf ("isNotDown(%d)\n", isNotDown(t));
#endif // debug

        if(grad[t] > gmax_i)
        {
          gmax_i = grad[t];
          i_ = t;
        }

#ifdef debug
        printf ("gmax_i: %f, i_: %d\n", gmax_i, i_);
#endif // debug
      }

      if(isNotUp(t))
      {
#ifdef debug
        printf ("isNotUp(%d)\n", isNotUp(t));
#endif // debug

        if(grad[t] < gmin_j)
        {
          gmin_j = grad[t];
          j_ = t;
        }
#ifdef debug
        printf ("gmin_j: %f, j_: %d\n", gmin_j, j_);
#endif // debug
      }
    }
    else
    {
#ifdef debug
     printf ("y[%d]<0\n", t);
#endif // debug

      if(isNotUp(t))
      {
#ifdef debug
        printf ("isNotUp(%d)\n", isNotUp(t));
#endif // debug

        if(-grad[t] > gmax_i)
        {
          gmax_i = -grad[t];
          i_ = t;
        }
#ifdef debug
        printf ("gmax_i: %f, i_: %d\n", gmax_i, i_);
#endif // debug

      }

      if(isNotDown(t))
      {
#ifdef debug
        printf ("isNotDown(%d)\n", isNotDown(t));
#endif // debug

        if(-grad[t] < gmin_j)
        {
          gmin_j = -grad[t];
          j_ = t;
        }
#ifdef debug
        printf ("gmin_j: %f, j_: %d\n", gmin_j, j_);
#endif // debug

      }
    }
  }


//  for (i=0; i<m_NumSamples; i++) {
//   cout << "grad[" << i << "]:" << grad.GetElement(i) << endl;
//  }  // don't use index i !!!

  current_error =  gmax_i - gmin_j;

     eps_fin = GetEPS();

    //cout << "eps_fin: " << eps_fin << endl;

    if (current_error < eps_fin) {
    return(true);

   }

  if( (i_ == -1) || (j_ == -1) )
    return(true);

  i = i_;
  j = j_;

  return(false);

}


template<class TXSample, class TYSample>
bool SVMSolverBase<TXSample,TYSample>
::bCompute()
{
  double sum = 0;
  int n_ = 0;
  for(int it = 0; it < n_active_var; it++)
  { 
    int t = active_var[it];
    if( isNotUp(t) && isNotDown(t) )
    {
     // sum += y[t]*grad[t];
        sum += m_Label.GetElement(t)*grad[t];

      n_++;
    }
  }
  
  if(n_)
  { 
    b = -sum/(double)n_;
    return(true);
  }
  else
    return(false);

}

// Renvoie le nb de var susceptibles d'etre shrinkee
template<class TXSample, class TYSample>
int SVMSolverBase<TXSample,TYSample>
::checkShrinking(double bmin, double bmax)
{
  double bb = (bmin+bmax)/2.;

  n_active_var_new = 0;
  for(int it = 0; it < n_active_var; it++)
  {
    int t = active_var[it];
    bool garde = true;

    if(isNotDown(t) && isNotUp(t))
      not_at_bound_at_iter[t] = iter;
    else
    {
      if(isNotUp(t)) // Donc elle est en bas.
      {
        //if(grad[t] + y[t]*bb < eps_shrink)
         if(grad[t] +m_Label.GetElement(t)*bb < eps_shrink)
          not_at_bound_at_iter[t] = iter;
        else
        {
          if( (iter - not_at_bound_at_iter[t]) > n_iter_min_to_shrink)
            garde = false;
        }
      }
      else
      {
        //if(grad[t] + y[t]*bb > -eps_shrink)
        if(grad[t] + m_Label.GetElement(t)*bb > -eps_shrink)
          not_at_bound_at_iter[t] = iter;
        else
        {
          if( (iter - not_at_bound_at_iter[t]) > n_iter_min_to_shrink)
            garde = false;
        }
      }
    }

    if(garde)
      active_var_new[n_active_var_new++] = t;
  }

  return(n_active_var-n_active_var_new);

}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::shrink()
{
  n_active_var = n_active_var_new;
  //int *ptr_sav = active_var;
  IntVectorType ptr_sav = active_var;
  active_var = active_var_new;
  active_var_new = ptr_sav;
  deja_shrink = true;

}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::unShrink()
{
  for(int i = 0; i < m_NumSamples; i++)
    active_var[i] = i;

  n_active_var = m_NumSamples;
  deja_shrink = false;

  if(++n_unshrink == n_max_unshrink)
  {
    unshrink_mode = false;
    n_iter_min_to_shrink = 666666666;
    cout << "shrinking and unshrinking desactived...";
  }

}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::solve()
{
  printf ("\nSVMSolverBase::solve()\n");

  int xi, xj;
  int n_to_shrink = 0;

#ifdef I_WANT_TIME
  long t_start = getRuntime();
#endif

  n_unshrink = 0;
  b = 0;

  //for debug
  //for (int i=0; i<m_NumSamples; i++)
   //cout << "grad.GetElement: " << grad.GetElement(i) << endl; 


  iter = 0;
  while(1)
  {
    if(selectVariables(xi, xj))
    {
      if(unshrink_mode)
      {
        cout << "# Unshrink...";
        cout.flush();
        unShrink();
        if(selectVariables(xi, xj))
        {
          cout << "finished.\n";
          break;
        }
        else
          cout << "restart.\n";
      }
      else
        break;
    }

#ifdef debug
  cout << "xi: " << xi << ", xj: " << xj << endl;
#endif // debug

    if(iter >= n_iter_min_to_shrink)
     // n_to_shrink = checkShrinking(-y[xi]*grad[xi], -y[xj]*grad[xj]);
	n_to_shrink = checkShrinking(-m_Label.GetElement(xi)*grad[xi], -m_Label.GetElement(xj)*grad[xj]);

    
   // TODO: pass solver type
   m_SVM->SetL(m_NumSamples);


     k_xi = adresseCache(xi);
     k_xj = adresseCache(xj); 

    old_alpha_xi = alpha[xi];
    old_alpha_xj = alpha[xj];

    analyticSolve(xi, xj);

    double delta_alpha_xi = alpha[xi] - old_alpha_xi;
    double delta_alpha_xj = alpha[xj] - old_alpha_xj;

    if(deja_shrink && !unshrink_mode)
    {
      for(int t = 0; t < n_active_var; t++)
      {
        int it = active_var[t];
        grad[it] += k_xi[it]*delta_alpha_xi + k_xj[it]*delta_alpha_xj;
      }
    }
    else
    {
      for(int t = 0; t < m_NumSamples; t++)
        grad[t] += k_xi[t]*delta_alpha_xi + k_xj[t]*delta_alpha_xj;
    }

    iter++;
    if(! (iter % 1000) )
    {
      // Pour ne pas effrayer le neophite.
      if(current_error < 0)
        current_error = 0;
      cout << "  + Iteration " <<  iter << "\n";
      cout << "   --> Current error    = " <<  current_error << "\n";
      cout << "   --> Active variables = " <<  n_active_var  << "\n";
      cout.flush();
    }

    /////////////// Shhhhhrinnnk

    if(!(iter % n_iter_min_to_shrink))
    {
      if( (n_to_shrink > n_active_var/10) && (n_active_var-n_to_shrink > 100) )
        shrink();
    }
  }

  // Pour ne pas effrayer le neophite.
  if(current_error < 0)
    current_error = 0;
  cout << "  + Iteration " <<  iter << "\n";
  cout << "   --> Current error    = "  << current_error << "\n";
  cout << "   --> Active variables = "  << n_active_var  << "\n";
  cout.flush();
  
}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::updateStatus(int i)
{
  if(alpha[i] < Cx[i] - eps_bornes)
    status_alpha[i] = 1;
  else
    status_alpha[i] = 0;

  if(alpha[i] > eps_bornes)
    status_alpha[i] |= 2;

}



// for Cache
template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::Cache(double taille_en_megs)
{
  printf ("SVMSolverBase::Cache()\n");
 // Alloc
 // l = svm.l;
  int l = m_NumSamples;

  taille = (int)(taille_en_megs*1048576./((double)sizeof(double)*l));
  index_dans_liste = new Liste *[l];
  cached = new Liste[taille];
  cached_sauve = cached;

  cout << "# Max columns in cache: " << taille << "\n";
  if(taille < 2)
  {
    cout << "$ Change cache size : it's too small.\n\n";
    exit(0);
  }

  // Init
  Liste *ptr = cached;
  for(int i = 0; i < l; i++)
    index_dans_liste[i] = NULL;

  for(int i = 0; i < taille; i++)
  {
    ptr->adr = new double [l];
    ptr->index = -1;
    if(i != 0)
      ptr->prev = (ptr-1);
    else
      ptr->prev = &cached[taille-1];
    if(i != taille-1)
      ptr->suiv = (ptr+1);
    else
      ptr->suiv = cached;

    ptr++;
  }

}

template<class TXSample, class TYSample>
double* SVMSolverBase<TXSample,TYSample>
::adresseCache(int index)
{
#ifdef debug
  printf ("SVMSolverBase::adresseCache()\n");
#endif // debug

  Liste *ptr;

  // Rq: en regression faudrait faire gaffe a pas recalculer deux trucs...
  // Mais pb: -1 +1 a inverser dans la matrice...
  // Donc fuck.

  ptr = index_dans_liste[index];
  if( (ptr != NULL) && (ptr != cached) )
  {
//    cout << "Index " << index << " is already inside" << endl;
    ptr->prev->suiv = ptr->suiv;
    ptr->suiv->prev = ptr->prev;

    ptr->suiv = cached;
    ptr->prev = cached->prev;
    cached->prev->suiv = ptr;
    cached->prev = ptr;
    cached = ptr;
  }
  else
  {
    cached = cached->prev;
    if(cached->index != -1)
      index_dans_liste[cached->index] = NULL;
    cached->index = index;
    index_dans_liste[index] = cached; 

    rempliColonne(index, cached->adr);
  }

//   {
//     cout << "Inside cache:" << endl;
//     Liste *ptx;
//     ptx = cached;
//     for(int i = 0; i < taille; i++)
//     {
//       cout << ptx->index << " ";
//       ptx = ptx->suiv;
//     }
//     cout << endl;
//     getchar();
//   }

  return(cached->adr);


}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::efface()
{
}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::CacheClassification(double taille_en_megs)
{
  printf ("SVMSolverBase::CacheClassification()\n");


}

template<class TXSample, class TYSample>
void SVMSolverBase<TXSample,TYSample>
::rempliColonne(int index, double *adr)
{
//  printf ("SVMSolverBase::rempliColonne()\n");


  if(deja_shrink && !unshrink_mode)
  {
#ifdef debug
    cout << "deja_shrink && !unshrink_mode" << endl;	
#endif // debug

    if (m_Label.GetElement(index)>0)
    {
      for(int it = 0; it < n_active_var; it++)
      {
        int t = active_var[it];
         adr[t] = m_Label.GetElement(t)*m_SVM->GetKernel()->evalue(index, t);

#ifdef debug
        cout << "m_Label.GetElement(" << t << "): " << m_Label.GetElement(t) << " --  " ;
	cout << "m_SVM->GetKernel()->evalue(" <<index<<","<<t<<"): "<<m_SVM->GetKernel()->evalue(index, t) << " ";
#endif // debug
      }

#ifdef debug
       cout << endl;
#endif // debug
    }
    else
    {
      for(int it = 0; it < n_active_var; it++)
      {
        int t = active_var[it];
	adr[t] = -m_Label.GetElement(t)*m_SVM->GetKernel()->evalue(index, t);

#ifdef debug
        cout << "-m_Label.GetElement(" << t << "): " << -m_Label.GetElement(t) << " --  " ;
        cout << "m_SVM->GetKernel()->evalue(" <<index<<","<<t<<"): "<<m_SVM->GetKernel()->evalue(index, t) << " ";
#endif // debug
      }
#ifdef debug
       cout << endl;
#endif // debug
    }
  }
  else
  {
#ifdef debug
    cout << "NOT deja_shrink && !unshrink_mode" << endl;
#endif // debug

     if (m_Label.GetElement(index)>0)
    {
      for(int i = 0; i < m_NumSamples; i++) {
          adr[i] = m_Label.GetElement(i)*m_SVM->GetKernel()->evalue(index, i);

#ifdef debug
        cout << "m_Label.GetElement(" << i << "): " << m_Label.GetElement(i) << " --  " ;
        cout << "m_SVM->GetKernel()->evalue(" <<index<<","<<i<<"): "<<m_SVM->GetKernel()->evalue(index, i) << " ";
#endif // debug
       }
#ifdef debug
       cout << endl;
#endif // debug
    }
    else
    {
      for(int i = 0; i < m_NumSamples; i++) {
	  adr[i] = -m_Label.GetElement(i)*m_SVM->GetKernel()->evalue(index, i);

#ifdef debug
        cout << "-m_Label.GetElement(" << i << "): " << -m_Label.GetElement(i) << " --  " ;
        cout << "m_SVM->GetKernel()->evalue(" <<index<<","<<i<<"): "<<m_SVM->GetKernel()->evalue(index, i) << " ";
#endif // debug
	}
#ifdef debug
       cout << endl;
#endif // debug
    }
  }

#ifdef debug
  for(int i = 0; i < m_NumSamples; i++)
   cout << "adr[" << i << "]: " << adr[i] << endl;
#endif // debug

}


template<class TXSample, class TYSample>
bool SVMSolverBase<TXSample,TYSample>
::isNotUp(int i)  
 {  return(status_alpha[i] != 2); };

template<class TXSample, class TYSample>
bool SVMSolverBase<TXSample,TYSample>
::isNotDown(int i)
 {  return(status_alpha[i] != 1); };


} // namespace Statistics
} // namespace itk







#endif

