"""
This file contains the specific functions to solve Burgers equations
in 1D or 2D using a WENO scheme.
"""

import numpy as np
from Burgers import *
from Weno import *
from helpers import extend

## 1D ##
###################################################################################
def BurgersWENOrhs1D(x,u,h,k,m,Crec,dw,beta,maxvel):

    """Evaluate the RHS of Burgers equations using a WENO reconstruction"""
    N = len(x)
    du = np.zeros(N)

    xe,ue = extend(x, u, m, "P", 0, "P", 0)

    # Define cell left and right interface values
    ul = np.zeros(N+2)
    ur = np.zeros(N+2)

    for i in range(N+2):
        ul[i], ur[i] = WENO(xe[i:1+(i+2*(m-1))],ue[i:1+(i+2*(m-1))],m,Crec,dw,beta)

    # Change numerical flux here
    du = - (BurgersLF(ur[1:N+1], ul[2:N+2], maxvel) - \
            BurgersLF(ur[:N], ul[1:N+1], maxvel))/h

    return du

def BurgersWENO1D(x,u,h,m,CFL,FinalTime):
    """Integrate 1D Burgers equation until FinalTime using a WENO
       scheme and 3rd order SSP-RK method
    """   

    t = 0.0
    tstep = 0

    # Initialize reconstruction weights
    Crec = np.zeros((m+1,m))
    for r in range(-1,m):
        Crec[r+1,:] = ReconstructWeights(m,r)

    # Initialize linear weights
    dw = LinearWeights(m,0)

    # Compute smoothness indicator matrices
    beta = np.zeros((m,m,m))
    for r in range(m):
        xl = -1.0/2 + np.arange(-r,m-r+1)
        beta[:,:,r] = betarcalc(xl,m)

    # Integrate scheme
    while (t<FinalTime):
        # Decide on timestep
        maxvel = (2*np.abs(u)).max()
        k = min(CFL*h/maxvel,FinalTime-t)
        # Update solution
        rhsu  = BurgersWENOrhs1D(x,u,h,k,m,Crec,dw,beta,maxvel)
        u1 = u + k*rhsu
        rhsu  = BurgersWENOrhs1D(x,u1,h,k,m,Crec,dw,beta,maxvel) 
        u2 = (3*u + u1 + k*rhsu)/4
        rhsu  = BurgersWENOrhs1D(x,u2,h,k,m,Crec,dw,beta,maxvel)
        u = (u + 2*u2 + 2*k*rhsu)/3
        
        t += k
        tstep += 1

    return u

## 2D ##
###################################################################################
def BurgersWENOrhs2D(x,y,u,hx,hy,k,m,Crec,dw,beta,maxvel):
    """Evaluate right hand side for 2D Burgers equation 
        using a WENO method"""

    Nxy = x.shape
    Nx = Nxy[1]
    Ny = Nxy[0]

    du = np.zeros((Ny,Nx))

    # Extend data and assign boundary conditions in x-direction
    for i in range(Ny):
        xe,ue = extend(x[i,:],u[i,:],m,'P',0,'P',0);

        # define cell left and right interface values
        ul = np.zeros(Nx+2)
        ur = np.zeros(Nx+2)

        for j in range(Nx+2):
            ul[j],ur[j] = WENO(xe[j:1+(j+2*(m-1))],ue[j:1+(j+2*(m-1))],m,Crec,dw,beta)

        # Update residual
        du[i,:] = - (BurgersLF(ur[1:Nx+1],ul[2:Nx+2],maxvel) - \
                     BurgersLF(ur[:Nx],ul[1:Nx+1],maxvel))/hx;

    # Extend data and assign boundary conditions in y-direction
    for j in range(Nx):
        xe,ue = extend(y[:,j],u[:,j],m,'P',0,'P',0)

        # define cell left and right interface values
        ul = np.zeros(Ny+2)
        ur = np.zeros(Ny+2)

        for i in range(Ny+2):
            ul[i],ur[i] = WENO(xe[i:1+(i+2*(m-1))],ue[i:1+(i+2*(m-1))],m,Crec,dw,beta)

        # Update residual
        du[:,j] -= (BurgersLF(ur[1:Ny+1],ul[2:Ny+2],maxvel) - \
                    BurgersLF(ur[:Ny],ul[1:Ny+1],maxvel))/hy
    return du


def BurgersWENO2D(x,y,u,hx,hy,m,CFL,FinalTime):
    """Integrate 2D Burgers equation until FinalTime 
           using a WENO scheme."""

    t = 0.0
    tstep = 0.0
    delta = min(hx,hy)

    #Initialize reconstruction weights
    Crec = np.zeros((m+1,m))
    for r in range(-1,m):
        Crec[r+1,:] = ReconstructWeights(m,r)

    # Initialize linear weights
    dw = LinearWeights(m,0)

    # Compute smoothness indicator matrices
    beta = np.zeros((m,m,m))
    for r in range(m):
        xl = -1.0/2 + np.arange(-r,m-r+1)
        beta[:,:,r] = betarcalc(xl,m)

    # integrate scheme
    while (t<FinalTime):
        # Decide on timestep
        maxvel = (2*np.sqrt(2.0)*np.abs(u)).max()
        k = min(FinalTime-t, CFL*delta/maxvel/2)
  
        # Update solution
        rhsu  = BurgersWENOrhs2D(x,y,u,hx,hy,k,m,Crec,dw,beta,maxvel)
        u1 = u + k*rhsu;
        rhsu  = BurgersWENOrhs2D(x,y,u1,hx,hy,k,m,Crec,dw,beta,maxvel)
        u2 = (3*u + u1 + k*rhsu)/4;
        rhsu  = BurgersWENOrhs2D(x,y,u2,hx,hy,k,m,Crec,dw,beta,maxvel)
        u = (u + 2*u2 + 2*k*rhsu)/3;
        t += k
        tstep += 1

    return u