%% AUTHORS: Mathieu Xhonneux (UCLouvain), Orion Afisiadis (EPFL)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function main_MC_MU(RunID, TxRx)

%% some init
rng('default')
rng(RunID);

SNRs = TxRx.SNRs;
n_SNRs = length(SNRs);
n_packets = TxRx.n_packets;
Pi = TxRx.Pi;
sf = TxRx.sf;
os_factor = TxRx.os_factor;

F = TxRx.F;
N = 2^sf;

% Number of info bits for the desired packet
nr_info_bits = TxRx.nr_codewords * 4;

n = 0:N-1;
n = n.';

% Gray mapping
n_gray = bin2dec(num2str(nongray2gray(de2bi(n, sf,2,'left-msb'), TxRx)));

% linear value of the interferer power
Pi_lin = 10^(Pi/10);
E_u = 1;
E_i = sqrt(Pi_lin);

% Reference symbol (s = 0)
y_ref = exp(1j*2*pi*(n.^2/(2^(sf+1))-n/2));

%% Generation of Hamming matrices G and H
if(TxRx.CR==1)
    G = [1,0,0,0,1
         0,1,0,0,1
         0,0,1,0,1
         0,0,0,1,1];
    
    H = [1,1,1,1,1];
    HT = H.';
else
    G = [1,0,0,0,1,0,1,1
         0,1,0,0,1,1,1,0
         0,0,1,0,1,1,0,1
         0,0,0,1,0,1,1,1];
          
    G = G(:, 1:TxRx.PPM);
    H = gen2par(G);             % generate H from G
    HT = H.';                   % Transpose
end

%% Preparing error patterns for soft-decoding
% Permutation of 1,2 and 3 bit errors
[e1, e2, e3] = permute_e(TxRx.PPM, 3);
e=[e1;e2;e3];

% syndrom calculation
syn=NaN(size(e,1), TxRx.CR);
for k=1:size(e,1)
    syn(k,:)=mod(e(k,:)*HT,2);
end

% sort rows of syndroms associated error pattern e
[s_sort,s_idx] = sortrows(syn);     % sort syn to s_sort
e_sort=e(s_idx,:);                  % sort e to e_sort in the same way as syn

e_sort=e_sort(sum(s_sort,2)>0,:);   % delete zero rows
s_sort=s_sort(sum(s_sort,2)>0,:);   % delete zero rows

s_de=bem_bi2de(s_sort);             % calculate decimal values from s_sort

s_pdf=hist(s_de,1:2^TxRx.CR-1);          % calculate which value occurs how often

%% Initialize counters
BE_u = zeros(n_packets, n_SNRs);
PE_u = zeros(n_packets, n_SNRs);
SE_u = zeros(n_packets, n_SNRs);
CWE_u = zeros(n_packets, n_SNRs);

BE_i = zeros(n_packets,n_SNRs);
PE_i = zeros(n_packets,n_SNRs);
SE_i = zeros(n_packets,n_SNRs);
CWE_i = zeros(n_packets,n_SNRs);

% For the Monte-Carlo division
totMontCarlo = n_packets*ones(1,n_SNRs);

