function [beta_hat_clust, beta_hat_conf, Cov_bc, COV_d_out, Fw, num_voxROI, freq_band] = ssfmri_est_rs(FX_seed,Y_band,FX_conf,bands_all,nband,coord_ROIs,img_dim,prior_d)
%
% [beta_hat_clust, beta_hat_conf, Cov_bc, COV_d_out, Fw, num_voxROI, freq_band]
%       = SSFMRI_EST_RS(FX_seed,Y_band,FX_conf,bands_all,nband,coord_ROIs,img_dim,prior_d)
%
%
% This function is called in ssfmri_ancova. 
% It is designed to take the outputs from ssfmri_band 
% 
%
%-------------------------------------------------------------------------
% Input parameters
% 
% FX_seed - matrix of desired frequecies [num_freqs x size(Xtime,2)] 
% Y_band - cell array (num_ROI) of matricies 
%       [num_vox_in_ROI x num_freqs_selected]
% FX_conf - matrix of desired frequecies [num_freqs x size(confounds,2)] 
% bands_all - concated array of indicies of all freqs kept
% nband - integer number of bands that you want to break spectrum into
% coord_ROIs - cell array (num_ROI) of matricies: x,y,z coordinates for each 
%       voxel in that ROI [num_vox_in_ROI x 3] 
% img_dim - either 2 or 3
% prior_d - not sure yet
% 
% -------------------------------------------------------------------------
% Output parameters
% 
% beta_hat_clust - estimated beta for seed on each ROI (one value per ROI)
% beta_hat_conf  - estimated beta for other confounds (one value per voxel)
% Cov_bc         - estimated within ROI covariance
% COV_d_out      - estimated between ROI covariance
% Fw             - estimated std for each band
% num_voxROI     - number of voxels in each ROI
% freq_band      - frequency index in each band
% 
%--------------------------------------------------------------------------

% Compute OLS for each ROI
num_ROI = length(Y_band);
num_voxROI = zeros(num_ROI,1);

fx = [FX_seed FX_conf]; 
num_desm = size(fx,2);

freq_tot = (1:length(bands_all))';
num_freq = ceil(length(freq_tot)/nband);
freq_band = cell(nband,1);

for k = 1:nband
    if (k*num_freq < length(freq_tot))
        freq_band{k} = freq_tot(1+(k-1)*num_freq:k*num_freq);
    else
        freq_band{k} = freq_tot(1+(k-1)*num_freq:end);
    end
end

Firstsump = (real(fx(freq_tot,:)))'*real(fx(freq_tot,:))+(imag(fx(freq_tot,:)))'*imag(fx(freq_tot,:));

for clust = 1:num_ROI
    num_voxROI(clust) = size(coord_ROIs{clust},1);
end

gamma_OLS=cell(num_ROI,1);
for clust=1:num_ROI    
    Secondsump = (real(fx(freq_tot,:)))'*real(Y_band{clust}(:,freq_tot))' + (imag(fx(freq_tot,:)))'*(imag(Y_band{clust}(:,freq_tot)))';
  
    gamma_OLS{clust}=(pinv(Firstsump)*Secondsump); % gamma_OLS = inv(sum(X'X))*(sum(X'Y))
    
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Estimate beta_c, cluster-specific fixed effect %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% By definition gamma_OLS = beta_c + b_vc
% Iteratively we can estimate beta_c and sigma_bc

% beta_OLS
gamma_OLS123 = cell(num_ROI,1);

beta_c0 = zeros(num_ROI,1);
num_conf = size(FX_conf,2);
beta_hat_conf = cell(num_conf,1);

for i = 1:(num_desm-1)
    beta_hat_conf{i} = zeros(img_dim);
end

for clust = 1:num_ROI
    gamma_OLS123{clust} = gamma_OLS{clust}';    
    beta_c0(clust) = mean(gamma_OLS{clust}(1,:));
    if length(img_dim) == 2
        indclust = sub2ind(img_dim,coord_ROIs{clust}(:,1),coord_ROIs{clust}(:,2));
    else
        indclust = sub2ind(img_dim,coord_ROIs{clust}(:,1),coord_ROIs{clust}(:,2),coord_ROIs{clust}(:,3));
    end
    for i = 1:(num_desm-1)        
        beta_hat_conf{i}(indclust) = gamma_OLS123{clust}(:,i+1);
    end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% spatial distance matrix
dist_matrix1 = cell(num_ROI,1);
dx1 = cell(num_ROI,1);

