#include "SLIC.h"
#include "mex.h"
#include <cmath>

// method to call superpixels method
void getSuperpixels(double *img, char *gt, mwSize height, mwSize width, mwSize numSPixels, double m, double *labels, mwSize nDims,
                    mxArray*& meanColorArray,
                    mxArray*& meanPosArray,
                    mxArray*& pixelsArray,
                    mxArray*& histArray,
                    mxArray*& hogsArray,
                    mxArray*& edgeMagnitudeArray,
                    mxArray*& superpixelSizeArray,
                    mxArray*& boundaryPixelsArray,
                    mxArray*& truePositiveArray,
                    mxArray*& falsePositiveArray,
                    mxArray*& LABArray,
                    mxArray*& averageSaliencyArray)
{
    unsigned int r, g, b, p, c;
    int i, j;
    
    int w=(int)(width);
    int h=(int)(height);
    int sz=w*h;
    int k=(int)(numSPixels);
    
    //put image into 32 bit uint array
    unsigned int *ubuff=(unsigned int*)mxCalloc(sz, sizeof(unsigned int));
    
    //grayscale image input
    if(nDims==2)
    {
        for(j=0;j<h;j++)
            for(i=0;i<w;i++)
            {
                c=(unsigned int)(img[i*h+j]);
                p=(c<<16) | (c<<8) | c;
                ubuff[j*w+i]=p;
            }
    }
    
    //RGB image (put into ubuff by rows)
    //Note: img is col order
    else
    {
        for(j=0;j<h;j++)
            for(i=0;i<w;i++)
            {
                int rInd=i*h+j;
                r=(unsigned int)(img[rInd]);
                g=(unsigned int)(img[sz+rInd]);
                b=(unsigned int)(img[2*sz+rInd]);
                
                p=(r<<16) | (g<<8) | b;
                ubuff[j*w+i]=p;
            }
    }
    
    //run superpixels
    SLIC segment;
    int *klabels=(int*)mxCalloc(sz, sizeof(int));
    int numlabels=0;
    
    segment.PerformSLICO_ForGivenK(ubuff, w, h, klabels, numlabels, k, m);
    
    vector<double> clustersize(numlabels, 0);
	vector<double> inv(numlabels, 0);
    
	vector<double> sigmal(numlabels, 0);
	vector<double> sigmaa(numlabels, 0);
	vector<double> sigmab(numlabels, 0);
	vector<double> sigmax(numlabels, 0);
	vector<double> sigmay(numlabels, 0);
    
    sigmal.assign(numlabels, 0);
    sigmaa.assign(numlabels, 0);
    sigmab.assign(numlabels, 0);
    sigmax.assign(numlabels, 0);
    sigmay.assign(numlabels, 0);
    
    clustersize.assign(numlabels, 0);
    
    {int ind(0);
		for( int r = 0; r < h; r++ )
		{
			for( int c = 0; c < w; c++ )
			{
				sigmal[klabels[ind]] += segment.m_lvec[ind];
				sigmaa[klabels[ind]] += segment.m_avec[ind];
				sigmab[klabels[ind]] += segment.m_bvec[ind];
				sigmax[klabels[ind]] += c;
				sigmay[klabels[ind]] += r;

				clustersize[klabels[ind]] += 1.0;
				ind++;
			}
		}}
    
    {for( int k = 0; k < numlabels; k++ )
    {
        if( clustersize[k] <= 0 ) clustersize[k] = 1;
        inv[k] = 1.0/clustersize[k];
    }}
    
    int boolSize[] = {h,w,numlabels};
    
    meanColorArray  = mxCreateDoubleMatrix(numlabels, 3, mxREAL);
    meanPosArray    = mxCreateDoubleMatrix(numlabels, 2, mxREAL);
    pixelsArray     = mxCreateLogicalArray(3, boolSize);
    
    histArray       = mxCreateDoubleMatrix(numlabels, 64, mxREAL);
    hogsArray       = mxCreateDoubleMatrix(numlabels, 8, mxREAL);

    edgeMagnitudeArray        = mxCreateDoubleMatrix(h, w, mxREAL);
    
    superpixelSizeArray     = mxCreateDoubleMatrix(numlabels,1, mxREAL);
    boundaryPixelsArray     = mxCreateDoubleMatrix(numlabels,1, mxREAL);
    truePositiveArray       = mxCreateDoubleMatrix(numlabels,1, mxREAL);
    falsePositiveArray      = mxCreateDoubleMatrix(numlabels,1, mxREAL);
    
    averageSaliencyArray    = mxCreateDoubleMatrix(numlabels,1, mxREAL);
    
    double* meanColor   = mxGetPr(meanColorArray);
    double* meanPos     = mxGetPr(meanPosArray);
    bool* pixels        = (bool*)mxGetPr(pixelsArray);
    
    double* hist   = mxGetPr(histArray);
    double* hogs   = mxGetPr(hogsArray);

    double* edgeMagnitude    = mxGetPr(edgeMagnitudeArray);
    double* superpixelSize    = mxGetPr(superpixelSizeArray);
    double* boundaryPixels    = mxGetPr(boundaryPixelsArray);
    
    double* truePositive        = mxGetPr(truePositiveArray);
    double* falsePositive       = mxGetPr(falsePositiveArray);
    
    double* lab                 = mxGetPr(LABArray);
    
    double* averageSaliency     = mxGetPr(averageSaliencyArray);

    for( int i = 0; i < numlabels; i++ )
    {
        meanColor[i + 0*numlabels]  = sigmal[i]*inv[i];
        meanColor[i + 1*numlabels]  = sigmaa[i]*inv[i];
        meanColor[i + 2*numlabels]  = sigmab[i]*inv[i];
        
        meanPos[i + 0*numlabels]    = sigmax[i]*inv[i];
        meanPos[i + 1*numlabels]    = sigmay[i]*inv[i];
    }
    
    double maxL = 0, minL = 1e6, maxA = 0, minA = 1e6, maxB = 0, minB = 1e6;
    
    //put results into labels array
    for(j=0;j<h;j++){
        for(i=0;i<w;i++){
            labels[i*h+j]=(double)(klabels[j*w+i]);
            
            if (segment.m_lvec[j*w+i] > maxL)
                maxL = segment.m_lvec[j*w+i];
            if (segment.m_lvec[j*w+i] < minL)
                minL = segment.m_lvec[j*w+i];
            
            if (segment.m_avec[j*w+i] > maxA)
                maxA = segment.m_avec[j*w+i];
            if (segment.m_avec[j*w+i] < minA)
                minA = segment.m_avec[j*w+i];
            
            if (segment.m_bvec[j*w+i] > maxB)
                maxB = segment.m_bvec[j*w+i];
            if (segment.m_bvec[j*w+i] < minB)
                minB = segment.m_bvec[j*w+i];
            
        }
    }
    
//     printf("%f %f %f %f %f %f\n", minL, maxL, minA, maxA, minB, maxB);
    
    for (int i = 0; i < h; i++) {
        for (int j = 0; j < w; j++) {
            
            int index = i + j * h + labels[i + j * h] * h * w;
            
            superpixelSize[int(labels[i + j * h])]++;
            
            if (i < 16 || j < 16 || i > h - 16 || j > w - 16)
                boundaryPixels[int(labels[i + j * h])]++;
            
            if (gt[i + j * h])
                truePositive[int(labels[i + j * h])]++;
            else
                falsePositive[int(labels[i + j * h])]++;
            
            averageSaliency[int(labels[i + j * h])] += gt[i + j * h];
            
            pixels[index] = true;

            int posIndex = j + i * w;
            
            lab[i + j * h + 0 * (w * h)] = segment.m_lvec[posIndex];
            lab[i + j * h + 1 * (w * h)] = segment.m_avec[posIndex];
            lab[i + j * h + 2 * (w * h)] = segment.m_bvec[posIndex];

            int Lpos, Apos, Bpos;
            
            if (segment.m_lvec[posIndex]     <= 0.25 * (maxL - minL) + minL)
                Lpos = 0;
            else if(segment.m_lvec[posIndex] <= 0.50 * (maxL - minL) + minL)
                Lpos = 1;
            else if(segment.m_lvec[posIndex] <= 0.75 * (maxL - minL) + minL)
                Lpos = 2;
            else
                Lpos = 3;
            
            if (segment.m_avec[posIndex]     <= 0.25 * (maxA - minA) + minA)
                Apos = 0;
            else if(segment.m_avec[posIndex] <= 0.50 * (maxA - minA) + minA)
                Apos = 1;
            else if(segment.m_avec[posIndex] <= 0.75 * (maxA - minA) + minA)
                Apos = 2;
            else
                Apos = 3;
            
            if (segment.m_bvec[posIndex]     <= 0.25 * (maxB - minB) + minB)
                Bpos = 0;
            else if(segment.m_bvec[posIndex] <= 0.50 * (maxB - minB) + minB)
                Bpos = 1;
            else if(segment.m_bvec[posIndex] <= 0.75 * (maxB - minB) + minB)
                Bpos = 2;
            else
                Bpos = 3;
            
//            printf("%f %f %f\n",segment.m_lvec[posIndex],segment.m_avec[posIndex],segment.m_bvec[posIndex]);
            
            int histPos = Lpos + 4 * Apos + 16 * Bpos;

            hist[int(labels[i + j * h]) + histPos * numlabels]++;

            double dxl, dyl, dxa, dya, dxb, dyb;
            
            if (i == 0){
                dyl = segment.m_lvec[(i + 1) * w + j] - segment.m_lvec[i * w + j];
                dya = segment.m_avec[(i + 1) * w + j] - segment.m_avec[i * w + j];
                dyb = segment.m_bvec[(i + 1) * w + j] - segment.m_bvec[i * w + j];
            }
            else if(i == h - 1){
                dyl = -segment.m_lvec[(i - 1) * w + j] + segment.m_lvec[i * w + j];
                dya = -segment.m_avec[(i - 1) * w + j] + segment.m_avec[i * w + j];
                dyb = -segment.m_bvec[(i - 1) * w + j] + segment.m_bvec[i * w + j];
            }
            else{
                dyl = (-segment.m_lvec[(i - 1) * w + j] + segment.m_lvec[(i + 1) * w + j])/2;
                dya = (-segment.m_avec[(i - 1) * w + j] + segment.m_avec[(i + 1) * w + j])/2;
                dyb = (-segment.m_bvec[(i - 1) * w + j] + segment.m_bvec[(i + 1) * w + j])/2;
            }
            
            if (j == 0){
                dxl = segment.m_lvec[i * w + (j + 1)] - segment.m_lvec[i * w + j];
                dxa = segment.m_avec[i * w + (j + 1)] - segment.m_avec[i * w + j];
                dxb = segment.m_bvec[i * w + (j + 1)] - segment.m_bvec[i * w + j];
            }
            else if(j == w - 1){
                dxl = -segment.m_lvec[i * w + (j - 1)] + segment.m_lvec[i * w + j];
                dxa = -segment.m_avec[i * w + (j - 1)] + segment.m_avec[i * w + j];
                dxb = -segment.m_bvec[i * w + (j - 1)] + segment.m_bvec[i * w + j];
            }
            else{
                dxl = (-segment.m_lvec[i * w + (j - 1)] + segment.m_lvec[i * w + (j + 1)])/2;
                dxa = (-segment.m_avec[i * w + (j - 1)] + segment.m_avec[i * w + (j + 1)])/2;
                dxb = (-segment.m_bvec[i * w + (j - 1)] + segment.m_bvec[i * w + (j + 1)])/2;
            }
            
            edgeMagnitude[i + j * h] = dxl*dxl + dyl*dyl + dxa*dxa + dya*dya + dxb*dxb + dyb*dyb;
            
            int hogIndexPosition = 0;
            
            double dx = dxl;
            double dy = dyl;
            
            if (dx >= 0 && dy >= 0      && abs(dx) >= abs(dy)){
                hogIndexPosition = 0;
            }
            else if(dx >= 0 && dy >= 0  && abs(dy) > abs(dx)){
                hogIndexPosition = 1;
            }
            else if(dx < 0 && dy >= 0   && abs(dy) >= abs(dx)){
                hogIndexPosition = 2;
            }
            else if(dx < 0 && dy >= 0   && abs(dy) < abs(dx)){
                hogIndexPosition = 3;
            }
            else if(dx < 0 && dy < 0    && abs(dy) <= abs(dx)){
                hogIndexPosition = 4;
            }
            else if(dx <= 0 && dy < 0    && abs(dy) > abs(dx)){
                hogIndexPosition = 5;
            }
            else if(dx > 0 && dy < 0   && abs(dy) >= abs(dx)){
                hogIndexPosition = 6;
            }
            else{
                hogIndexPosition = 7;
            }

            hogs[int(labels[i + j * h]) + hogIndexPosition * numlabels]++;
            
        }
    }

    //free memory
    if(ubuff)
        mxFree(ubuff);
    if(klabels)
        mxFree(klabels);
}


