Source code for ngclearn.utils.optim.adam

# %%

import numpy as np
from jax import jit, numpy as jnp, random, nn, lax
from functools import partial


[docs] def step_update(param, update, g1, g2, eta, beta1, beta2, time_step, eps): """ Runs one step of Adam over a set of parameters given updates. The dynamics for any set of parameters is as follows: | g1 = beta1 * g1 + (1 - beta1) * update | g2 = beta2 * g2 + (1 - beta2) * (update)^2 | g1_unbiased = g1 / (1 - beta1**time) | g2_unbiased = g2 / (1 - beta2**time) | param = param - lr * g1_unbiased / (sqrt(g2_unbiased) + epsilon) Args: param: parameter tensor to change/adjust update: update tensor to be applied to parameter tensor (must be same shape as "param") g1: first moment factor/correction factor to use in parameter update (must be same shape as "update") g2: second moment factor/correction factor to use in parameter update (must be same shape as "update") eta: global step size value to be applied to updates to parameters beta1: 1st moment control factor beta2: 2nd moment control factor time_step: current time t or iteration step/call to this Adam update eps: numberical stability coefficient (for calculating final update) Returns: adjusted parameter tensor (same shape as "param"), adjusted g1, adjusted g2 """ _g1 = beta1 * g1 + (1. - beta1) * update _g2 = beta2 * g2 + (1. - beta2) * jnp.square(update) g1_unb = _g1 / (1. - jnp.power(beta1, time_step)) g2_unb = _g2 / (1. - jnp.power(beta2, time_step)) _param = param - eta * g1_unb/(jnp.sqrt(g2_unb) + eps) return _param, _g1, _g2
[docs] @jit def adam_step(opt_params, theta, updates, eta=0.001, beta1=0.9, beta2=0.999, eps=1e-8): ## apply adjustment to theta """Implements the adaptive moment estimation (Adam) algorithm as a decoupled update rule given adjustments produced by a credit assignment algorithm/process. Args: opt_params: (ArrayLike) parameters of the optimization algorithm theta: (ArrayLike) the weights of neural network updates: (ArrayLike) the updates of neural network eta: (float, optional) step size coefficient for Adam update (Default: 0.001) beta1: (float, optional) 1st moment control factor. (Default: 0.9) beta2: (float, optional) 2nd moment control factor. (Default: 0.999) eps: (float, optional) numberical stability coefficient (for calculating final update). (Default: 1e-8) Returns: ArrayLike: opt_params. New opt params, ArrayLike: theta. The updated weights """ g1, g2, time_step = opt_params time_step = time_step + 1 new_theta = [] new_g1 = [] new_g2 = [] for i in range(len(theta)): px_i, g1_i, g2_i = step_update(theta[i], updates[i], g1[i], g2[i], eta, beta1, beta2, time_step, eps) new_theta.append(px_i) new_g1.append(g1_i) new_g2.append(g2_i) return (new_g1, new_g2, time_step), new_theta
[docs] @jit def adam_init(theta): time_step = jnp.asarray(0.0) g1 = [jnp.zeros(theta[i].shape) for i in range(len(theta))] g2 = [jnp.zeros(theta[i].shape) for i in range(len(theta))] return g1, g2, time_step
if __name__ == '__main__': weights = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])] updates = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])] opt_params = adam_init(weights) opt_params, theta = adam_step(opt_params, weights, updates) print(f"opt_params: {opt_params}, theta: {theta}")