% Reference:
% For more details see:
%   Afshin-Pour, B., Hossein-Zadeh, G. A., Strother, S. C.,
%   Soltanian-Zadeh,H. (2012). Enhancing reproducibility of fMRI
%   statistical maps using generalized canonical correlation analysis
%   in NPAIRS framework. NeuroImage.
%
% Babak Afshin-Pour (bafshinpour@research.baycrest.org)
% Date: 05-May-2014
%==========================================================================
%    This program is free software: you can redistribute it and/or modify
%    it under the terms of the GNU General Public License as published by
%    the Free Software Foundation, either version 3 of the License, or
%    (at your option) any later version.
%
%     gCCA is distributed in the hope that it will be useful,
%     but WITHOUT ANY WARRANTY; without even the implied warranty of
%     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%     GNU General Public License for more details.
%
%     You should have received a copy of the GNU General Public License
%     along with gCCA.  If not, see <http://www.gnu.org/licenses/>.
%    

function [At,Zm,sub_eff,qcorr,dd,Z,Afull,ddfull] = Generalized_CCA(Ts,Weights,npcs,Trans)

% Ts is a cell array


[inOctave] = in_octave();

mode = 'temporal';


Num_Subjects = length(Ts);
Num_Voxels = size(Ts{1},2);


count = 0;
qlen = 0;

if npcs~=0
    for i = 1:Num_Subjects
        count = count + 1;
        if strcmpi(mode,'temporal')
            [rTs{i},Trans{i}] = reduce_dimension(Ts{i},npcs);%size(Ts{i},1));
        end
        if  strcmpi(mode,'spatial')
            [rTs{i},Trans{i}] = reduce_dimension_spatial(Ts{i},npcs);%size(Ts{i},1));
        end

        qlen = qlen + size(rTs{i},1);
        q(i) =  size(rTs{i},1);
    end
else

    for i = 1:Num_Subjects

        Ts{i}(:,(~isfinite(sum(Ts{i},1)))) = 0;
        rTs{i} = Ts{i};

        qlen = qlen + size(rTs{i},1);
        q(i) =  size(rTs{i},1);
        if nargin<4
            Trans{i} = eye(q(i));
        end
    end
end

ind = cumsum([0 q])+1;


Csignal = zeros(qlen,qlen);
Cnoise  =  zeros(qlen,qlen);
Dsignal = zeros(qlen,qlen);
Dnoise  =  zeros(qlen,qlen);

T = cell2mat(rTs');
Vsignal = T;
Vnoise  = T;
weight_signal = ones(size(Weights));
weight_noise  = ones(size(Weights));
if ~exist('mode','var')
    mode = 'original';
end
if ~isempty(Weights)
    if sum(Weights)~=Num_Voxels
        weight_signal = Weights;
        if strcmp(mode,'original')
            weight_noise  = ones(size(Weights));
        else
            weight_noise  = 1-Weights;
        end
        Vsignal = bsxfun(@times,T,weight_signal);
        Vnoise  = bsxfun(@times,T,weight_noise);
    end
end
Csignal = Vsignal*Vsignal'/(sum(weight_signal)-1);
Cnoise  = Vnoise*Vnoise'/(sum(weight_noise)-1);

for i = 1:Num_Subjects
    Dsignal(ind(i):ind(i+1)-1,ind(i):ind(i+1)-1) =  Csignal(ind(i):ind(i+1)-1,ind(i):ind(i+1)-1);
    Dnoise(ind(i):ind(i+1)-1,ind(i):ind(i+1)-1)  =  Cnoise(ind(i):ind(i+1)-1,ind(i):ind(i+1)-1);
end

[V,D]= eig((Csignal-Dsignal)/(Num_Subjects-1),Dnoise);%+eye(size(Ci,1))*1e-10);
dd=diag(D);
dd_temp = dd;
% [U,V,X,C,S] = gsvd(C-Ci,Ci);
% dd = diag(C)./diag(S);
[dds,dind] = sort(dd,'descend');
%V = X;
dd = real(dd(dind(1:min(q))));
A = real(V(:,dind(1:min(q))));
Vfull = real(V(:,dind));
ddfull = real(dd_temp(dind));

for i = 1:Num_Subjects
    Z{i} = A(ind(i):ind(i+1)-1,:)'*rTs{i};
end

if inOctave
    count = 0;
    for i = 1:Num_Subjects
        for j = i+1:Num_Subjects
            count = count + 1;
            qcorr(count) = corr(Z{i}(1,:),Z{j}(1,:));
        end
    end
else
    count = 0;
    for i = 1:Num_Subjects
        for j = i+1:Num_Subjects
            count = count + 1;
            Temp = corrcoef(Z{i}(1,:),Z{j}(1,:));
            qcorr(count) = Temp(1,2);
        end
    end
end

for i  = 1:Num_Subjects
    At{i} = Trans{i}*A(ind(i):ind(i+1)-1,:);
    Afull{i} = Trans{i}*Vfull(ind(i):ind(i+1)-1,:);
    sub_eff(i) = sqrt(sum(A(ind(i):ind(i+1)-1,1).^2));
end


Zm = zeros(min(q),Num_Voxels);
for i = 1:min(q)
    for j = 1:Num_Subjects
        Zm(i,:) = Zm(i,:) + Z{j}(i,:)/Num_Subjects;
    end
    %Zm(i,:) = Zm(i,:)./(std(temp));
end



%%%% use test data set to evalute the generality of test



function [rTs,Trans] = reduce_dimension(Ts,q)

[p,n]=size(Ts);

C = cov(Ts');
[U,S,V] = svd(Ts,'econ');
dd = diag(S);
[dds,index] = sort(dd,'descend');
if fix(q)==q
    ind   = index(1:q);
    
else
    pp    = cumsum(dds)/sum(dds);
    ind_temp   = find(pp>q);
    ind = index(1:ind_temp(1)-1);
end
S2    = S(ind,ind);
Trans = U(:,ind)*inv(S2);
rTs   = V(:,ind)';

function [rTs,Trans] = reduce_dimension_spatial(Ts,q)

[p,n]=size(Ts);
[U,S,V] = svd(Ts','econ');

rTs = V(:,1:q)';
SN = diag(std(rTs,0,2));
rTs = inv(SN)*V(:,1:q)';
rTs = rTs - repmat(mean(rTs,2),1,size(rTs,2));
Trans = U(:,1:q)*S(1:q,1:q)*SN;



function Cab = xcov(A,B)

[n,p]=size(A);
A = A - ones(n,1)*mean(A);
B = B - ones(n,1)*mean(B);
Cab = A'*B/(n-1);

function [inOctave] = in_octave()
try
    OCTAVE_VERSION;
    inOctave = 1;
catch
    inOctave = 0;
end

