"""
In-built dynamical systems built on differential equations. Note that these systems are designed such that they
directly operzte with ngc-learn's ODE integration backend.
| Currently in-built dynamical systems include:
| 0) A continuous linear 2D system;
| 1) A continuous cubic 2D system;
| 2) A Lorenz attractor system;
| 3) A continuous linear 3D system;
| 4) A continuous oscillator system.
"""
import jax.numpy as jnp
[docs]
def linear_2D(t, x, params):
"""
* suggested init value - x0 = jnp.array([3, -1.5])
Args:
x: 2D vector
type: jax array
shape:(2,)
t: Unused
params: Unused
Returns:
2D vector:
[
-0.1 * x[0] + 2.0 * x[1],
-2.0 * x[0] - 0.1 * x[1]
]; type: jax array, shape:(2,)
"""
coeff = jnp.array([[-0.1, 2],
[-2, -0.1]]).T
dfx_ = jnp.matmul(x, coeff)
return dfx_
[docs]
def cubic_2D(t, x, params):
"""
suggested init value - x0 = jnp.array([2., 0.])
Args:
x: 2D vector
type: jax array
shape: (2,)
t: Unused
params: Unused
Returns:
2D vector:
[
-0.1 * x[0] ** 3 + 2.0 * x[1] ** 3,
-2.0 * x[0] ** 3 - 0.1 * x[1] ** 3,
]; type: jax array, shape:(2,)
"""
coeff = jnp.array([[-0.1, 2],
[-2, -0.1]]).T
dfx_ = jnp.matmul(x**3, coeff)
return dfx_
[docs]
def lorenz(t, x, params):
"""
suggested init value - x0 = jnp.array([-8, 7, 27])
Args:
x: 3D vector
type: jax array
shape: (3,)
t: Unused
params: Unused
Returns:
3D vector:
[
10 * (x[1] - x[0]),
x[0] * (28 - x[2]) - x[1],
x[0] * x[1] - 8 / 3 * x[2],
]; type: jax array, shape:(3,)
"""
x_ = x[..., 0]
y_ = x[..., 1]
z_ = x[..., 2]
dx = 10 * y_ - 10 * x_
dy = 28 * x_ - x_ * z_ - y_
dz = x_ * y_ - 8 / 3 * z_
return jnp.stack([dx, dy, dz], axis=-1)
[docs]
def linear_3D(t, x, params):
"""
suggested init value - x0 = jnp.array([1, 1., -1])
Args:
x: 3D vector
type: jax array
shape: (3,)
t: Unused
params: Unused
Returns:
3D vector:
[
-0.1 * x[0] + 2 * x[1],
-2 * x[0] - 0.1 * x[1],
-0.3 * x[2]
]; type: jax array, shape:(3,)
"""
x_ = x[..., 0]
y_ = x[..., 1]
z_ = x[..., 2]
dx = -0.1 * x_ + 2.0 * y_
dy = -2.0 * x_ - 0.1 * y_
dz = -0.3 * z_
return jnp.stack([dx, dy, dz], axis=-1)
[docs]
def oscillator(t, x, params, mu1=0.05, mu2=-0.01, omega=3.0, alpha=-2.0, beta=-5.0, sigma=1.1):
"""
suggested init value - x0 = jnp.array([0.5, 0.05, 0.1])
Args:
x: 3D vector
type: jax array
shape: (3,)
t: Unused
params: Unused
Returns:
3D vector:
[
mu1 * x[0] + sigma * x[0] * x[1],
mu2 * x[1] + (omega + alpha * x[1] + beta * x[2]) * x[2] - sigma * x[0] ** 2,
mu2 * x[2] - (omega + alpha * x[1] + beta * x[2]) * x[1],
]; type: jax array, shape:(3,)
"""
x_ = x[..., 0]
y_ = x[..., 1]
z_ = x[..., 2]
dx = mu1 * x_ + sigma * x_ * y_
dy = mu2 * y_ + (omega + alpha * y_ + beta * z_) * z_ - sigma * x_ ** 2
dz = mu2 *z_ - omega *y_ - alpha * y_*y_ - beta * z_* y_
return jnp.stack([dx, dy, dz], axis=-1)
## some testing/driver code to check the ODEs themselves
if __name__ == "__main__":
import matplotlib.pyplot as plt
from ngclearn.utils.diffeq.ode_utils import solve_ode
t0 = 0.
dt = 0.01
# 1. Linear 2D System
x0 = jnp.array([3, -1.5], dtype=jnp.float32)
t, x_lin2D = solve_ode('rk4', t0=t0, x0=x0, T=3000, dfx=linear_2D, dt=dt, params=None, sols_only=True)
# 2. Cubic 2D System
x0 = jnp.array([2, 0.], dtype=jnp.float32)
t, x_cub2D = solve_ode('rk4', t0=t0, x0=x0, T=10000, dfx=cubic_2D, dt=dt, params=None, sols_only=True)
# 3. Lorenz System (3D)
x0 = jnp.array([-8, 7, 27], dtype=jnp.float32)
t, x_lorenz = solve_ode('rk4', t0=t0, x0=x0, T=2000, dfx=lorenz, dt=dt, params=None, sols_only=True)
# 4. Linear 3D System
x0 = jnp.array([1, 1., -1], dtype=jnp.float32)
t, x_lin3D = solve_ode('rk4', t0=t0, x0=x0, T=10000, dfx=linear_3D, dt=dt, params=None, sols_only=True)
# 5. Oscillator System
x0 = jnp.array([0.5, 0.05, 0.1], dtype=jnp.float32)
t, x_osci = solve_ode('rk4', t0=t0, x0=x0, T=20000, dfx=oscillator, dt=dt, params=None, sols_only=True)
plt.plot(x_lin2D[:, 0], x_lin2D[:, 1], linewidth=2, color='darkorange', label=r'$linear-2D$')
plt.title('Linear 2D System', fontsize=20)
plt.xlabel('x', fontsize=20)
plt.ylabel('y', fontsize=20)
plt.grid(True)
plt.show()
plt.plot(x_cub2D[:, 0], x_cub2D[:, 1], linewidth=2, color='royalblue', label=r'cubic-2D$')
plt.title('Cubic 2D System', fontsize=20)
plt.xlabel('x', fontsize=20)
plt.ylabel('y', fontsize=20)
plt.grid(True)
plt.show()
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot(x_lorenz[:, 0], x_lorenz[:, 1], x_lorenz[:, 2], linewidth=1, color='red', label=r'$lorenz$')
ax.set_title('Lorenz System', fontsize=20)
ax.set_xlabel('x', fontsize=20)
ax.set_ylabel('y', fontsize=20)
ax.set_zlabel('z', fontsize=20)
plt.grid(True)
plt.show()
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot(x_lin3D[:, 0], x_lin3D[:, 1], x_lin3D[:, 2], linewidth=1, color='purple', label=r'linear-3D')
ax.set_title('Linear 3D System', fontsize=20)
ax.set_xlabel('x', fontsize=20)
ax.set_ylabel('y', fontsize=20)
ax.set_zlabel('z', fontsize=20)
plt.grid(True)
plt.tight_layout()
plt.show()
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot(x_osci[:, 0], x_osci[:, 1], x_osci[:, 2], linewidth=1, color='green', label=r'oscillator')
ax.set_title('Atmospheric Oscillator', fontsize=20)
ax.set_xlabel('x', fontsize=20)
ax.set_ylabel('y', fontsize=20)
ax.set_zlabel('z', fontsize=20)
plt.grid(True)
plt.show()