for clust = 1:num_ROI
    dx1{clust} = ones(num_voxROI(clust),1);
    distV = pdist(coord_ROIs{clust},'euclidean');    
    dist_matrix1{clust} = squareform(distV);
end

beta_hat_clust=zeros(num_ROI,1);
Cov_bc=cell(num_ROI,1);

% estimate beta_seed
c0_init = 0;
c1_init = 1;
phi_init = 1;
h = cell(num_ROI,1);
for clust=1:num_ROI
    temp2 = strcat('Working on the ', num2str(clust), 'th Cluster Spatial Estimation');
    disp(temp2);
    
    residRt = gamma_OLS123{clust}(:,1);
    YY = residRt*residRt';
    [V_hat h{clust}] = ssfmri_reml(YY,dist_matrix1{clust},c0_init,c1_init,phi_init);
    
    V_hat = V_hat*length(residRt)/trace(V_hat);
    V_hat_inv = pinv(V_hat);

    % First part of GLS 
    firstsum = 1/sum(sum(V_hat_inv));
    secondsum = sum(V_hat_inv,1)*gamma_OLS123{clust}(:,1);
    beta_clust = firstsum*secondsum;

    beta_hat_clust(clust)=beta_clust;
    
    rRes = residRt - beta_clust;
    sig2 = sum(rRes.^2)/(length(residRt)-1);
    Cov_bc{clust}.V = sig2*V_hat;
    Cov_bc{clust}.h = [h{clust};sig2];

end
% beta_hat_clust = beta_c0;       
%% estimate off-diagonal elements of covariance matrix of d_c
ResidR1_band = cell(num_ROI,1);
ResidI1_band = cell(num_ROI,1);

% compute Residuals at each frequency and each ROI, Ycv(w)-X(w)*Beta_c
for clust = 1:num_ROI       
    tempY=Y_band{clust};
    tempgam1 = repmat(beta_hat_clust(clust),num_voxROI(clust),1);
    tempgam = [tempgam1 gamma_OLS123{clust}(:,2:end)];
    Resid_R = real(tempY) - tempgam*real(fx)';
    Resid_I = imag(tempY) - tempgam*imag(fx)';

%     Resid_R = [];
%     Resid_I = [];
%     for j=1:length(bands_all)
%         tempXR = repmat(real(fx(j,:)),num_voxROI(clust),1);
%         tempXI = repmat(imag(fx(j,:)),num_voxROI(clust),1);
% 
%         tempXGamR = sum(tempXR.*tempgam, 2);
%         tempXGamI = sum(tempXI.*tempgam, 2);
% 
%         % Compute residuals for Real and Imaginary
%         tempResidR = real(tempY(:,j)) - tempXGamR;
%         tempResidI = imag(tempY(:,j)) - tempXGamI;
% 
%         Resid_R = [Resid_R, tempResidR];
%         Resid_I = [Resid_I, tempResidI];
%     end
    ResidR1_band{clust} = Resid_R;
    ResidI1_band{clust} = Resid_I;  
end

m_ResidR = cell(num_ROI,1);
m_ResidI = cell(num_ROI,1);
for clust = 1:num_ROI
    m_ResidR{clust} = mean(ResidR1_band{clust})';
    m_ResidI{clust} = mean(ResidI1_band{clust})';
end

% meancorR = zeros(num_ROI);
% meancorI = zeros(num_ROI);
m_meancor = zeros(num_ROI);
if isempty(prior_d)
    for clust=1:num_ROI-1
        for clust1=clust+1:num_ROI
%                 % Real part
%                 meancorR(clust,clust1)=corr(m_ResidR{clust}(freq_tot), m_ResidR{clust1}(freq_tot));
%         %             meancorR{i}(clust,clust1) = meanCorr(ResidR1_band{clust,i}, ResidR1_band{clust1,i});
%                 % Imaginary part
%                 meancorI(clust,clust1)=corr(m_ResidI{clust}(freq_tot), m_ResidI{clust1}(freq_tot));
%         %             meancorI{i}(clust,clust1) = meanCorr(ResidI1_band{clust,i}, ResidI1_band{clust1,i});

            m_meancor(clust,clust1) = corr(cat(1,m_ResidR{clust}(freq_tot),m_ResidI{clust}(freq_tot)),...
                cat(1,m_ResidR{clust1}(freq_tot),m_ResidI{clust1}(freq_tot)));

        end
    end
