classdef NystromSketch
    %NYSTROMSKETCH implements a class definition for the sketching method 
    %described [TYUC2017Nys]. 
    %
	%[TYUC2017Nys] J.A. Tropp, A. Yurtsever, M. Udell and V. Cevher. Fixed-
	%Rank Approximation of a Positive-Semidefinite Matrix from Streaming 
	%Data. In Proc. 31st Conference on Neural Information Processing Systems
	%(NIPS), Long Beach, CA, USA, December 2017.
    %
    %Coded by: Alp Yurtsever
    %Ecole Polytechnique Federale de Lausanne, Switzerland.
    %Laboratory for Information and Inference Systems, LIONS.
    %contact: alp.yurtsever@epfl.ch
    %Created: April 12, 2017
    %Last modified: October 24, 2017
    %
    %Nys—SKETCHv1.0
	%Copyright (C) 2017 Laboratory for Information and Inference Systems
	%(LIONS), Ecole Polytechnique Federale de Lausanne, Switzerland.
	%This code is a part of Nys—SKETCH toolbox. 
	%Please read COPYRIGHT before using this file.

    
    %% properties
    properties
        Field     % 'real' or 'complex' field (default, 'complex')
        Model     % 'gaussian', 'orthonormal', or 'ssft'
        Omega     % (n x k) dimensional test matrix for the range of A (std Gaussian + orthonormalization)
        Omega2    % (n x k) dimensional test matrix for the range of A (std Gaussian + orthonormalization)
        Y         % (n x k) dimensional range sketch
        n         % dimensions of the sketch
        k         % dimensions of the sketch
    end
    
    %% methods
    methods
        % Constructor
        function obj = NystromSketch(A, k, varargin)
            % process variable length inputs
            obj.k = k;
            % 1. Field, 2. Orthogonalization
            narginchk(2,4);
            if nargin >= 3
                obj.Field  = lower(varargin{1});
            else
                obj.Field = 'complex';
            end
            if nargin == 4
                obj.Model  = lower(varargin{2});
            else
                obj.Model = 'gaussian';
            end
            % construct the sketch
            if size(A,1) ~= size(A,2)
                error('Input matrix must be square poositive-semidefinite.');
            end
            n = size(A, 1);
            if n == 1
                n = A;
                A = [];
            end
            obj.n = n;
            if strcmp(obj.Model, 'ssft')
                indR = sort(randsample(n,k));
                opR  = @(x) x(:,indR);
                Pi1  = spdiags(sign((randn(n,1)<0)-0.5),0,n,n);
                Pi1  = Pi1(randperm(n),:);
                Pi2  = spdiags(sign((randn(n,1)<0)-0.5),0,n,n);
                Pi2  = Pi2(randperm(n),:);
                if strcmp(obj.Field,'real')
                    dctLeftMult = @(x) idct(x')';
                    dctRightMult = @(x) dct(x);
                    obj.Omega = @(x) opR(dctLeftMult(dctLeftMult(x*Pi1)*Pi2));
                    R = speye(n);
                    R = full(R(:,indR));
                    obj.Omega2 = Pi1*(dctRightMult(Pi2*(dctRightMult(R))));
                elseif strcmp(obj.Field,'complex')
                    dftLeftMult = @(x) (ifft(full(x'))*sqrt(n))';
                    dftRightMult = @(x) (fft(full(x))/sqrt(n));
                    obj.Omega = @(x) opR(dftLeftMult(dftLeftMult(x*Pi1)*Pi2));
                    R = speye(n);
                    R = full(R(:,indR));
                    obj.Omega2 = Pi1*(dftRightMult(Pi2*(dftRightMult(R))));
                else
                    error('Field should be ''real'' or ''complex''.');
                end               
            elseif strcmp(obj.Model, 'orthonormal') || strcmp(obj.Model, 'gaussian')
                if strcmp(obj.Field,'real')
                    obj.Omega = randn(n,k);
                elseif strcmp(obj.Field,'complex')
                    obj.Omega = randn(n,k) + 1i*randn(n,k);
                else
                    error('Field should be ''real'' or ''complex''.');
                end
                if strcmp(obj.Model, 'orthonormal')
                    [obj.Omega, ~] = qr(obj.Omega,0);
                end
            else
                error('Model should be ''gaussian'', ''orthonormal'', or ''ssft''.');
            end
            % initialize the sketch
            if isempty(A)
                obj.Y = zeros(n,k);
            else
                if isa(obj.Omega,'function_handle')
                    obj.Y = obj.Omega(A);   % Note that Omega(A) computes A*Omega
                else
                    obj.Y = A*obj.Omega;
                end
            end
        end
        
        % Other methods
        [Q, D]  = FixedRankPSDApprox(obj, r)
        [Q, D]  = GittensMahoneyApprox(obj, r)
        [Q, D]  = FixedRankPSDApproxUnstable(obj, r)
    end
end

