
%% script for segmenting hyperintensities on a new image given training outputs

%%
function names = WhyD_detect(names, training, training_path)

%% loading image data and options for training 
input = load_nii(sprintf('%s/%s',names.directory_path,names.WM_mod));
sub_image = double(input.img); sub_dim = size(sub_image);
load(sprintf('%s/options_training.mat',training_path)); vol = K^3; [ker width_vec] = getKernels(width);
sub_image(1:K,:,:) = 0; sub_image(end-K+1:end,:,:) = 0; 
sub_image(:,1:K,:) = 0; sub_image(:,end-K+1:end,:) = 0; 
sub_image(:,:,1:K) = 0; sub_image(:,:,end-K+1:end) = 0; 

%% segmenting new subject
% initializing the segmentation process
fg_thresh = 0.6*max(sub_image(:)); fg = find(sub_image>fg_thresh); 
N_tot = length(fg); lnloop = 500; folds = ceil(N_tot/lnloop); eps = 1e-3;
if strcmp(training.method,'svm_class')
    scale = 2; names.method = 'SVM Classification';
    print = struct('name','SVM Classification','short','SVM');
    sg('new_svm', 'LIGHT'); sg('svm_epsilon', eps); sg('c', 1); 
    sg('set_kernel', 'LINEAR', 'REAL', 100, scale*N_tot);
    sg('set_classifier', training.b, training.alphas);    
else trees = training.trees;
    if strcmp(training.method,'rf_regress') 
        scale = 1; names.method = 'RF Regression'; print = struct('name','RF Regression','short','RFREG');
    elseif strcmp(training.method,'rf_class') 
        names.method = 'RF Classification'; print = struct('name','RF Classification','short','RFCLA'); end
end
fprintf('Segmenting subject : %s_%s using %s method \n',names.folder_name,names.folder_id,print.name);
fprintf('Total folds in processing : %d \n',folds);
oo = zeros(N_tot,1); first = 1;
% computing kernels for each fold followed by segmenting
for j = 1:folds-1
    last = min(first + ceil(N_tot/folds) - 1, N_tot); %[first, last]
    sub_ind = first:last; D0 = fg(sub_ind); 
    [Xr Yr Zr] = ind2sub(sub_dim,D0); D = zeros(size(sub_ind,2),K^3);
    for l = 1:1:size(D,1)
        i1_1 = Xr(l)-(K-1)/2; i1_2 = Xr(l)+(K-1)/2;
        i2_1 = Yr(l)-(K-1)/2; i2_2 = Yr(l)+(K-1)/2;
        i3_1 = Zr(l)-(K-1)/2; i3_2 = Zr(l)+(K-1)/2;
        tempim = sub_image(i1_1:i1_2,i2_1:i2_2,i3_1:i3_2);
        D(l,:) = reshape(tempim,1,[]);
    end
    D2 = zeros(size(sub_ind,2),size(ker,1)*K^3);
    for n = 1:1:size(D,1)
        i1_1 = Xr(n)-(K-1)/2; i1_2 = Xr(n)+(K-1)/2;
        i2_1 = Yr(n)-(K-1)/2; i2_2 = Yr(n)+(K-1)/2;
        i3_1 = Zr(n)-(K-1)/2; i3_2 = Zr(n)+(K-1)/2;
        tempim = sub_image(i1_1:i1_2,i2_1:i2_2,i3_1:i3_2);
        for k = 1:1:size(ker,1)
            if k == size(ker,1) scalediv = 3; else scalediv = 1; end
            D2(n,1+(k-1)*K^3:k*K^3) = reshape(convn(tempim,ker{k,1},'same'),1,[])/(scalediv*(width_vec(k)^3));
            if k >= 3 && k <= 8 D2(n,1+(k-1)*K^3:k*K^3) = D2(n,1+(k-1)*K^3:k*K^3) + D(n,:); end
            if k >= 9 && k <= 14 D2(n,1+(k-1)*K^3:k*K^3) = D2(n,1+(k-1)*K^3:k*K^3) - D(n,:); end
        end
    end
    D3 = [D,D2]; if size(D3, 2) ~= numel(D3), D3 = D3'; end
    % segmenting each fold
    if strcmp(training.method,'svm_class')
        sg('set_features', 'TEST', D3); oo(sub_ind) = sg('classify');
    elseif strcmp(training.method,'rf_regress')
        oo(sub_ind) = eval_Stochastic_Bosque(D3',trees,'method','r');
    elseif strcmp(training.method,'rf_class')
        oo(sub_ind) = eval_Stochastic_Bosque(D3',trees,'method','c');
    end
    first = first + ceil(N_tot/folds); if round(j/50)==j/50 fprintf('... %d done ... \n',j); end
end
fprintf('... all folds done ... \n');
fprintf('Done segmenting subject : %s_%s using %s method \n',names.folder_name,names.folder_id,print.name);
% packaging the segmented output
indout = ind2sub(sub_dim,fg);
if strcmp(training.method,'svm_class')  
    out = zeros(sub_dim); out(indout) = oo;
elseif strcmp(training.method,'rf_regress')
    out = -1*ones(sub_dim); out(indout) = oo; out = out/scale; out(out<=0) = 0;
elseif strcmp(training.method,'rf_class')  
    out = char(zeros(sub_dim)); out(:) = 'n'; out(indout) = oo;
end
% saving the segmentated output and updating names file
nii = input; nii.img = out;
save_nii(nii, sprintf('%s/%s_out_%s.nii',names.directory_path,print.short,names.folder_id));
names.seg_out = sprintf('%s_out_%s.nii',print.short,names.folder_id);
save(sprintf('%s/names_%s.mat',names.directory_path,names.folder_id),'names');

%% end
