"""
Routines and co-routines for ngc-learn's differential equation integration backend.
| Currently supported back-end forms of integration in ngc-learn include:
| 0) Euler integration (RK-1);
| 1) Midpoint method (RK-2);
| 2) Heun's method (error-corrector RK-2);
| 3) Ralston's method (error-corrector RK-2);
| 4) 4th-order Runge-Kutta method (RK-4);
"""
from jax import numpy as jnp, random, jit #, nn
from functools import partial
from jax.lax import scan as _scan
import time, sys
[docs]
def get_integrator_code(integrationType): ## integrator type decoding routine
"""
Convenience function for mapping integrator type string to ngc-learn's
internal integer code value.
Args:
integrationType: string indicating integrator type
(supported type: rk1` or `euler`, `rk2` or `midpoint`,
`rk2_heun` or `heun`, `rk2_ralston` or `ralston`, `rk4`)
Returns:
integator type integer code
"""
intgFlag = 0 ## Default is Euler (RK1)
if integrationType == "midpoint" or integrationType == "rk2": ## midpoint method
intgFlag = 1
elif integrationType == "rk2_heun" or integrationType == "heun": ## Heun's method
intgFlag = 2
elif integrationType == "rk2_ralston" or integrationType == "ralston": ## Ralston's method
intgFlag = 3
elif integrationType == "rk4": ## Runge-Kutte 4rd order code
intgFlag = 4
else:
if integrationType != "euler" or integrationType == "rk1":
print("ERROR: unrecognized integration method {} provided! Defaulting \
to RK-1/Euler routine".format(integrationType))
return intgFlag
@jit
def _sum_combine(*args, **kwargs): ## fast co-routine for simple addition/summation
_sum = 0
for arg, val in zip(args, kwargs.values()): ## Sigma^I_{i=1} a_i
_sum = _sum + val * arg
return _sum
@jit
def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine
_t = t + dt ## advance time forward by dt (denominator)
_x = x * x_scale + dx_dt * dt ## advance variable(s) forward by dt (numerator)
return _t, _x
[docs]
@partial(jit, static_argnums=(2))
def step_euler(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the Euler method, i.e., a
first-order Runge-Kutta (RK-1) step.
Args:
t: current time variable to advance by dt
x: current variable values to advance/iteratively integrate (at time `t`)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
carry = (t, x)
next_state, *_ = _euler(carry, dfx, dt, params, x_scale=x_scale)
_t, _x = next_state
return _t, _x
@partial(jit, static_argnums=(1))
def _euler(carry, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the Euler method, i.e., a
first-order Runge-Kutta (RK-1) step.
Args:
carry: a tuple containing current time and data, i.e., (t, x)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
t, x = carry
dx_dt = dfx(t, x, params)
_t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
new_carry = (_t, _x)
return new_carry, (new_carry, carry)
@partial(jit, static_argnums=(1))
def _leapfrog(carry, dfq, dt, params):
t, q, p = carry
dq_dt = dfq(t, q, params)
_p = p + dq_dt * (dt/2.)
_q = q + p * dt
dq_dtpdt = dfq(t+dt, _q, params)
_p = _p + dq_dtpdt * (dt/2.)
_t = t + dt
new_carry = (_t, _q, _p)
return new_carry, (new_carry, carry)
[docs]
@partial(jit, static_argnums=(3, 4))
def leapfrog(t_curr, q_curr, p_curr, dfq, L, step_size, params): ## leapfrog estimator step
t = t_curr + 0.
q = q_curr + 0.
p = p_curr + 0.
def scanner(carry, _):
return _leapfrog(carry, dfq, step_size, params)
new_values, (xs_next, xs_carry) = _scan(scanner, init=(t, q, p), xs=jnp.arange(L))
t, q, p = new_values
return t, q, p
[docs]
@partial(jit, static_argnums=(2))
def step_heun(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via Heun's method, i.e., a
second-order Runge-Kutta (RK-2) error-corrected step. This method utilizes
two (differential) function evaluations to estimate the solution at a given
point in time.
(Note: ngc-learn internally recognizes "rk2_heun" or "heun" for this routine)
| Reference:
| Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
| differential equations and differential-algebraic equations. Society for
| Industrial and Applied Mathematics, 1998.
Args:
t: current time variable to advance by dt
x: current variable values to advance/iteratively integrate (at time `t`)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
carry = (t, x)
next_state, *_ = _heun(carry, dfx, dt, params, x_scale=x_scale)
_t, _x = next_state
return _t, _x
@partial(jit, static_argnums=(1))
def _heun(carry, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via Heun's method, i.e., a
second-order Runge-Kutta (RK-2) error-corrected step. This method utilizes
two (differential) function evaluations to estimate the solution at a given
point in time.
(Note: ngc-learn internally recognizes "rk2_heun" or "heun" for this routine)
| Reference:
| Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
| differential equations and differential-algebraic equations. Society for
| Industrial and Applied Mathematics, 1998.
Args:
carry: a tuple containing current time and data, i.e., (t, x)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
t, x = carry
dx_dt = dfx(t, x, params)
_t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
_dx_dt = dfx(_t, _x, params)
summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1)
_, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale)
new_carry = (_t, _x)
return new_carry, (new_carry, carry)
[docs]
@partial(jit, static_argnums=(2))
def step_rk2(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the midpoint method, i.e., a
second-order Runge-Kutta (RK-2) step.
(Note: ngc-learn internally recognizes "rk2" or "midpoint" for this routine)
| Reference:
| Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
| differential equations and differential-algebraic equations. Society for
| Industrial and Applied Mathematics, 1998.
Args:
t: current time variable to advance by dt
x: current variable values to advance/iteratively integrate (at time `t`)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
carry = (t, x)
next_state, *_ = _rk2(carry, dfx, dt, params, x_scale=x_scale)
_t, _x = next_state
return _t, _x
@partial(jit, static_argnums=(1))
def _rk2(carry, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the midpoint method, i.e., a
second-order Runge-Kutta (RK-2) step.
(Note: ngc-learn internally recognizes "rk2" or "midpoint" for this routine)
| Reference:
| Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
| differential equations and differential-algebraic equations. Society for
| Industrial and Applied Mathematics, 1998.
Args:
carry: a tuple containing current time and data, i.e., (t, x)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
t, x = carry
f_1 = dfx(t, x, params)
t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale)
f_2 = dfx(t1, x1, params)
_t, _x = _step_forward(t, x, f_2, dt, x_scale)
new_carry = (_t, _x)
return new_carry, (new_carry, carry)
[docs]
@partial(jit, static_argnums=(2))
def step_rk4(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the midpoint method, i.e., a
fourth-order Runge-Kutta (RK-4) step.
(Note: ngc-learn internally recognizes "rk4" for this routine)
| Reference:
| Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
| differential equations and differential-algebraic equations. Society for
| Industrial and Applied Mathematics, 1998.
Args:
t: current time variable to advance by dt
x: current variable values to advance/iteratively integrate (at time `t`)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
carry = (t, x)
next_state, *_ = _rk4(carry, dfx, dt, params, x_scale=x_scale)
_t, _x = next_state
return _t, _x
@partial(jit, static_argnums=(1))
def _rk4(carry, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via the midpoint method, i.e., a
fourth-order Runge-Kutta (RK-4) step.
(Note: ngc-learn internally recognizes "rk4" or this routine)
| Reference:
| Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
| differential equations and differential-algebraic equations. Society for
| Industrial and Applied Mathematics, 1998.
Args:
carry: a tuple containing current time and data, i.e., (t, x)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
t, x = carry
## carry out 4 steps of RK-4
dfx_1 = dfx(t, x, params) ## k1
t2, x2 = _step_forward(t, x, dfx_1, dt * 0.5, x_scale)
dfx_2 = dfx(t2, x2, params) ## k2
t3, x3 = _step_forward(t, x, dfx_2, dt * 0.5, x_scale)
dfx_3 = dfx(t3, x3, params) ## k3
t4, x4 = _step_forward(t, x, dfx_3, dt, x_scale)
dfx_4 = dfx(t4, x4, params) ## k4
## produce final estimate and move forward
_dx_dt = _sum_combine(dfx_1, dfx_2, dfx_3, dfx_4, w_f1=1, w_f2=2, w_f3=2, w_f4=1)
_t, _x = _step_forward(t, x, _dx_dt / 6, dt, x_scale)
new_carry = (_t, _x)
return new_carry, (new_carry, carry)
[docs]
@partial(jit, static_argnums=(2))
def step_ralston(t, x, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via Ralston's method, i.e., a
second-order Runge-Kutta (RK-2) error-corrected step. This method utilizes
two (differential) function evaluations to estimate the solution at a given
point in time.
(Note: ngc-learn internally recognizes "rk2_ralston" or "ralston" for this
routine)
| Reference:
| Ralston, Anthony. "Runge-Kutta methods with minimum error bounds."
| Mathematics of computation 16.80 (1962): 431-437.
Args:
t: current time variable to advance by dt
x: current variable values to advance/iteratively integrate (at time `t`)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
carry = (t, x)
next_state, *_ = _ralston(carry, dfx, dt, params, x_scale=x_scale)
_t, _x = next_state
return _t, _x
@partial(jit, static_argnums=(1))
def _ralston(carry, dfx, dt, params, x_scale=1.):
"""
Iteratively integrates one step forward via Ralston's method, i.e., a
second-order Runge-Kutta (RK-2) error-corrected step. This method utilizes
two (differential) function evaluations to estimate the solution at a given
point in time.
(Note: ngc-learn internally recognizes "rk2_ralston" or "ralston" for this
routine)
| Reference:
| Ralston, Anthony. "Runge-Kutta methods with minimum error bounds."
| Mathematics of computation 16.80 (1962): 431-437.
Args:
carry: a tuple containing current time and data, i.e., (t, x)
dfx: (ordinary) differential equation co-routine (as implemented in an
ngc-learn component)
dt: integration time step (also referred to as `h` in mathematics)
params: tuple containing configuration values/hyper-parameters for the
(ordinary) differential equation an ngc-learn component will provide
x_scale: dampening factor to scale `x` by (Default: 1)
Returns:
variable values iteratively integrated/advanced to next step (`t + dt`)
"""
t, x = carry
dx_dt = dfx(t, x, params) ## k1
tm, xm = _step_forward(t, x, dx_dt, dt * 0.75, x_scale)
_dx_dt = dfx(tm, xm, params) ## k2
## Note: new step is a weighted combination of k1 and k2
summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=(1./3.), weight2=(2./3.))
_t, _x = _step_forward(t, x, summed_dx_dt, dt, x_scale)
new_carry = (_t, _x)
return new_carry, (new_carry, carry)
[docs]
@partial(jit, static_argnums=(0, 3, 4, 5, 6, 7, 8))
def solve_ode(method_name, t0, x0, T, dfx, dt, params=None, x_scale=1., sols_only=True):
if method_name =='euler':
method = _euler
elif method_name == 'heun':
method = _heun
elif method_name == 'rk2':
method = _rk2
elif method_name =='rk4':
method = _rk4
elif method_name =='ralston':
method = _ralston
def scanner(carry, _):
return method(carry, dfx, dt, params, x_scale)
x_T, (xs_next, xs_carry) = _scan(scanner, init=(t0, x0), xs=jnp.arange(T))
if not sols_only:
return x_T, xs_next, xs_carry
return xs_next
########################################################################################
########################################################################################
if __name__ == '__main__':
import matplotlib.pyplot as plt
from odes import linear_2D
dfx = linear_2D
x0 = jnp.array([3, -1.5])
dt = 1e-2
t0 = 0.
T = 800
(t_final, x_final), (ts_sol, sol_euler), (ts_carr, xs_carr) = solve_ode('euler', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
(_, x_final), (_, sol_heun), (_, xs_carr) = solve_ode('heun', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
(_, x_final), (_, sol_rk2), (_, xs_carr) = solve_ode('rk2', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
(_, x_final), (_, sol_rk4), (_, xs_carr) = solve_ode('rk4', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
(_, x_final), (_, sol_ralston), (_, xs_carr) = solve_ode('ralston', t0, x0, T=T, dfx=dfx, dt=dt, params=None, sols_only=False)
plt.plot(ts_sol, sol_euler[:, 0], label='x0-Euler')
plt.plot(ts_sol, sol_heun[:, 0], label='x0-Heun')
plt.plot(ts_sol, sol_rk2[:, 0], label='x0-RK2')
plt.plot(ts_sol, sol_rk4[:, 0], label='x0-RK4')
plt.plot(ts_sol, sol_ralston[:, 0], label='x0-Ralston')
plt.plot(ts_sol, sol_euler[:, 1], label='x1-Euler')
plt.plot(ts_sol, sol_heun[:, 1], label='x1-Heun')
plt.plot(ts_sol, sol_rk2[:, 1], label='x1-RK2')
plt.plot(ts_sol, sol_rk4[:, 1], label='x1-RK4')
plt.plot(ts_sol, sol_ralston[:, 1], label='x1-Ralston')
plt.legend(loc='best')
plt.grid()
#plt.show()
plt.savefig("integrator_plot.jpg")