% This is the companion code for Sec. IV of the following paper:
% Y. Yang, M. Pesavento, Symeon Chatzinotas, and Bjrn Ottersten, 
% "Successive convex approximation algorithms for sparse signal estimation
% with nonconvex regularizations",
% IEEE Journal of Selected Topics in Signal Processing

clear; 
clc;

%% parameters
N        = 1000; 
K        = 2000; 
I        = 2000; % X: N * K, D: N * I, S: I * K
rho_real = 5;
rho      = 10; % rank of initial point; P: N * rho, Q: rho * K

% number of samples in Monte Carlo simulations
Sample   = 10;

% the following algorithms are compared:
% "_g" stands for the BCD algorithm proposed in [5]
% "_j" stands for the proposed parallel best-response algorithm (STELA)
% "_a" stands for the ADMM method proposed in [4]
% maximum number of iterations
MaxIter_g =  20; val_g = zeros(Sample,MaxIter_g+1);  time_g = zeros(Sample,MaxIter_g+1); %
MaxIter_j =  50; val_j = zeros(Sample,MaxIter_j+1);  time_j = zeros(Sample,MaxIter_j+1); 
MaxIter_a = 100; val_a = zeros(Sample,MaxIter_a+1);  time_a = zeros(Sample,MaxIter_a+1); 

