/*=====================================================================

Program: Gaussian Smoothing of DTI nrrd image
Language: C++
Date: 11/13/2007
Author: Madhura A Ingalhalikar

======================================================================*/




#include "itkVectorImage.h"
#include "itkNrrdImageIO.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkMetaDataDictionary.h"
#include "itkImageSeriesReader.h"
#include "itkImageRegionConstIterator.h"
#include "itkImageRegionIterator.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkVectorIndexSelectionCastImageFilter.h"
#include "vnl/vnl_math.h"
#include "vnl/vnl_matrix.h"
#include "itkTimeProbe.h"
#include "itkVariableLengthVector.h"
#include "itkImageFileWriter.h"
#include "itkMaskImageFilter.h"
#include "MaskDWIFilterCLP.h"
#include <vnl/algo/vnl_svd.h>
#include <iostream>

int main( int argc, char *argv[] )
{
   
   PARSE_ARGS;
  
  std::cout << "Input DTI nrrd Image: " <<  InputImageFilename << std::endl; 
  std::cout << "Output Smoothed DTI nrrd Image: " <<  OutputImageFilename << std::endl;  
  std::cout << "Mask Image: " <<  MaskImageFilename << std::endl;  
 
  
  typedef unsigned short                      PixelType;
  typedef itk::VectorImage<unsigned short, 3> ImageType;
  typedef itk::VariableLengthVector<double> OutputImagePixelType;
  typedef itk::Image<unsigned short, 3> BvalueImageType;
  typedef itk::Image<unsigned char, 3> MaskImageType;
  
  
  //Read the input DTI nrrd file
  
  itk::ImageFileReader<ImageType>::Pointer reader = itk::ImageFileReader<ImageType>::New();
  ImageType::Pointer img; 
  reader->SetFileName(InputImageFilename.c_str());
 
  try
    {
    std::cout << "Reading DTI....." << std::endl;
    reader->Update();
    img = reader->GetOutput();
   
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    return EXIT_FAILURE;
    }
  
  
    
  int vectorLength = img->GetVectorLength();
  
  
  //Read the input DTI nrrd file
  
  itk::ImageFileReader<MaskImageType>::Pointer maskRead = itk::ImageFileReader<MaskImageType>::New();
  MaskImageType::Pointer mask; 
  maskRead->SetFileName(MaskImageFilename.c_str());
 
  try
    {
    std::cout << "Reading Mask" << std::endl;
    maskRead->Update();
    mask= maskRead->GetOutput();
   
    }
  catch (itk::ExceptionObject &ex)
    {
    std::cout << ex << std::endl;
    return EXIT_FAILURE;
    }
  //Allocate output image
  
  ImageType::Pointer output = ImageType::New();  
  
 
  
  output->SetRegions( img->GetLargestPossibleRegion() );
  output->SetSpacing( img->GetSpacing() );
  output->SetOrigin( img->GetOrigin() );
  output->SetDirection( img->GetDirection() );
  output->SetVectorLength( img->GetVectorLength() );
  output->SetMetaDataDictionary( img->GetMetaDataDictionary() );
  output->Allocate(); 
  
  //Define Filters
  
  typedef itk::VectorIndexSelectionCastImageFilter<ImageType, BvalueImageType> VectorSelectFilterType;
  VectorSelectFilterType::Pointer extractImageFilter = VectorSelectFilterType::New();
  extractImageFilter->SetInput( img );
  
  typedef itk::MaskImageFilter<BvalueImageType, MaskImageType, BvalueImageType>  FilterType;
  FilterType::Pointer filter = FilterType::New();
  filter->SetInput2(mask);
  BvalueImageType::Pointer image = filter->GetOutput();   
   
  
  //Define Iterator
  typedef itk::ImageRegionIterator< ImageType > IteratorType;
  typedef itk::ImageRegionConstIterator< BvalueImageType > ConstIteratorType;
  IteratorType ot( output, output->GetRequestedRegion() );
  OutputImagePixelType vectorImagePixel;
  
   for (int i=0; i < vectorLength; i++)
    {
   
    extractImageFilter->SetIndex( i );
    std::cout << " Masking image no :  " << i << std::endl;
    extractImageFilter->Update( );
    filter->SetInput1(extractImageFilter->GetOutput( ));
    filter->Update();
    image = filter->GetOutput();      
    ConstIteratorType it( image,image->GetRequestedRegion() );    
    
   	 for (ot.GoToBegin(),it.GoToBegin(); !it.IsAtEnd(); ++ot, ++it)
     	 {
     	 vectorImagePixel = ot.Get( );
      	 vectorImagePixel[i] = it.Value();
     	 ot.Set( vectorImagePixel );
         }
      
     image->DisconnectPipeline(); 
    }
   
   
   //// Writer............
  
  typedef itk::ImageFileWriter< ImageType> WriterType;
  WriterType::Pointer writer = WriterType::New();
  writer->SetFileName( OutputImageFilename );
  writer->SetInput(output);
  
  writer->Update();
  
 
  
  return 0;
  
} 
    
    
    
    