/* The gateway function */
void mexFunction( int nlhs, mxArray *plhs[],
                 int nrhs, const mxArray *prhs[])
{
    double *img;        //image
    char *gt;
    mwSize numSPixels;  //number of superpixels
    double m;           //compactness
    double *labels;     //labels
    const mwSize *dims;	//image dimensions
    mwSize nDims;       //number of image dimensions (2 for gray, 3 for color)
    
    /* check for proper number of arguments */
//    if(nrhs!=3)
//        mexErrMsgIdAndTxt("MyToolbox:runSuperpixels:nrhs",
//                          "Three inputs required.");
//    if(nlhs!=1)
//        mexErrMsgIdAndTxt("MyToolbox:runSuperpixels:nlhs",
//                          "One output required.");
    
    /* check 1st input argument */
    /* check if input is image matrix of size hxwx3 or hxwx1 */
    if(!mxIsDouble(prhs[0]) ||
       !mxIsNumeric(prhs[0]) ||
       mxIsComplex(prhs[0]) ||
       mxIsEmpty(prhs[0]) ||
       ( mxGetNumberOfDimensions(prhs[0])!=2 &&
        (mxGetNumberOfDimensions(prhs[0])!=3 || mxGetDimensions(prhs[0])[2]!=3) ) )
    {
        mexErrMsgIdAndTxt("MyToolbox:runSuperpixels:invalidInput1",
                          "First input invalid.");
    }
    
    /* check 2nd input argument */
    /* check if input has positive value */
    if(!mxIsNumeric(prhs[1]) ||
       mxIsComplex(prhs[1]) ||
       mxIsEmpty(prhs[1]) ||
       ((mwSize)mxGetScalar(prhs[1]))<1 )
    {
        mexErrMsgIdAndTxt("MyToolbox:runSuperpixels:invalidInput2",
                          "Second input invalid.");
    }
    
    /* check 3rd input argument */
    if( mxGetM(prhs[2])!=1 ||
       mxGetN(prhs[2])!=1 ||
       !mxIsNumeric(prhs[2]) ||
       mxIsComplex(prhs[2]) ||
       mxIsEmpty(prhs[2]) )
    {
        mexErrMsgIdAndTxt("MyToolbox:runSuperpixels:invalidInput3",
                          "Third input invalid.");
    }
    
    //get inputs
    img         = mxGetPr(prhs[0]);
    numSPixels  = (mwSize)mxGetScalar(prhs[1]);
    m           = mxGetScalar(prhs[2]);
    
    gt          = (char*)mxGetPr(prhs[3]);
    
    //image dims
    dims=mxGetDimensions(prhs[0]);
    nDims=mxGetNumberOfDimensions(prhs[0]);
    
    //output
    plhs[0]=mxCreateDoubleMatrix(dims[0], dims[1], mxREAL);
    labels=mxGetPr(plhs[0]);
    
    plhs[11] = mxCreateNumericArray(3, dims, mxDOUBLE_CLASS, mxREAL);
    
    //call superpixels method
    getSuperpixels(img, gt, dims[0], dims[1], numSPixels, m, labels, nDims, plhs[1], plhs[2], plhs[3], plhs[4], plhs[5], plhs[6], plhs[7], plhs[8], plhs[9], plhs[10], plhs[11],plhs[12]);

}