%%
for s = 1:1:Sample
    disp(['sample ' num2str(s)]);
    
    % generate the data
    D = zeros(N,I);    
    for i = 1:1:I
        D(randi(N),i) = 1;
    end
    S0 = sprandn(I,K,0.05); % density   
    
    P0    = sqrt(100/I) * randn(N, rho_real);
    Q0    = sqrt(100/K) * randn(rho_real, K);
    X0    = P0 * Q0; % perfect X
    sigma = 0.01;
    V     = sigma * randn(N,K); % noise
    
    Y = X0 + D * S0 + V; % observation

    lambda   = 2.5 * 10^-1 * norm(Y); %spectral norm
    mu       = 2 * 10^-4 * norm(D' * (Y),inf); %     
    
    % initial point (common for all algorithms)
    initial_P = sqrt(100/I) * randn(N,rho);
    initial_Q = sqrt(100/K) * randn(rho,K);
    initial_S = zeros(I,K);

    % initial value
    val0 = 0.5 * norm(Y - initial_P * initial_Q - D * initial_S, 'fro')^2 ...
         + 0.5 * lambda * (norm(initial_P, 'fro')^2 + norm(initial_Q, 'fro')^2) ...
         + mu * norm(vec(initial_S), 1);

    %% STELA algorithm: Initialization
    P           = initial_P; 
    Q           = initial_Q; 
    S           = initial_S;    
    val_j(s,1)  = val0; 
    time_j(s,1) = 0;
    d_DtD       = diag(diag(D' * D));
%     disp(['STELA, iteration ' num2str(0) ', time ' num2str(0) ', value ' num2str(val_j(s,1))]);
    
    for t = 1: 1: MaxIter_j
        tic;
        
        Y_DS  = Y - D * S;
        
        P_new = Y_DS * Q' * (Q * Q' + lambda * eye(rho))^-1;
        cP    = P_new - P;
        
        Q_new = (P' * P + lambda * eye(rho))^-1 * P' * Y_DS;
        cQ    = Q_new - Q;
        
        G     = d_DtD * S - D' * (P * Q - Y_DS); clear Y_DS
        S_new = d_DtD^-1 * (max(G - mu * ones(I,K),zeros(I,K)) - max(-G - mu * ones(I,K),zeros(I,K))); clear G
        cS    = S_new - S;
        
        %-------------------- to calculate the stepsize by exact line search----------------
        A = cP * cQ;
        B = P * cQ + cP * Q + D * cS;
        C = P * Q + D * S - Y;
        
        a = 2 * sum(sum(A.^2,1));
        b = 3 * sum(sum(A.*B,1));
        c = sum(sum(B.^2,1)) + 2 * sum(sum(A.*C,1)) + lambda * sum(sum(cP.^2,1)) + lambda * sum(sum(cQ.^2,1));
        d = sum(sum(B.*C,1)) + lambda * sum(sum(cP.*P,1)) + lambda * sum(sum(cQ.*Q,1)) + mu * (norm(vec(S_new),1) - norm(vec(S),1));

        clear A B C
        % calculating the stepsize by closed-form expression
        Sigma1      = (-(b/3/a)^3 + b*c/6/a^2 - d/2/a);
        Sigma2      = c/3/a - (b/3/a)^2;
        Sigma3      = Sigma1^2 + Sigma2^3;
        Sigma3_sqrt = sqrt(Sigma3);
        if Sigma3 >= 0
            gamma = nthroot(Sigma1 + Sigma3_sqrt,3)...
                 + nthroot(Sigma1 - Sigma3_sqrt,3)...
                - b/3/a;
        else
            C1 = 1; C1(4)  =  -(Sigma1 + Sigma3_sqrt);
            C2 = 1; C2(4) = -(Sigma1 - Sigma3_sqrt);
            R = real(roots(C1) + roots(C2)) - b/3/a * ones(3,1);
            gamma = min(R(R>0));
            clear C1 C2 R;
        end
        clear Sigma1 Sigma2 Sigma3 Sigma3_sqrt
        clear a b c d
        gamma = max(0,min(gamma,1)); 

        % variable update
        P = P + gamma * cP; clear cP P_new
        Q = Q + gamma * cQ; clear cQ Q_new
        S = S + gamma * cS; clear cS S_new
        
        time_j(s,t+1) = toc + time_j(s,t);                
        
        val_j(s,t+1)  = 0.5 * norm(Y - P * Q - D * S,'fro')^2 + 0.5 * lambda * (norm(P,'fro')^2 + norm(Q,'fro')^2) + mu * norm(vec(S),1);
        disp(['Jacobi algorithm, iteration ' num2str(t) ', time ' num2str(time_j(s,t+1)) ', value ' num2str(val_j(s,t+1))...
            ', stepsize ' num2str(gamma)]);
    end
    X = P * Q;
    disp(['check the optimality of solution: ' num2str([norm(Y - P * Q - D * S) lambda])]);
    clear P Q S gamma
    
    %% BCD (Gauss-Seidel) algorithm
    P           = initial_P; 
    Q           = initial_Q; 
    S           = initial_S;
    d_DtD       = diag(diag(D' * D));
    val_g(s,1)  = val0; 
    time_g(s,1) = 0;
%     disp(['Gauss-Seidel algorithm, iteration ' num2str(0) ', time ' num2str(0) ', value ' num2str(val_g(s,1))]);
    for t = 1:1:MaxIter_g
        tic;
        
        P = (Y - D * S) * Q' * (Q * Q' + lambda * eye(rho))^-1;
        
        Q = (P' * P + lambda * eye(rho))^-1 * P' * (Y - D * S);
        
        for i = 1:1:I
            q_i    = -D(:,i)' * (P * Q + D * S - D(:,i) * S(i,:) - Y);
            S(i,:) = (max(q_i - mu * ones(1,K),0) - max(-q_i - mu * ones(1,K),0))/d_DtD(i,i);
            clear q_i;
        end
        
        time_g(s,t+1) = toc + time_g(s,t);
        
        val_g(s,t+1)  = 0.5 * norm(Y - P * Q - D * S,'fro')^2 + 0.5 * lambda * (norm(P,'fro')^2 + norm(Q,'fro')^2) + mu * norm(vec(S),1);
        
        disp(['Gauss-Seidel algorithm, iteration ' num2str(t) ', time ' num2str(time_g(s,t+1)) ', value ' num2str(val_g(s,t+1))]);
    end
    clear P Q S
    clear P0 Q0 S0    
    clear d_DtD X0 V
    
    %% ADMM algorithm
    P  = initial_P; 
    Q  = initial_Q; 
    E  = initial_S; 
    F  = initial_S; 
    Pi = zeros(I,K); 
    c  = 10^4;
    val_a(s,1) = val0;
    time_a(s,1) = 0;
    for t = 1:1:MaxIter_a
        tic;
        Q  = (P' * P + lambda * eye(rho))^-1 * P' * (Y - D * E);
        F  = max(E + Pi/c - mu/c * ones(I,K),zeros(I,K)) - max(-E - Pi/c - mu/c * ones(I,K),zeros(I,K));
        P  = (Y - D * E) * Q' * (Q * Q' + lambda * eye(rho))^-1;
        E  = (D' * D + c * eye(I))^-1 * (c * F - D' * (P * Q - Y) - Pi);
        Pi = Pi + c * (E - F);

        val_a(s,t+1)  = 0.5 * norm(Y - P * Q - D * E,'fro')^2 + 0.5 * lambda * (norm(P,'fro')^2 + norm(Q,'fro')^2) + mu * norm(vec(E),1);
        time_a(s,t+1) = toc + time_a(s,t);
%         disp(['ADMM algorithm, iteration ' num2str(t) ', time ' num2str(time_a(s,t+1)) ', value ' num2str(val_a(s,t+1))]);
    end
    clear P Q E F Pi
end

%%
subplot(2,1,1);
semilogy(0: 1: MaxIter_j, mean(val_j, 1), 'r', 'LineWidth', 1.5);
hold on; box on;
semilogy(0: 1: MaxIter_g, mean(val_g, 1), 'k--o', 'LineWidth', 1.5);
semilogy(0: 1: MaxIter_a, mean(val_a, 1), 'b-.', 'LineWidth', 1.5);
legend('STELA (proposed)', 'BCD algorithm (state-of-the-art)', 'ADMM (state-of-the-art)');
xlabel('number of iterations');
ylabel('function value');

subplot(2,1,2);
semilogy(mean(time_j, 1)/60,mean(val_j, 1), 'r', 'LineWidth', 1.5);
hold on; box on;
semilogy(mean(time_g, 1)/60,mean(val_g, 1), 'k--o', 'LineWidth', 1.5);
semilogy(mean(time_a, 1)/60,mean(val_a, 1), 'b-.', 'LineWidth', 1.5);
legend('STELA (proposed)','BCD algorithm (state-of-the-art)','ADMM (state-of-the-art)');
xlabel('CPU time (minutes)');ylabel('function value');