%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Copyright (c) 2014 Gokhan Yildirim
% Version 1.0
% ===========================================================================
% GENERAL INFORMATION
% ===========================================================================
% 
% This code implements the Hierarchical Regression (HR) salient object 
% detection technique that is explained in the following paper:
% 
% G. Yildirim, A. Shaji, S. Susstrunk, "Saliency Detection Using Regression 
% Trees on Hierarchical Image Segments", IEEE ICIP, 2014
% 
% Please cite the paper if you used our source code.
% 
% This code is shared for non-commercial use only. For commercial use please 
% contact the author:
% 
% gokhan.yildirim@epfl.ch
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

clc
close all
clear all

%% Indicate the path of the UGM framework

addpath(genpath('/Users/gokhan/Downloads/UGM/'));

%% Input image and ground truth paths

imdir = '/Users/gokhan/Desktop/academicWork/Saliency/MSRA/orig/';
gtdir = '/Users/gokhan/Desktop/academicWork/Saliency/MSRA/gt/';

IF = dir([imdir '*.jpg']);
GF = dir([gtdir '*.bmp']);

%% Parameters

global whereToSave SS;

whereToSave          = '.';

numberOfImages  = 16;
SS              = 16; %% This is the image frame size that is used to compute "border pixels" feature

%% Saving functions

mkdir(whereToSave);
mkdir([whereToSave '/hierarchy/']);
mkdir([whereToSave '/feature/']);
mkdir([whereToSave '/map/']);

saveHierarchy   = @(hierarchy,name)save([whereToSave '/hierarchy/' name],'-struct','hierarchy','-v7.3');
saveFeature     = @(I,IP,name)save([whereToSave '/feature/' name],'I','IP');
saveMap         = @(finalSaliencyMap,name)imwrite(mat2gray(finalSaliencyMap),[whereToSave '/map/' name]);
trainFunc       = @(Input,Output,numIters,opts)SQBMatrixTrain(single(Input), Output, uint32(numIters) ,opts );


%% Extract & Save Features

disp('Extracting features...');

parfor i = 1:numberOfImages
    
    disp(num2str(i));
    
    im = imread([imdir IF(i).name]);
    
    [hh,ww,cc]  = size(im);
    
    gt = imread([gtdir GF(i).name]);
    
    gt = gt(:,:,1) > 0;
    
    [hierarchy,edgeMap] = findSuperpixelChar(im,gt,250,1);
    
    hierarchy = findHierarchy(hierarchy,1);
    
    feval(saveHierarchy,hierarchy,[IF(i).name(1:end-3) 'mat']);
    
    [I,IP] = extractFeatures(hierarchy);
    
    feval(saveFeature,I,IP,[IF(i).name(1:end-3) 'mat']);
    
end

%% Load Features

inputData       = cell(numberOfImages,1);
outputData      = cell(numberOfImages,1);
offsetData      = cell(numberOfImages,1);

disp('Loading data...');

for i = 1:numberOfImages

    percentile = i/numberOfImages * 100;

    showText = [repmat('|',1,round(percentile/5)) ' (' num2str(percentile,'%.2f') '%%)'];

    fprintf(showText);

    I2      = load([whereToSave '/feature/' IF(i).name(1:end-3) 'mat']);

    for k = 1:length(I2.I)

        offsetData{i}(k) = size(I2.I{k},1);

    end

    inputData{i}  = cell2mat(I2.I(:));

    outputData{i} = cell2mat(I2.IP(:));

    if i < numberOfImages
        fprintf(repmat('\b',1,length(showText)-1));
    end

end

%% Train the Model

setenv OMP_NUM_THREADS 8

testGroup = zeros(numberOfImages,1);

[sorted,randomList] = sort(rand(1,numberOfImages));

numberOfCrossvalidations = 4;
gs = numberOfImages/numberOfCrossvalidations;

for i = 1:numberOfCrossvalidations

    test{i}                 = randomList((1:gs) + (i-1)*gs);
    training{i}             = setdiff(randomList,test{i});

    testGroup(test{i})      = i;

    Input                   = cell2mat(inputData(training{i}(:)));
    Output                  = cell2mat(outputData(training{i}(:)));

    opts                    = [];
    opts.loss               = 'squaredloss';

    opts.shrinkageFactor    = 0.01;
    opts.subsamplingFactor  = 0.2;
    opts.maxTreeDepth       = uint32(2);
    opts.randSeed           = uint32(rand()*1000);

    numIters                = 600;

    model{i}                = trainFunc(Input,Output,numIters,opts);

end

save([whereToSave '/models.mat'],'model');

%% Test the Model

parfor k = 1:numberOfImages

    disp(num2str(k));

    gt = imread([gtdir GF(k).name]);
    
    gt = gt(:,:,1) > 0;

    [hh,ww] = size(gt);

    hierarchyData = load([whereToSave '/hierarchy/' IF(k).name(1:end-3) 'mat'],'whereat','newPixels','connections');

    offset = 0;

    pred = {};

    map = zeros(hh,ww);

    for level = 1:length(hierarchyData.whereat)

        pred{level} = SQBMatrixPredict(model{testGroup(k)}, single(inputData{k}(offset + (1:offsetData{k}(level)),:)));

        offset = offset + offsetData{k}(level);

    end

    [finalSaliencyMap] = inferSaliency(hierarchyData,pred,hh,ww);

    feval(saveMap,finalSaliencyMap,[IF(k).name(1:end-3) 'png']);

    performance(k) = calculateAUC(finalSaliencyMap,gt);

end

save([whereToSave '/performance.mat'],'performance');

%%

fontsize = 24;

close all

set(gca,'fontsize',fontsize);

U   = struct2cell(performance);

P   = squeeze(cell2mat(U(4,1,:)));
R   = squeeze(cell2mat(U(5,1,:)));
AUC = squeeze(cell2mat(U(6,1,:)));
MAE = squeeze(cell2mat(U(7,1,:)));

plot(mean(R,2),mean(P,2))
xlabel('Recall','fontsize',fontsize)
ylabel('Precision','fontsize',fontsize)
title(['AUC = ' num2str(mean(AUC),3) ' MAE = ' num2str(mean(MAE),3)],'fontsize',fontsize);
axis equal
axis([0 1 0 1])