Source code for ngclearn.utils.optim.sgd

from jax import jit, numpy as jnp

[docs] def step_update(param, update, eta): """ Runs one step of SGD over a set of parameters given updates. Args: eta: global step size to apply when adjusting parameters Returns: adjusted parameter tensor (same shape as "param") """ _param = param - update * eta return _param
[docs] @jit def sgd_step(opt_params, theta, updates, eta=0.001): ## apply adjustment to theta """ Returns updated parameters in accordance to a stochastic gradient descent (SGD) recipe Args: opt_params: (ArrayLike) parameters of the optimization algorithm theta: (ArrayLike) the weights of neural networks updates: (ArrayLike) the updates of neural networks eta: (float, optional) hyperparams. Defaults to 0.001. Returns: ArrayLike: opt_params. New opt params, ArrayLike: theta. The updated weights """ time_step = opt_params time_step = time_step + 1 new_theta = [] for i in range(len(theta)): px_i = step_update(theta[i], updates[i], eta) new_theta.append(px_i) new_opt_params = time_step return new_opt_params, new_theta
[docs] @jit def sgd_init(theta): return jnp.asarray(0.0)
if __name__ == '__main__': opt_params, theta = sgd_step((2.0), [1.0, 1.0], [3.0, 4.0], 3e-2) print(f"opt_params: {opt_params}, theta: {theta}")