/*
//
//  Copyright 1997-2009 Torsten Rohlfing
//
//  Copyright 2004-2010 SRI International
//
//  This file is part of the Computational Morphometry Toolkit.
//
//  http://www.nitrc.org/projects/cmtk/
//
//  The Computational Morphometry Toolkit is free software: you can
//  redistribute it and/or modify it under the terms of the GNU General Public
//  License as published by the Free Software Foundation, either version 3 of
//  the License, or (at your option) any later version.
//
//  The Computational Morphometry Toolkit is distributed in the hope that it
//  will be useful, but WITHOUT ANY WARRANTY; without even the implied
//  warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License along
//  with the Computational Morphometry Toolkit.  If not, see
//  <http://www.gnu.org/licenses/>.
//
//  $Revision: 4882 $
//
//  $LastChangedDate: 2013-09-27 18:16:36 -0400 (Fri, 27 Sep 2013) $
//
//  $LastChangedBy: torstenrohlfing $
//
*/

#include <System/cmtkException.h>

#include <Base/cmtkMathUtil.h>
#include <Base/cmtkTypes.h>

#include <algorithm>

namespace
cmtk
{

/** \addtogroup Registration */
//@{

template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
void
MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
::ClearAllChannels()
{
  this->m_HashKeyScaleRef.resize( 0 );
  this->m_HashKeyOffsRef.resize( 0 );

  this->m_HashKeyScaleFlt.resize( 0 );
  this->m_HashKeyOffsFlt.resize( 0 );

  this->Superclass::ClearAllChannels();
}

template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
void
MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
::AddReferenceChannel( UniformVolume::SmartPtr& channel )
{
  const Types::DataItem maxBinIndex = (1<<NBitsPerChannel) - 1;

  const Types::DataItemRange range = channel->GetData()->GetRange();

  const Types::DataItem scale = maxBinIndex / range.Width();
  const Types::DataItem offset = -(range.m_LowerBound/scale);

  this->m_HashKeyScaleRef.push_back( static_cast<TDataType>( scale ) );
  this->m_HashKeyOffsRef.push_back( static_cast<TDataType>( offset ) );
  this->m_HashKeyShiftRef = NBitsPerChannel*this->m_ReferenceChannels.size();

  this->Superclass::AddReferenceChannel( channel );

  const size_t hashKeyBits = 8 * sizeof( THashKeyType );
  if ( this->m_NumberOfChannels * NBitsPerChannel > hashKeyBits )
    {
    StdErr << "ERROR in MultiChannelHistogramRegistrationFunctional:\n"
	      << "  Cannot represent total of " << this->m_NumberOfChannels << " channels with "
	      << NBitsPerChannel << " bits per channel using hash key type with "
	      << hashKeyBits << "bits.\n";
    exit( 1 );
    }
}

template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
void
MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
::AddFloatingChannel( UniformVolume::SmartPtr& channel )
{
  const Types::DataItem maxBinIndex = (1<<NBitsPerChannel) - 1;

  const Types::DataItemRange range = channel->GetData()->GetRange();

  const Types::DataItem scale = maxBinIndex / range.Width();
  const Types::DataItem offset = -(range.m_LowerBound/scale);

  this->m_HashKeyScaleFlt.push_back( static_cast<TDataType>( scale ) );
  this->m_HashKeyOffsFlt.push_back( static_cast<TDataType>( offset ) );

  this->Superclass::AddFloatingChannel( channel );

  const size_t hashKeyBits = 8 * sizeof( THashKeyType );
  if ( this->m_NumberOfChannels * NBitsPerChannel > hashKeyBits )
    {
    StdErr << "ERROR in MultiChannelHistogramRegistrationFunctional:\n"
	      << "  Cannot represent total of " << this->m_NumberOfChannels << " channels with "
	      << this->m_HistogramBitsPerChannel << " bits per channel using hash key type with "
	      << hashKeyBits << "bits.\n";
    exit( 1 );
    }
}

template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
void
MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
::ContinueMetric( MetricData& metricData, const size_t rindex, const Vector3D& fvector )
{
  std::vector<Types::DataItem> values( this->m_NumberOfChannels );
  
  size_t idx = 0;
  for ( size_t ref = 0; ref < this->m_ReferenceChannels.size(); ++ref )
    {
    if ( !this->m_ReferenceChannels[ref]->GetDataAt( values[idx++], rindex ) ) return;
    }
  
  for ( size_t flt = 0; flt < this->m_FloatingChannels.size(); ++flt )
    {
    if ( !this->m_FloatingInterpolators[flt]->GetDataAt( fvector, values[idx++] ) ) return;
    }

  metricData += values;
}

template<class TDataType,class TInterpolator,class THashKeyType,char NBitsPerChannel>
Functional::ReturnType
MultiChannelHistogramRegistrationFunctional<TDataType,TInterpolator,THashKeyType,NBitsPerChannel>
::GetMetric( const MetricData& metricData ) const
{
  if ( metricData.m_TotalNumberOfSamples )
    {
    const double norm = 1.0 / metricData.m_TotalNumberOfSamples;

    double hXY = 0;
    typename MetricData::HashTableType::const_iterator it = metricData.m_JointHash.begin();
    for ( ; it != metricData.m_JointHash.end(); ++it )
      {
      if ( it->second )
	{
	const double p = norm * it->second;
	hXY -= p * log( p );
	}
      }

    double hX = 0;
    it = metricData.m_ReferenceHash.begin();
    for ( ; it != metricData.m_ReferenceHash.end(); ++it )
      {
      if ( it->second )
	{
	const double p = norm * it->second;
	hX -= p * log( p );
	}
      }

    double hY = 0;
    it = metricData.m_FloatingHash.begin();
    for ( ; it != metricData.m_FloatingHash.end(); ++it )
      {
      if ( it->second )
	{
	const double p = norm * it->second;
	hY -= p * log( p );
	}
      }
    
    if ( this->m_NormalizedMI )
      return static_cast<Functional::ReturnType>( (hX + hY) / hXY );
    else
      return static_cast<Functional::ReturnType>( hX + hY - hXY );
    }
  
  return static_cast<Functional::ReturnType>( -FLT_MAX );
}

} // namespace cmtk
