Source code for ngclearn.utils.optim.adam

from ngclearn.utils.optim.opt import Opt
import numpy as np
from jax import jit, numpy as jnp, random, nn, lax
from functools import partial
import time

@jit
def step_update(param, update, g1, g2, lr, beta1, beta2, time, eps):
    """
    Runs one step of Adam over a set of parameters given updates.
    The dynamics for any set of parameters is as follows:

    | g1 = beta1 * g1 + (1 - beta1) * update
    | g2 = beta2 * g2 + (1 - beta2) * (update)^2
    | g1_unbiased = g1 / (1 - beta1**time)
    | g2_unbiased = g2 / (1 - beta2**time)
    | param = param - lr * g1_unbiased / (sqrt(g2_unbiased) + epsilon)

    Args:
        param: parameter tensor to change/adjust

        update: update tensor to be applied to parameter tensor (must be same
            shape as "param")

        g1: first moment factor/correction factor to use in parameter update
            (must be same shape as "update")

        g2: second moment factor/correction factor to use in parameter update
            (must be same shape as "update")

        lr: global step size value to be applied to updates to parameters

        beta1: 1st moment control factor

        beta2: 2nd moment control factor

        time: current time t or iteration step/call to this Adam update

        eps: numberical stability coefficient (for calculating final update)

    Returns:
        adjusted parameter tensor (same shape as "param")
    """
    _g1 = beta1 * g1 + (1. - beta1) * update
    _g2 = beta2 * g2 + (1. - beta2) * jnp.square(update)
    g1_unb = _g1 / (1. - jnp.power(beta1, time))
    g2_unb = _g2 / (1. - jnp.power(beta2, time))
    _param = param - lr * g1_unb/(jnp.sqrt(g2_unb) + eps)
    return _param, _g1, _g2

[docs] class Adam(Opt): """ Implements the adaptive moment estimation (Adam) algorithm as a decoupled update rule given adjustments produced by a credit assignment algorithm/process. Args: learning_rate: step size coefficient for Adam update beta1: 1st moment control factor beta2: 2nd moment control factor epsilon: numberical stability coefficient (for calculating final update) """ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8): super().__init__(name="adam") self.eta = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.eps = epsilon self.g1 = [] self.g2 = [] #self.time = 0.
[docs] def update(self, theta, updates): ## apply adjustment to theta if self.time <= 0.: ## init statistics for i in range(len(theta)): self.g1.append(jnp.zeros(theta[i].shape)) self.g2.append(jnp.zeros(theta[i].shape)) self.time += 1 for i in range(len(theta)): px_i, g1_i, g2_i = step_update(theta[i], updates[i], self.g1[i], self.g2[i], self.eta, self.beta1, self.beta2, self.time, self.eps) theta[i] = px_i self.g1[i] = g1_i self.g2[i] = g2_i