"""
This file contains the specific functions to solve KPP equations
in 2D using a WENO scheme.
"""

import numpy as np
from KPP import *
from Weno import *
from helpers import extend

## 2D ##
###################################################################################
def KPPWENOrhs2D(x,y,u,hx,hy,k,m,Crec,dw,beta,maxvel):
    """Purpose: Evaluate right hand side for 2D Burgers equation 
        using a WENO method"""

    Nxy = x.shape
    Nx = Nxy[1]
    Ny = Nxy[0]

    du = np.zeros((Ny,Nx))

    # Extend data and assign boundary conditions in x-direction
    for i in range(Ny):
        xe,ue = extend(x[i,:],u[i,:],m,"D",np.pi/4,"D",np.pi/4);

        # define cell left and right interface values
        ul = np.zeros(Nx+2)
        ur = np.zeros(Nx+2)

        for j in range(Nx+2):
            ul[j],ur[j] = WENO(xe[j:1+(j+2*(m-1))],ue[j:1+(j+2*(m-1))],m,Crec,dw,beta)

        # Update residual
        du[i,:] = - (KPPxLF(ur[1:Nx+1],ul[2:Nx+2],maxvel) - \
                     KPPxLF(ur[:Nx],ul[1:Nx+1],maxvel))/hx;

    # Extend data and assign boundary conditions in y-direction
    for j in range(Nx):
        xe,ue = extend(y[:,j],u[:,j],m,'D',np.pi/4,'D',np.pi/4)

        # define cell left and right interface values
        ul = np.zeros(Ny+2)
        ur = np.zeros(Ny+2)

        for i in range(Ny+2):
            ul[i],ur[i] = WENO(xe[i:1+(i+2*(m-1))],ue[i:1+(i+2*(m-1))],m,Crec,dw,beta)

        # Update residual
        du[:,j] -= (KPPyLF(ur[1:Ny+1],ul[2:Ny+2],maxvel) - \
                    KPPyLF(ur[:Ny],ul[1:Ny+1],maxvel))/hy
    return du


def KPPWENO2D(x,y,u,hx,hy,m,CFL,FinalTime):
    """Purpose: Integrate 2D KPP equation until FinalTime 
           using a WENO scheme."""

    t = 0.0
    tstep = 0.0
    delta = min(hx,hy)

    #Initialize reconstruction weights
    Crec = np.zeros((m+1,m))
    for r in range(-1,m):
        Crec[r+1,:] = ReconstructWeights(m,r)

    # Initialize linear weights
    dw = LinearWeights(m,0)

    # Compute smoothness indicator matrices
    beta = np.zeros((m,m,m))
    for r in range(m):
        xl = -1/2 + np.arange(-r,m-r+1)
        beta[:,:,r] = betarcalc(xl,m)

    k = CFL*delta/2
    maxvel = 1.0

    # integrate scheme
    while (t<FinalTime):
        #Decide on timestep
        k = min(FinalTime-t, k)
  
        #Update solution
        rhsu  = KPPWENOrhs2D(x,y,u,hx,hy,k,m,Crec,dw,beta,maxvel)
        u1 = u + k*rhsu;
        rhsu  = KPPWENOrhs2D(x,y,u1,hx,hy,k,m,Crec,dw,beta,maxvel)
        u2 = (3*u + u1 + k*rhsu)/4;
        rhsu  = KPPWENOrhs2D(x,y,u2,hx,hy,k,m,Crec,dw,beta,maxvel)
        u = (u + 2*u2 + 2*k*rhsu)/3;
        t += k
        tstep += 1

    return u