% This function reads the input argument and produces Data and ConstVars strucutre
% It has to do the following:
%     1. reads config file
%     2. make sure that there is no contradition
%     3. reads Data and normalize accoding to the options
%
%       Written by  Kayhan  Batmanghelich
%                   March 2012
%                   Section of Biomedical Image Analysis (SBIA)
%                   University of Pennsylvania


function [Data,ConstVars] = Initialization(argOpt)
         
         % read the config file for the experiment
         config = readConfigFile(argOpt.configFile,':','#') ;
         checkConfig(config) ;
         
         configFieldsNames = fieldnames(config) ;   % get all field name of config structure

         expnum = config.expnum ; 
         configFieldsNames = setdiff(configFieldsNames,{'expnum'}) ;  % take care of this field and remove it from the list 
         if isdeployed()
            if (expnum~=str2num(argOpt.expnum))
                error('experiment number ("expnum") provided in the command line and that of config file should be the same !!') ;
            end 
         else
             if ~isnumeric(argOpt.expnum)
                 usage() ;
                 error('experiment number must be a number (e.g. 1001) !!!  ')  ;
             end
         end

         algo = strtrim(config.algo) ;
         Data = [] ; 
         ConstVars.algo = algo ; 
         configFieldsNames = setdiff(configFieldsNames,{'algo'}) ;
         FeatureFilename = argOpt.ResFilename ;
         DataFcnHandler = @ReadDataFromCOMPAREConfigFile ;    % use this function to read the file list
         ConstVars.Datapath = argOpt.ImageList ;
         ConstVars.FeatureFilename = FeatureFilename ;
         ConstVars.DownSampleRatio = config.DownSampleRatio ;   
         configFieldsNames = setdiff(configFieldsNames,{'DownSampleRatio'}) ;
         ConstVars.saveAfterEachIteration = config.saveAfterEachIteration ;  
         configFieldsNames = setdiff(configFieldsNames,{'saveAfterEachIteration'}) ;
         ConstVars.nrm_mode   =  config.nrm_mode ;      
         configFieldsNames = setdiff(configFieldsNames,{'nrm_mode'}) ;  % normalization mode
         if isfield(config,'randSeed')
            ConstVars.randSeed = config.randSeed ;
            configFieldsNames = setdiff(configFieldsNames,{'randSeed'}) ;  % random seed
         else
            ConstVars.randSeed = 0  ;    % deafault random seed
         end
         if isfield(config,'nrm_scale')
             ConstVars.nrm_scale = config.nrm_scale ;   
             configFieldsNames = setdiff(configFieldsNames,{'nrm_scale'}) ;
         end
         if ~strcmpi(argOpt.what,'show')    % to save basis vectors, we don't need to read images again
            [V,y,Dim] = feval(DataFcnHandler,ConstVars) ;
            if (min(y)<0)
                error('minimum value for the lables can be zero for unlabeled data ') ;
            end
            numClasses = length(unique(y)) - any(y==0) ;
            class_N = zeros(numClasses,1) ;
            classWeight = ones(numClasses,1) ;
            for classCnt=1:numClasses
                class_N(classCnt) = length(find(y==classCnt))    ; 
                classWeight(classCnt) = 1/class_N(classCnt) ;
            end
            N = length(y) ;
            ConstVars.classWeight = classWeight ;
            ConstVars.numClasses = numClasses ;
            ConstVars.nullWeight = 1/sum(class_N) ;   % this value is used to weight part of the loss function that penalizes not being class(i) in one_vs_all multi-class scenario
         else
            load(ConstVars.FeatureFilename,'ConstVars') ;
         end
         % normalizing data 
         if strcmpi(argOpt.what,'featureextr')   % it is in the testing process
            load(ConstVars.FeatureFilename,'ConstVars')
            V = normalizeData(V,ConstVars) ;    
            Data.V = V ;
            Data.y = y ;
         elseif  strcmpi(argOpt.what,'learn')    
            [V,normalizeParams] = normalizeData(V,ConstVars) ;
            ConstVars.normalizeParams = normalizeParams ;
            Data.V = V ;
            Data.y = y ;
         end
         % general variables for problem
         global ZSBT ;
         ZSBT = 1e-20 ;   %zero substituite
         if strcmpi(argOpt.what,'learn')   % it is in the testing process
            MAXITR = config.MAXITR ;    
            configFieldsNames = setdiff(configFieldsNames,{'MAXITR'}) ;
            tol =  config.tol ;    
            configFieldsNames = setdiff(configFieldsNames,{'tol'}) ;  %--- used tobe 1e-3
            r = config.numBasisVectors      ;    
            configFieldsNames = setdiff(configFieldsNames,{'numBasisVectors'}) ;
            D = prod(Dim) ;
            ConstVars.r = r ;                           			% number of basis vectors
            ConstVars.N = N ;                           			% Number of samples to be used for training
            ConstVars.numChannels = size(V,3) ;
            ConstVars.lambda_gen = config.lambda_gen/N ;                  		% weight for Frobinous norm
            configFieldsNames = setdiff(configFieldsNames,{'lambda_gen'}) ;
            ConstVars.lambda_disc = config.lambda_disc ;                 		% weight for the second class loss function
            configFieldsNames = setdiff(configFieldsNames,{'lambda_disc'}) ;
            % decide how to interprete lambda_4: groupwise or voxel-wise
            if isfield(config,'GroupIndexImage')  % grop-wise
                [tmpimg,~] = readMedicalImage(config.GroupIndexImage) ;
                numGrp = checkGroupIndexImage(tmpimg,Dim) ;    % in case group-sparsity is used, make sure everything is correct
                ConstVars.lambda_const = numGrp*config.lambda_const ;
            else    % voxel-wise
                ConstVars.lambda_const = D*config.lambda_const ; 
            end
            configFieldsNames = setdiff(configFieldsNames,{'lambda_const'}) ;
            ConstVars.lambda_stab = config.lambda_stab ;                 		% stabilizer regularizer for the coefficients
            configFieldsNames = setdiff(configFieldsNames,{'lambda_stab'}) ;
            if ismember(configFieldsNames,'lambda_laplac')
                warning('you are using laplacian regularization of the method, be careful, this is not an official part of the package !!!') ;
                ConstVars.lambda_laplac = config.lambda_laplac ;                 		% regularizer for the Laplacian term if used
            end
            ConstVars.MAXITR = MAXITR  ;                			% MAXITR for optimization
            ConstVars.ZSBT = ZSBT ;                     			% zero substituite
            ConstVars.tol = tol ;                       			% tolerance for optimization
            ConstVars.numBatchBasisVector = config.numBatchBasisVector ;         % number of batch of basis vectors to be optimized together
            configFieldsNames = setdiff(configFieldsNames,{'numBatchBasisVector'}) ;
            ConstVars.D1 = Dim(1) ;    % x: original image size
            ConstVars.D2 = Dim(2) ;    % y: original image size
            ConstVars.D3 = Dim(3) ;    % x: original image size
            ConstVars.D = D ;
            % options for the Bsolver
            ConstVars.BSolver_opt.BBMethod = config.BBMethod ;
            configFieldsNames = setdiff(configFieldsNames,{'BBMethod'}) ;
            % optimizer parameter
            ConstVars.Monitor_Bsol = config.Monitor_Bsol ;   % this options monitors the solution for B and if it increases (instead of decrease), returns it back to the old B
            configFieldsNames = setdiff(configFieldsNames,{'Monitor_Bsol'}) ;
            ConstVars.Monitor_Csol = config.Monitor_Csol ;   % this options monitors the solution for C and if it increases (instead of decrease), returns it back to the old C
            configFieldsNames = setdiff(configFieldsNames,{'Monitor_Csol'}) ;
            % for the rest of the fields in the config do the following
            for cnt=1:length(configFieldsNames)
                ConstVars.(configFieldsNames{cnt}) = config.(configFieldsNames{cnt}) ;
            end
         end
         if isfield(configFieldsNames,'ProjectionMode')     % method to create features from basis vectors
            ConstVars.ProjectionMode = config.ProjectionMode ;
            configFieldsNames = setdiff(configFieldsNames,'ProjectionMode') ; 
         else
            ConstVars.ProjectionMode = 'inner_product' ;    % default method for projection 
         end
         if exist(ConstVars.FeatureFilename,'file')
           if strcmpi(argOpt.what,'learn')
                     save(ConstVars.FeatureFilename,'-append','ConstVars','Data')
           end
         else
           save(ConstVars.FeatureFilename,'ConstVars','Data')
         end
         % create constant matrices
         if strcmpi(argOpt.what,'learn') 
             if (argOpt.continueFlag)
                 load(ConstVars.FeatureFilename,'B','C','w','iter')
                 Data.W0 = B ;
                 Data.H0 = C ;
                 Data.w0 = w ;
                 Data.iter0 = iter ;
             end
         end
         
         % to fix a supid mcc bug
         model = train(sparse(ones(10,1)), sparse(ones(10,1)), '-c 1');
         [tmp1,tmp2,tmp3]=blockDiag(1,1,0) ;
         clear tmp1 tmp2 tmp3
         mosekopt
         fprintf('\n \n \n \n \n \n \n') ;
         
end


% this function check the required options in the config and return with error if if does have minimum required options
function checkConfig(config)
     requiredOptions = {'expnum','algo','DownSampleRatio','saveAfterEachIteration','nrm_mode','nrm_scale',...
                        'MAXITR','tol','numBasisVectors','numBatchBasisVector','lambda_gen',...
                        'lambda_disc','lambda_const','lambda_stab','BBMethod','Monitor_Bsol','Monitor_Csol'} ;
    configFieldsNames = fieldnames(config) ;
    for cnt=1:length(requiredOptions)
        if ~ismember(requiredOptions{cnt},configFieldsNames)
            error([ 'This option is necessary, make sure that the config has it : ' requiredOptions{cnt}]) ;
        end
    end

end

% in case, group-sparsity is used, this function make sure that everything is correct
function numGrp = checkGroupIndexImage(img,Dim)
         % all indices should be integer
         if any(uint16(img(:))~=img(:))
              error('Group index image must contain only integer (>0) numbers !!!') ;
         end 

         % image size of group index and individual images should match 
         if ~isequal(size(img),Dim)
              error('size of the GroupIndex Image and individual images should match !!') ;
         end 
         numGrp = length(unique(img(:))) ;
end
