#include <itkImageFileReader.h>
#include <itkImageFileWriter.h>
#include <itkImage.h>
#include <itkCovariantVector.h>
#include <itkImageToImageFilter.h>
#include <string>
#include <iostream>
#include <vector>
#include <itkImageLinearIteratorWithIndex.h>
#include <vnl_cholesky.h>

int usage()
{
  printf("wmahal : weighted means and inverse covariance for Mahalanobis distance computation\n");
  printf("usage:\n");
  printf("  wmahal [options] \n");
  printf("required options:\n");
  printf("  -d <N>           : Image dimensionality (2,3,4)\n");
  printf("  -f <imagelist>   : List of displacement fields for which statistics will be computed\n");
  printf("  -w <imagelist>   : List of weight images used to weight fields\n");
  printf("  -o <image>       : Output filename\n");
  printf("additional options:\n");
  printf("  -I <epsilon>     : Add an identity matrix scaled by epsilon to covariance matrix\n");
  return -1;
}

template <class TVectorInput, class TWeightInput, class TOutputImage>
class WeightedMahalanobisStatisticsImageFilter
: public itk::ImageToImageFilter<TVectorInput, TOutputImage>
{
public:
  typedef WeightedMahalanobisStatisticsImageFilter<TVectorInput, TWeightInput, TOutputImage> Self;
  typedef itk::ImageToImageFilter<TVectorInput, TOutputImage> Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  itkTypeMacro(WeightedMahalanobisStatisticsImageFilter, itk::InPlaceImageFilter)

  itkNewMacro(Self)

  /** Some convenient typedefs. */
  typedef TVectorInput                                 VectorImageType;
  typedef TWeightInput                                 WeightImageType;
  typedef TOutputImage                                 OutputImageType;
  typedef typename OutputImageType::Pointer            OutputImagePointer;
  typedef typename OutputImageType::RegionType         OutputImageRegionType;
  typedef typename OutputImageType::PixelType          OutputImagePixelType;
  typedef typename VectorImageType::PixelType          VectorType;
  typedef typename WeightImageType::PixelType          WeightType;

  /** Add a displacement and a weight map */
  void AddInputPair(VectorImageType *vec, WeightImageType *wgt)
    {
    this->AddInput(vec);
    this->AddInput(wgt);
    }

  itkSetMacro(IdentityWeight, double)
  itkGetMacro(IdentityWeight, double)

protected:
  WeightedMahalanobisStatisticsImageFilter() {}
  ~WeightedMahalanobisStatisticsImageFilter() {}

  virtual void ThreadedGenerateData(
    const OutputImageRegionType & outputRegionForThread,
    itk::ThreadIdType threadId) ITK_OVERRIDE
    {
    // Get the output
    OutputImageType *out = this->GetOutput();

    // Get the input pointers
    int n = this->GetNumberOfIndexedInputs() / 2;
    std::vector<VectorType *> p_vec;
    std::vector<WeightType *> p_wgt;
    for(int i = 0; i < n; i++)
      {
      VectorImageType *vec = dynamic_cast<VectorImageType *>(this->GetIndexedInputs()[i*2].GetPointer());
      WeightImageType *wgt = dynamic_cast<WeightImageType *>(this->GetIndexedInputs()[i*2+1].GetPointer());

      if(!vec || vec->GetBufferedRegion() != out->GetBufferedRegion())
        throw itk::ExceptionObject();
        
      if(!wgt || wgt->GetBufferedRegion() != out->GetBufferedRegion())
        throw itk::ExceptionObject();

      p_vec.push_back(vec->GetBufferPointer());
      p_wgt.push_back(wgt->GetBufferPointer());
      }

    // This is a super-simple filter in terms of iteration. We just need to know an offset 
    typedef itk::ImageLinearIteratorWithIndex<OutputImageType> Iterator;
    for(Iterator it(out, outputRegionForThread); !it.IsAtEnd(); it.NextLine())
      {
      // Compute the offset
      typename OutputImageType::OffsetValueType offset = out->ComputeOffset(it.GetIndex());
      OutputImagePixelType *p_out = out->GetBufferPointer();

      // Iterate over the line
      for(; !it.IsAtEndOfLine(); ++it, ++offset)
        {
        // Compute the weighted mean
        VectorType mu; mu.Fill(0.0);
        WeightType w_sum = 0.0, w_sqsum = 0.0;
        for(int i = 0; i < n; i++)
          {
          WeightType w_i = p_wgt[i][offset]; 
          if(w_i < 0) w_i = 0;

          const VectorType &x_i = p_vec[i][offset]; 
          mu += w_i * x_i;
          w_sum += w_i;
          }
        mu *= (1.0 / w_sum);

        // Compute the covariance coefficients and store in the covariance matrix
        vnl_matrix_fixed<double, VectorType::Dimension, VectorType::Dimension> M, M_inv;
        M.fill(0.0);
        for(int i = 0; i < n; i++)
          {
          WeightType w_i = p_wgt[i][offset] / w_sum; 
          if(w_i < 0) w_i = 0;
          w_sqsum += w_i * w_i;

          const VectorType &x_i = p_vec[i][offset]; 
          for(int a = 0; a < VectorType::Dimension; a++)
            for(int b = a; b < VectorType::Dimension; b++)
              M(a,b) += (x_i[a] - mu[a]) * (x_i[b] - mu[b]) * w_i;
          }
        
        for(int a = 0; a < VectorType::Dimension; a++)
          {
          M(a,a) += m_IdentityWeight;
          for(int b = a+1; b < VectorType::Dimension; b++)
            M(b,a) = M(a,b);
          }

        M *= 1.0/(1.0 - w_sqsum);

        // Invert the covariance matrix
        if(it.GetIndex()[0]==49 && it.GetIndex()[1]==38)
          {
          std::cout << mu << std::endl;
          std::cout << M << std::endl;
          }
        
        vnl_cholesky cholesky(M, vnl_cholesky::estimate_condition);
        if(cholesky.rank_deficiency() > 0)
          {
          std::cout << "Rank-deficient matrix " << M << std::endl;
          M_inv.fill(0.0);
          }
        else if(cholesky.rcond() < vnl_math::sqrteps)
          {
          std::cout << "Ill-conditioned matrix encountered " << M << std::endl;
          M_inv.fill(0.0);
          }
        else
          {
          M_inv = cholesky.inverse();
          }

        // Store the unique elements
        int k = 0;
        for(int a = 0; a < VectorType::Dimension; a++, k++)
          p_out[offset][k] = mu[a];
        for(int a = 0; a < VectorType::Dimension; a++)
          for(int b = a; b < VectorType::Dimension; b++, k++)
            p_out[offset][k] = M_inv(a,b);
        }
      }
    }

  double m_IdentityWeight;
};