%% Monte-Carlo Simulation
parfor trial = 1:n_packets
    if (mod(trial,100) == 0)
        trial
    end
    
    % Generate a random information bitstream
    bits_info_u   = randi([0 1],nr_info_bits,1);
    bits_info_i   = randi([0 1],nr_info_bits,1);
    
    % Reshape in a matrix
    bits_mat_info_u   = reshape(bits_info_u, TxRx.HammingCode.Rate*TxRx.PPM, TxRx.nr_codewords).';
    bits_mat_info_i   = reshape(bits_info_i, TxRx.HammingCode.Rate*TxRx.PPM, TxRx.nr_codewords).';
    
    % -- Hamming encoding
    c_u = Hamming_enc_Or(bits_mat_info_u,TxRx,G);
    c_i = Hamming_enc_Or(bits_mat_info_i,TxRx,G);
    
    % -- Interleaving
    [c_interleaved_u] = LoRa_interleaver_Or(c_u,TxRx);
    [c_interleaved_i] = LoRa_interleaver_Or(c_i,TxRx);
    
    % -- Gray encoding
    if TxRx.Gray.Reverse==1
        % -- reverse Gray coding according to LoRa standard
        [~,c_gray_u] = gray2nongray(c_interleaved_u,TxRx);
        [~,c_gray_i] = gray2nongray(c_interleaved_i,TxRx);
    else
        % -- Normal Gray coding
        c_gray_u = nongray2gray(c_interleaved_u,TxRx);
        c_gray_i = nongray2gray(c_interleaved_i,TxRx);
    end
    
    % Payload symbols of the sync. user
    data_u = bin2dec(num2str(c_gray_u));         % convert bits to decimal values ()
    s_u = [data_u.', zeros(1, 1+TxRx.sym_delay)];
    
    % Create the full async. user frame (only payload)
    data_i = bin2dec(num2str(c_gray_i));         % convert bits to decimal values ()
    s_i = [randi(N, 1, 1+TxRx.sym_delay)-1, data_i.', 0];
        
    % Create the matrix that contains the frame to be transmitted. Each column is a LoRa symbol
    y_matrix = exp(1j*2*pi*(repmat(n, [1 length(s_u)]).^2/(N*2)+n*(s_u/N-1/2)));
    % Vectorize
    y = y_matrix(:);
    
    y_interf_full_os = zeros(N*os_factor, length(s_i));
    for kk = 1:length(s_i)      
        % Initialize to 0
        symb_os = zeros(1, N*os_factor);
        % kth symbol
        s_fold_os = (N-s_i(kk))*os_factor;

        n_osA= 0:s_fold_os-1;
        n_osB= s_fold_os:N*os_factor-1;

        symb_os(1:s_fold_os)      =  exp( 2i * pi *(1/(2*N)*(1/os_factor)^2*n_osA.^2+(s_i(kk)/N-1/2)*(1/os_factor)*n_osA));
        symb_os(s_fold_os+1:N*os_factor) =  exp( 2i * pi *(1/(2*N)*(1/os_factor)^2*n_osB.^2+(s_i(kk)/N-3/2)*(1/os_factor)*n_osB));
        y_interf_full_os(:,kk) = symb_os;
    end
    
    % Add the Carrier Frequency Offset to the interferer
    if (isnan(TxRx.cfo_interf))
        cfo_interf = (rand() - 0.5)*TxRx.fs_interf/(2^TxRx.sf); 
    else
        cfo_interf = TxRx.cfo_interf;
    end
    y_interf_os_cfo = add_cfo_packet(TxRx, y_interf_full_os, cfo_interf);
    
    % Vectorize the interfering packet
    y_interf_full_os_vec = y_interf_os_cfo(:);

    % Choose the offset (sample misalignment) randomly in the range [0 (os_factor*2^sf)-1]
    if TxRx.timing_offset == -1
        offset = randi(os_factor*N) - 1;
    else
        offset = TxRx.timing_offset;
    end
    tau = offset / os_factor;
    tau_hat = (offset + TxRx.timing_error_est) / os_factor;
    
    % full frame length with inter-symbol delay
    F_pad = F + TxRx.sym_delay;

    % Get the portion of the interference according to the offset value
    y_interf_unitpower = y_interf_full_os_vec(os_factor*N-offset+1:os_factor:(F_pad+2)*os_factor*N-offset);
    
    % Add a random phase to the user, which corresponds to the phi in the paper
    if (TxRx.phase_offset_user == 1)
        h_u = exp(1j*rand*2*pi);
        y = h_u * y;
    end
    
    % Add a random phase to the interf
    if (TxRx.phase_offset_interf == 1)
        h_i = exp(1j*rand*2*pi);
        y_interf_unitpower = h_i * y_interf_unitpower;
    end
    
    % Adjust the power of the interferer
    y_interf = sqrt(Pi_lin) * y_interf_unitpower;
    
    % Reshape the interf to a matrix
    y_interf_mat = reshape(y_interf, [N F_pad+1]);
    
    % Matched filters M1 and M2 pre-computation
    delta_cfo = cfo_interf*N/TxRx.fs;
    
    n1 = N-ceil(tau_hat):N-1;
    mf1 = exp(1j*2*pi*n1*(tau-delta_cfo)/N); % Matched filter for S_i,1
    f_S1 = (repmat(n.', [N 1]) + repmat(n, [1 N])) >= N + ceil(tau_hat);
    W1 = dftmtx(N) .* exp(-1j*2*pi*tau_hat*f_S1); % DFT matrix for S_i,1
    W1 = W1(ceil(tau_hat)+1:end, :);

    n2 = 0:N-ceil(tau_hat)-1;
    mf2 = exp(1j*2*pi*n2*(tau-delta_cfo)/N); % Matched filter for S_i,2
    f_S2 = (repmat(n.', [N 1]) + repmat(n, [1 N])) > ceil(tau_hat) - 1;
    W2 = dftmtx(N) .* exp(-1j*2*pi*tau_hat*f_S2); % DFT matrix for S_i,2
    W2 = W2(1:ceil(tau_hat), :);

    % AWGN generation
    noise = randn(size(y)) + 1i*randn(size(y));
    
    for idx_SNR = 1:n_SNRs
        SNR = SNRs(idx_SNR);
        SNRlin = 10^(SNR/10);
        
        % Add the noise
        y_noisy = y + sqrt(1/(2*SNRlin))*noise;
        
        % Reshape the noisy signal to matrix
        y_noisy_mat = reshape(y_noisy, [N length(s_u)]);
        
        % Add the interference only on the payload part
        y_noisy_interf = y_noisy_mat + y_interf_mat;
              
        %Dechirp
        y_dechirped = y_noisy_interf.*conj(repmat(y_ref, [1 F_pad+1]));
        
        % State vectors sync. user
        llr_u_ext = zeros(TxRx.nr_codewords, TxRx.PPM);
           
        %%% Soft-demodulation of sync. user
        Ys = fft(y_dechirped(:, 1:F_pad)); % no SIC yet
        
        for n_iter=1:TxRx.n_iterations
            llr_u_ext_interleaved = LoRa_interleaver_Or(llr_u_ext, TxRx);        
            llr_u_interleaved = zeros(F, sf);
            for k=0:sf-1
                % All symbols where bit k is equal to 1/0
                ind1 = n(bitand(n_gray, 2^k) ~= 0)+1;
                ind0 = n(bitand(n_gray, 2^k) == 0)+1;

                % All codewords where bit k is equal to 1/0
                n1 = n_gray(bitand(n_gray, 2^k) ~= 0);
                n0 = n_gray(bitand(n_gray, 2^k) == 0);

                llr_u_ext_k = llr_u_ext_interleaved.';
                llr_u_ext_k(k+1,:) = 0; 
                z1 = log(besseli(0, SNRlin*abs(Ys(ind1,1:F)))) + ...
                     de2bi(n1, sf, 2, 'left-msb') * -llr_u_ext_k;

                z0 = log(besseli(0, SNRlin*abs(Ys(ind0,1:F)))) + ...
                     de2bi(n0, sf, 2, 'left-msb') * -llr_u_ext_k;
                llr_u_interleaved(:, sf-k) = max(z0) - max(z1);
            end     
            
            % Deinterleaving
            llr_u_demod = LoRa_deinterleaver_Or(llr_u_interleaved, TxRx);

            % Soft-decoding
            [c_hat_u, llr_u_code] = hamm_soft_out(HT, e_sort, s_pdf, llr_u_demod);
            llr_u_ext = llr_u_code;
            
            llr_u_code_interleaved = LoRa_interleaver_Or(llr_u_code, TxRx);
        end
            
        %%% Interference cancellation and demodulation of B
        su_hat = zeros(1, length(data_u));
        si_hat = zeros(1, length(data_i));

        m2_prev = zeros(1, N);

        M_i = zeros(N, F);
        for kk=1:F_pad+1
            kk_i = kk - TxRx.sym_delay - 1;

            Y = fft(y_dechirped(:,kk));
            if (kk >= F+1) % we have a padding symbol for sync. user ...
                pU = zeros(N,1);
                pU(s_u(kk)+1) = 1; % so we use an oracle to demodulate it
            else % default case: get probabilities from LLR
                pU = prod( exp(de2bi(n_gray, sf,2,'left-msb') ...
                       .* repmat(-llr_u_code_interleaved(kk,:), [N 1])) ...
                       ./ repmat(1 + exp(-llr_u_code_interleaved(kk,:)), [N 1]), 2 );
            end

            % Hard interference cancellation, select most likely symbol
            % from pU
            [~, ind_peaks] = sort(pU, 'descend');
            s_u_cand = ind_peaks(1)-1;
            if (kk <= F)
                su_hat(kk) = s_u_cand;
            end

            phi_u_hat = angle(Y(s_u_cand+1)); % estimate phase
            y_i = y_dechirped(:,kk) - E_u * exp(1j*phi_u_hat) * exp(1j*2*pi*n*s_u_cand/N); % SIC version of the received signal

            % Computing both part of the "correlation-metric" of
            % interfering symbols and the "SIC" vector

            y_i1 = y_i(1:ceil(tau_hat)).' .* mf1; % part of interfering symbol s_i1
            y_i2 = y_i(ceil(tau_hat)+1:end).'  .* mf2; % part of interfering symbol s_i2

            m1 = y_i1 * W2;
            m2 = y_i2 * W1;

            if (kk_i > 0)
                M_i(:,kk_i) = m1 + m2_prev;

                [~, idx_si] = max(M_i(:,kk_i));
                si_hat(kk_i) = idx_si - 1;
            end

            m2_prev = m2;
        end

        %%% Soft-demodulation of user I
        llr_i_interleaved = zeros(F, sf);
 
        for k=0:sf-1
            % All symbols where bit k is equal to 1/0
            ind1 = n(bitand(n_gray, 2^k) ~= 0)+1;
            ind0 = n(bitand(n_gray, 2^k) == 0)+1;

            z1 = log(besseli(0, E_i*SNRlin*abs(M_i(ind1,:))));
            z0 = log(besseli(0, E_i*SNRlin*abs(M_i(ind0,:))));
            llr_i_interleaved(:, sf-k) = max(z0) - max(z1);
        end

        % Deinterleaving
        llr_i_demod = LoRa_deinterleaver_Or(llr_i_interleaved, TxRx);

        % Soft-decoding of user I
        [c_hat_i, ~] = hamm_soft_out(HT, e_sort, s_pdf, llr_i_demod);

        % Count symbol errors
        n_errs_u = sum(su_hat~=data_u.');
        n_errs_i = sum(si_hat~=data_i.');  
        SE_u(trial, idx_SNR) = n_errs_u;
        SE_i(trial, idx_SNR) = n_errs_i;
        
        bit_errors_u = c_hat_u(:, 1:4) ~= bits_mat_info_u;
        bit_errors_i = c_hat_i(:, 1:4) ~= bits_mat_info_i;
        BE_u(trial, idx_SNR) = sum(sum(bit_errors_u));
        BE_i(trial, idx_SNR) = sum(sum(bit_errors_i));
        
        % Vector containing the number of bit errors in every estimated codeword       
        errorsPerCodeword_u = sum((c_hat_u ~= c_u(:,1:TxRx.PPM)),2);
        errorsPerCodeword_i = sum((c_hat_i ~= c_i(:,1:TxRx.PPM)),2);
         
        % Codeword Error count for debugging
        CWE_u(trial, idx_SNR) = sum(errorsPerCodeword_u ~= 0);
        CWE_i(trial, idx_SNR) = sum(errorsPerCodeword_i ~= 0);
                
        % If any of the codewords were erroneously decoded then consider the whole frame wrong
        if sum(errorsPerCodeword_u ~= 0) ~= 0
            PE_u(trial, idx_SNR) = 1;
        end
        if sum(errorsPerCodeword_i ~= 0) ~= 0
            PE_i(trial, idx_SNR) = 1;
        end
    end
    
end

%% Post-processing of statistics
BER_u = sum(BE_u, 1)./(nr_info_bits *totMontCarlo);
PER_u = sum(PE_u, 1)./totMontCarlo;
SER_u = sum(SE_u, 1)./(TxRx.F*totMontCarlo);
CWER_u = sum(CWE_u, 1)./(TxRx.nr_codewords*totMontCarlo);

BER_i = sum(BE_i, 1)./(nr_info_bits *totMontCarlo);
PER_i = sum(PE_i, 1)./totMontCarlo;
SER_i = sum(SE_i, 1)./(TxRx.F*totMontCarlo);
CWER_i = sum(CWE_i, 1)./(TxRx.nr_codewords*totMontCarlo);

%%%%%%%%%%%%%%%%%%
% Save results
Results.TxRx = TxRx;

Results.BER_u = BER_u;
Results.PER_u = PER_u;
Results.SER_u = SER_u;
Results.CWER_u = CWER_u;

Results.BER_i = BER_i;
Results.PER_i = PER_i;
Results.SER_i = SER_i;
Results.CWER_i = CWER_i;

Results.fileName = sprintf('results/Monte-Carlo/%s_%d.mat',TxRx.basename,RunID);
save(Results.fileName,'Results');
%%%%%%%%%%%%%%%%%%

% Theoretical Symbol Error Rates for uncoded LoRa
SNRs_lin = 10.^(TxRx.SNRs/10);
theo_SERu = 0.5 * qfunc(sqrt(SNRs_lin.*2.^(TxRx.sf+1)) - sqrt(1.386*TxRx.sf + 1.154)) * (2^TxRx.sf - 1) / 2^(TxRx.sf-1);

SNRs_lin = 10.^((TxRx.SNRs + TxRx.Pi)/10);
theo_SERi = 0.5 * qfunc(sqrt(SNRs_lin.*2.^(TxRx.sf+1)) - sqrt(1.386*TxRx.sf + 1.154)) * (2^TxRx.sf - 1) / 2^(TxRx.sf-1);

figure;
semilogy(TxRx.SNRs + TxRx.Pi, SER_u, 'LineWidth', 2, 'DisplayName', 'SERu')
hold on;
semilogy(TxRx.SNRs + TxRx.Pi, SER_i, 'LineWidth', 2, 'DisplayName', 'SERi')
hold on;

set(gca,'ColorOrderIndex',1)
semilogy(TxRx.SNRs + TxRx.Pi, theo_SERu, '--', 'LineWidth', 2, 'DisplayName', 'SERu 1U NC')
hold on;
semilogy(TxRx.SNRs + TxRx.Pi, theo_SERi, '--', 'LineWidth', 2, 'DisplayName', 'SERi 1U NC')
ylim([1e-4 1])
xlim([TxRx.SNRs(1)+TxRx.Pi TxRx.SNRs(end)-TxRx.Pi])
legend('show');
xlabel('SNR weakest user (dB)')
ylabel('SER')
title(sprintf('SF = %d, Pi=%d, CR=4/%d', TxRx.sf, TxRx.Pi, TxRx.PPM));

theo_PERu = 1 - (1 - theo_SERu).^TxRx.F;
theo_PERi = 1 - (1 - theo_SERi).^TxRx.F;

figure;
semilogy(TxRx.SNRs + TxRx.Pi, PER_u, 'LineWidth', 2, 'DisplayName', 'PERu')
hold on;
semilogy(TxRx.SNRs + TxRx.Pi, PER_i, 'LineWidth', 2, 'DisplayName', 'PERi')
hold on;

set(gca,'ColorOrderIndex',1)
semilogy(TxRx.SNRs + TxRx.Pi, theo_PERu, '--', 'LineWidth', 2, 'DisplayName', 'PERu 1U NC')
hold on;
semilogy(TxRx.SNRs + TxRx.Pi, theo_PERi, '--', 'LineWidth', 2, 'DisplayName', 'PERi 1U NC')
ylim([1e-4 1])
xlim([TxRx.SNRs(1)+TxRx.Pi TxRx.SNRs(end)+TxRx.Pi])
legend('show');
xlabel('SNR weakest user (dB)')
ylabel('PER')
title(sprintf('SF = %d, Pi=%d, CR=4/%d', TxRx.sf, TxRx.Pi, TxRx.PPM));

figure;
semilogy(TxRx.SNRs + TxRx.Pi, BER_u, 'LineWidth', 2, 'DisplayName', 'BERu')
hold on;
semilogy(TxRx.SNRs + TxRx.Pi, BER_i, 'LineWidth', 2, 'DisplayName', 'BERi')
hold on;

% Theoretical Bit Error Rates for uncoded LoRa
theo_BERu = theo_SERu / (2^TxRx.sf - 1) * 2^(TxRx.sf-1);
theo_BERi = theo_SERi / (2^TxRx.sf - 1) * 2^(TxRx.sf-1);

set(gca,'ColorOrderIndex',1)
semilogy(TxRx.SNRs + TxRx.Pi, theo_BERu, '--', 'LineWidth', 2, 'DisplayName', 'BERu 1U NC')
hold on;
semilogy(TxRx.SNRs + TxRx.Pi, theo_BERi, '--', 'LineWidth', 2, 'DisplayName', 'BERi 1U NC')
ylim([1e-5 1])
xlim([TxRx.SNRs(1)+TxRx.Pi TxRx.SNRs(end)+TxRx.Pi])
xlabel('SNR weakest user (dB)')
ylabel('BER')
legend('show');
title(sprintf('SF = %d, Pi=%d, CR=4/%d', TxRx.sf, TxRx.Pi, TxRx.PPM));

return