// If you're using (parts of) this work, please cite the corresponding
// publication: ivrg.epfl.ch/Lindner_IEEE_MM_2015
//
// For any questions contact the author:
// ajl.epfl@gmail.com  http://ivrg.epfl.ch/people/lindner

#include "mex.h"
#include <math.h>

inline double linRGB2sRGB(double in) {
    if (in <= 0.0031308) {
        return 12.92*in;
	}
	else {
        return 1.055*pow(in, 1.0/2.4) - 0.055;
	}
}

inline bool normalize(double *R, double *G, double *B) {
    bool clip = false;
    if (*R > 1) {
        *R = 1.0;
        clip = true;
    }
    else if (*R < 0) {
        *R = 0.0;
        clip = true;
    }
    if (*G > 1) {
        *G = 1.0;
        clip = true;
    }
    else if (*G < 0) {
        *G = 0.0;
        clip = true;
    }
    if (*B > 1) {
        *B = 1.0;
        clip = true;
    }
    else if (*B < 0) {
        *B = 0.0;
        clip = true;
    }
    return clip;
}

/*
 The conversion from sRGB to XYZ is based on the formulas in:
 Michael Stokes, Matthew Anderson, Srinivasan Chandrasekar, Ricardo Motta (1996). "A Standard Default Color Space for the Internet - sRGB"
 The conversion from XYZ to CIELAB is based on the formulas in:
 Robert Hunt, "Measuring Color", 3rd edition, 1998
 
 see als: http://www.brucelindbloom.com/
*/
void Lab2sRGB(double *Lab, double *sRGB, int H, int W) {
    int HW = H*W;
    
    /* white point D65*/
    const double Xn = 0.9504;
    const double Yn = 1.0;
    const double Zn = 1.0889;
    
    const double epsilon = 0.008856;
    const double kappa = 903.3;
    const double kappaepsilon = kappa*epsilon;
    
	const double epsilon3 = pow(epsilon, 1/3.0);
    const double kappa3 = pow(kappa, 3.0);
    
    double L, a, b, X, Y, Z, f_x, f_y, f_z, R, G, B;
    
    for (int i = 0; i<HW; i++) {
        //mexPrintf("%d\n", i);
        L = Lab[i];
        a = Lab[i+HW];
        b = Lab[i+2*HW];
        
        //mexPrintf("Lab: %f %f %f\n", L, a, b);
        
        /* Lab to XYZ */
        f_y = (L+16)/116;
        f_x = a/500 + f_y;
        f_z = f_y - b/200;
        
        //mexPrintf("f(ratios): %f %f %f\n", f_x, f_y, f_z);
        
        Y = (L > kappaepsilon) ? pow(f_y, 3.0) : L/kappa;
        Y *= Yn;
        X = (f_x > epsilon3) ? pow(f_x, 3.0) : (116*f_x-16)/kappa;
        X *= Xn;
        Z = (f_z > epsilon3) ? pow(f_z, 3.0) : (116*f_z-16)/kappa;
        Z *= Zn;
        
        //mexPrintf("XYZ: %f %f %f\n", X, Y, Z);
        
        /* XYZ to linear RGB */
        R = 3.2404542*X - 1.5371385*Y - 0.4985314*Z;
        G = -0.9692660*X + 1.8760108*Y + 0.0415560*Z;
        B = 0.0556434*X - 0.2040259*Y + 1.0572252*Z;
        
        
        /* linear RGB to sRGB */
        R = linRGB2sRGB(R);
        G = linRGB2sRGB(G);
        B = linRGB2sRGB(B);
        double R_ = R;
        double G_ = G;
        double B_ = B;
        
        bool clip = normalize(&R, &G, &B);
        
        sRGB[i] = R;
        sRGB[i+HW] = G;
        sRGB[i+2*HW] = B;
        
        //if (clip) {
        //    mexPrintf("clipped: %.2f %.2f %.2f -> %.2f %.2f %.2f\n", L, a, b, R_, G_, B_);
        //}
        //mexPrintf("sRGB: %f %f %f\n", sRGB[i], sRGB[i+HW], sRGB[i+2*HW]);
        //mexPrintf("\n");
    }
}


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    double *sRGB;
    double *Lab;
    int H, W, D, HW;
    mwSize number_of_dimensions; 
    const mwSize *dims;
    
    /* check input */
    if(nrhs!=1){
        mexErrMsgTxt("Exactly one input argument required.\n");
    }
    
    /* refuse anything but doubles */
    if (!mxIsDouble(prhs[0])) {
        mexErrMsgTxt("Input data has to be double.\n");
    }
    
    /* get dimensions of image */
    dims = mxGetDimensions(prhs[0]);
    number_of_dimensions = mxGetNumberOfDimensions(prhs[0]);
    
    if (number_of_dimensions == 2) {
        H = dims[0];
        W = 1;
        D = dims[1];
    }
    else if (number_of_dimensions == 3) {
        H = dims[0];
        W = dims[1];
        D = dims[2];
    }
    else {
        mexErrMsgTxt("Lab input image has to be a [H, W, 3] matrix or a [H*W 3] matrix.\n");
    }
    
    if (D != 3) {
        mexErrMsgTxt("Lab input image has to be a [H, W, 3] matrix or a [H*W 3] matrix.\n");
    }
    
    //HW = H*W;
    //mexPrintf("H=%d W=%d D=%d HW=%d\n", H, W, D, HW);
    
    
    
    /* input ok, start processing*/
    Lab = mxGetPr(prhs[0]);
    if (number_of_dimensions == 2) {
        plhs[0] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL);
    }
	else {
        plhs[0] = mxCreateNumericArray(3, dims, mxDOUBLE_CLASS, mxREAL);
    }
            
    sRGB = mxGetPr(plhs[0]);
    
    
    Lab2sRGB(Lab, sRGB, H, W);
    
    return;
}