struct Parameters
{
  std::vector<std::string> fn_fields, fn_weights;
  unsigned int dim;
  std::string fn_output;
  double epsilon;

  Parameters() : epsilon(0.0), dim(0) {}
};

template <typename TFloat, unsigned int VDim>
class StatisticsComputer
{
public:
  static void compute(const Parameters &param);
};


/** 
 * This is the real main method
 */
template <typename TFloat, unsigned int VDim>
void
StatisticsComputer<TFloat, VDim>
::compute(const Parameters &param)
{
  // Read all the input images
  typedef itk::CovariantVector<TFloat, VDim> VectorType;
  typedef itk::Image<VectorType, VDim> VectorImageType;
  typedef itk::Image<TFloat, VDim> WeightImageType;

  // Define the output type
  typedef itk::CovariantVector<TFloat, VDim + VDim * (VDim + 1) / 2> StatsType;
  typedef itk::Image<StatsType, VDim> StatsImageType;

  // Define the filters
  typedef itk::ImageFileReader<VectorImageType> VectorImageReaderType;
  typedef itk::ImageFileReader<WeightImageType> WeightImageReaderType;
  typedef WeightedMahalanobisStatisticsImageFilter<
    VectorImageType, WeightImageType, StatsImageType> StatsFilterType;

  // Create the filter
  typename StatsFilterType::Pointer stats = StatsFilterType::New();
  stats->SetIdentityWeight(param.epsilon);

  // Lists to save the reader pointers
  std::vector<typename VectorImageReaderType::Pointer> rf_list;
  std::vector<typename WeightImageReaderType::Pointer> rw_list;

  // Create the readers
  for(int i = 0; i < param.fn_fields.size(); i++)
    {
    typename VectorImageReaderType::Pointer rf = VectorImageReaderType::New();
    rf->SetFileName(param.fn_fields[i]);
    rf_list.push_back(rf);

    typename WeightImageReaderType::Pointer rw = WeightImageReaderType::New();
    rw->SetFileName(param.fn_weights[i]);
    rw_list.push_back(rw);

    stats->AddInputPair(rf->GetOutput(), rw->GetOutput());
    }

  // Create a writer
  typedef itk::ImageFileWriter<StatsImageType> WriterType;
  typename WriterType::Pointer writer = WriterType::New();
  writer->SetInput(stats->GetOutput());
  writer->SetFileName(param.fn_output);

  writer->Update();
}

int main(int argc, char *argv[])
{
  if(argc < 3) return usage();
  Parameters param;

  for(int iarg = 1; iarg < argc; iarg++)
    {
    std::string sarg = argv[iarg];

    if(sarg == "-f")
      {
      for(; iarg < argc - 1 && argv[iarg+1][0] != '-'; iarg++)
        param.fn_fields.push_back(argv[iarg+1]);
      printf("Read %d fields\n", (int) param.fn_fields.size());
      }
    else if(sarg == "-w")
      {
      for(; iarg < argc - 1 && argv[iarg+1][0] != '-'; iarg++)
        param.fn_weights.push_back(argv[iarg+1]);
      printf("Read %d weights\n", (int) param.fn_weights.size());
      }
    else if(sarg == "-o")
      {
      param.fn_output = argv[++iarg];
      }
    else if(sarg == "-d")
      {
      param.dim = atoi(argv[++iarg]);
      }
    else if(sarg == "-I")
      {
      param.epsilon = atof(argv[++iarg]);
      }
    else
      {
      std::cerr << "Unknown parameter " << sarg << std::endl;
      return -1;
      }
    }
 
  if(param.dim == 2)
    StatisticsComputer<float, 2>::compute(param);
  else if(param.dim == 3)
    StatisticsComputer<float, 3>::compute(param);
  else if(param.dim == 4)
    StatisticsComputer<float, 4>::compute(param);
}