else
    d_yR = cell2mat(m_ResidR');
    d_yI = cell2mat(m_ResidI');
    d_yy = (cat(1,d_yR,d_yI))'*(cat(1,d_yR,d_yI));
    [Vd,hd] = spm_reml(d_yy,ones(num_ROI,1),prior_d);
    [~, m_meancor] = cov2corr(Vd);
    m_meancor = triu(m_meancor,1);
end

% m_meancor = (meancorR+meancorI)/2;
% 
% % Now compute pairwise correlation between any voxel in ROI1 and voxel
% % in ROI2
% % meancorR = cell(num_band,1);
% % meancorI = cell(num_band,1);
% % for nb = 1:num_band
% 
% % only have one band
% meancorR = zeros(num_ROI,num_ROI);
% meancorI = zeros(num_ROI,num_ROI);
% 
% for clust=1:num_ROI-1
%     for clust1=clust+1:num_ROI
%         % Real part
%         meancorR(clust,clust1)=corr(mean(ResidR1_band{clust})', mean(ResidR1_band{clust1})');
% %             meancorR(clust,clust1) = meanCorr(ResidR1_band{clust}, ResidR1_band{clust1});
%             
%             % mean of any paried voxel correlation
% %             meancorR(clust,clust1) = mean(mean(corr(ResidR1_band{clust}',ResidR1_band{clust1}')));
%             
%         % Imaginary part
%         meancorI(clust,clust1)=corr(mean(ResidI1_band{clust})', mean(ResidI1_band{clust1})');
% %             meancorI(clust,clust1) = meanCorr(ResidI1_band{clust},ResidI1_band{clust1});
% %             meancorI(clust,clust1) = mean(mean(corr(ResidI1_band{clust}',ResidI1_band{clust1}')));
%             
%     end
% end

%% estimate diagonal elements of covariance matrix of d_c
ResidgammaR_band = cell(num_ROI, 1);
ResidgammaI_band = cell(num_ROI, 1);

% compute Residuals at each frequency and each ROI, Zcv(w)
for clust = 1:num_ROI
    tempY=Y_band{clust};
    tempgam = gamma_OLS123{clust};
    Resid_Rtemp = real(tempY) - tempgam*(real(fx))';
    Resid_Itemp = imag(tempY) - tempgam*(imag(fx))';
%     resid_com = tempY - tempgam*fx';
%     Resid_Rtemp = real(resid_com);
%     Resid_Itemp = imag(resid_com);
%     Resid_Rtemp = [];
%     Resid_Itemp = [];
%     for j=1:length(bands_all)
%         tempXR = repmat(real(fx(j,:)),num_voxROI(clust),1);
%         tempXI = repmat(imag(fx(j,:)),num_voxROI(clust),1);
% 
%         tempXGamR = sum(tempXR.*tempgam, 2);
%         tempXGamI = sum(tempXI.*tempgam, 2);
% 
%         % Compute residuals for Real and Imaginary
%         tempResidR = real(tempY(:,j)) - tempXGamR;
%         tempResidI = imag(tempY(:,j)) - tempXGamI;
% 
%         Resid_Rtemp = [Resid_Rtemp, tempResidR];
%         Resid_Itemp = [Resid_Itemp, tempResidI];
%     end
    ResidgammaR_band{clust} = Resid_Rtemp;
    ResidgammaI_band{clust} = Resid_Itemp;    
end

% meancorR = cell(nband,1);
% meancorI = cell(nband,1);
% 
% m_ResidgammaR = cell(num_ROI,1);
% m_ResidgammaI = cell(num_ROI,1);
% for clust = 1:num_ROI
%     m_ResidgammaR{clust} = mean(ResidgammaR_band{clust})';
%     m_ResidgammaI{clust} = mean(ResidgammaI_band{clust})';
% end
% for k = 1:nband
%     meancorR{k} = zeros(num_ROI);
%     meancorI{k} = zeros(num_ROI);
% end
% 
% for clust=1:num_ROI-1
%     for clust1=clust+1:num_ROI
%         for k = 1:nband
% 
%             % Real part
%             meancorR{k}(clust,clust1)=corr(m_ResidgammaR{clust}(freq_band{k}), m_ResidgammaR{clust1}(freq_band{k}));
%     %             meancorR{i}(clust,clust1) = meanCorr(ResidR1_band{clust,i}, ResidR1_band{clust1,i});
%             % Imaginary part
%             meancorI{k}(clust,clust1)=corr(m_ResidgammaI{clust}(freq_band{k}), m_ResidgammaI{clust1}(freq_band{k}));
%     %             meancorI{i}(clust,clust1) = meanCorr(ResidI1_band{clust,i}, ResidI1_band{clust1,i});
%         end
% 
%     end
% end

sig2d_R = cell(nband,1);
sig2d_I = cell(nband,1);

for k = 1:nband
    sig2d_R{k} = zeros(num_ROI,1);
    sig2d_I{k} = zeros(num_ROI,1);
end

for clust=1:num_ROI

%         % Real part
%         sig2d_R(clust)=median(median(cov(ResidgammaR_band{clust}')));       
%         % Imaginary part
%         sig2d_I(clust)=median(median(cov(ResidgammaI_band{clust}'))); 
        
        %%% 
    for k = 1:nband       
        sig2d_R{k}(clust) = meanCov(ResidgammaR_band{clust}(:,freq_band{k}));
        sig2d_I{k}(clust) = meanCov(ResidgammaI_band{clust}(:,freq_band{k}));
    end
        
        %%% only consider cov between different voxels
%         covR = cov(ResidgammaR_band{clust}');
%         covI = cov(ResidgammaI_band{clust}');
%         tmpcovR = covR(eye(size(covR)) == 0);
%         tmpcovI = covI(eye(size(covI)) == 0);
%         sig2d_R(clust) = mean(tmpcovR);
%         sig2d_I(clust) = mean(tmpcovI);

end


%% Construnct Sig_d matrix at each band
disp('Construncting Sid_d matrix');

SIG_d_R = cell(nband,1);
SIG_d_I = cell(nband,1);

cor_d = eye(num_ROI)+m_meancor+m_meancor';
for k = 1:nband
    % corr(x,y) = cov(x,y)/(sig(x)*sig(y))
    sig_dRb = sqrt(sig2d_R{k})*sqrt(sig2d_R{k}');
%     cor_dRb = eye(num_ROI)+meancorR{k}+meancorR{k}';
%     SIG_d_R{k} =  sig_dRb.*cor_dRb;
    SIG_d_R{k} = sig_dRb.*cor_d;

    sig_dIb = sqrt(sig2d_I{k})*sqrt(sig2d_I{k}');
%     cor_dIb = eye(num_ROI)+meancorI{k}+meancorI{k}';
%     SIG_d_I{k} =  sig_dIb.*cor_dIb;
    SIG_d_I{k} = sig_dIb.*cor_d;
end
% put those into an array 
COV_d_out = [SIG_d_R SIG_d_I];


%% Finally, estimate sig2_residual
% idea: compute VAR(Y-x\gamma) = VAR(d +\epsilon)
%                               = sig2d_c + sig2_resid
% Here sig2d_c was estimated, so the VAR - sig2d_c will
% be sig2_resid.

% should be done for Real and Imaginary

Sig2_Resid_R = cell(nband,1);
Sig2_Resid_I = cell(nband,1);

for k = 1:nband
    Sig2_Resid_R{k} = zeros(num_ROI,1);
    Sig2_Resid_I{k} = zeros(num_ROI,1);
end

for clust=1:num_ROI
    for k = 1:nband

    % Real part
        tempRvarn = ResidgammaR_band{clust}(:,freq_band{k})';
        tempRvarn = tempRvarn(:);
        Sig2_Resid_R{k}(clust) = var(tempRvarn)-sig2d_R{k}(clust);

    % Imaginary part
        tempIvarn = ResidgammaI_band{clust}(:,freq_band{k})';
        tempIvarn = tempIvarn(:);
        Sig2_Resid_I{k}(clust) = var(tempIvarn)-sig2d_I{k}(clust);
    end
    
% var across time for each voxel
%     Sig2_Resid_R(clust) = mean(var(tempRvarn)) - sig2d_R(clust);
%     Sig2_Resid_I(clust) = mean(var(tempIvarn)) - sig2d_I(clust);

end

Fw = cell(nband,1);
for k = 1:nband
    temp_fw = zeros(2,1); % first row for Real part, second row for Imaginary part
    temp_fw(1) = mean(Sig2_Resid_R{k}(Sig2_Resid_R{k}>0));
    temp_fw(2) = mean(Sig2_Resid_I{k}(Sig2_Resid_R{k}>0));
    temp_fw(isnan(temp_fw)) = 0;
    fw_new = mean(temp_fw);

    Fw{k} = 2 * fw_new;
end

return