Source code for ngclearn.components.synapses.modulated.REINFORCESynapse

# %%

from jax import random, numpy as jnp, jit
from ngclearn import compilable, Compartment

from ngclearn.utils.model_utils import clip, d_clip
import jax
#import numpy as np

from ngclearn.components.synapses import DenseSynapse
from ngclearn.utils import tensorstats
from ngclearn.utils.model_utils import create_function

def _gaussian_logpdf(event, mean, stddev):
  scale_sqrd = stddev ** 2
  log_normalizer = jnp.log(2 * jnp.pi * scale_sqrd)
  quadratic = (event - mean)**2 / scale_sqrd
  return - 0.5 * (log_normalizer + quadratic)


def _compute_update(
        dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev
):
    learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32)
    # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
    W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
    # Forward pass
    activation = act_fx(inputs)
    mean = activation @ W_mu
    fx_mean = mu_act_fx(mean)
    logstd = activation @ W_logstd
    clip_logstd = clip(logstd, -10.0, 2.0)
    std = jnp.exp(clip_logstd)
    std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick
    # Sample using reparameterization trick
    epsilon = jax.random.normal(seed, fx_mean.shape)
    sample = epsilon * std + fx_mean
    sample = jnp.clip(sample, mu_out_min, mu_out_max)
    outputs = sample # the actual action that we take
    # Compute log probability density of the Gaussian
    log_prob = _gaussian_logpdf(sample, fx_mean, std).sum(-1)
    # Compute objective (negative REINFORCE objective)
    objective = (-log_prob * rewards).mean() * 1e-2

    # Backward pass
    batch_size = inputs.shape[0] # B
    dL_dlogp = -rewards[:, None] * 1e-2 / batch_size # (B, 1)

    # Compute gradients manually based on the derivation
    # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
    dlog_prob_dfxmean = (sample - fx_mean) / (std ** 2)
    dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx(mean) # (B, A)
    dL_dWmu = activation.T @ dL_dmean

    # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
    dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean)**2 / std**3
    dL_dstd = dL_dlogp * dlog_prob_dlogstd
    # Apply gradient clipping for logstd
    dL_dlogstd = d_clip(logstd, -10.0, 2.0) * dL_dstd * std
    dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
    dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev

    # Update weights, negate the gradient because gradient ascent in ngc-learn
    dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1)
    # Finally, return metrics if needed
    return dW, objective, outputs


[docs] class REINFORCESynapse(DenseSynapse): """ A stochastic synapse implementing the REINFORCE algorithm (policy gradient method). This synapse uses Gaussian distributions for generating actions and performs gradient-based updates. | --- Synapse Compartments: --- | inputs - input (takes in external signals) | outputs - output signals (sampled actions from Gaussian distribution) | weights - current value matrix of synaptic efficacies (contains both mean and log-std parameters) | dWeights - current delta matrix containing changes to be applied to synaptic efficacies | rewards - reward signals used to modulate weight updates (takes in external signals) | objective - scalar value of the current loss/objective | accumulated_gradients - exponential moving average of gradients for tracking learning progress | step_count - counter for number of learning steps | learning_mask - binary mask determining when learning occurs | seed - JAX PRNG key for random sampling Args: name: the string name of this component shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of inputs by number of outputs) eta: learning rate for weight updates (Default: 1e-4) decay: decay factor for computing exponential moving average of gradients (Default: 0.99) weight_init: a kernel to drive initialization of this synaptic cable's values; typically a tuple with 1st element as a string calling the name of initialization to use resist_scale: a fixed scaling factor to apply to synaptic transform (Default: 1.) act_fx: activation function to apply to inputs (Default: "identity") p_conn: probability of a connection existing (default: 1.); setting this to < 1. will result in a sparser synaptic structure w_bound: upper bound for weight clipping (Default: 1.) batch_size: batch size dimension of this component (Default: 1) seed: random seed for reproducibility (Default: 42) mu_act_fx: activation function to apply to the mean of the Gaussian distribution (Default: "identity") """ # Define Functions def __init__( self, name, shape, eta=1e-4, decay=0.99, weight_init=None, resist_scale=1., act_fx=None, p_conn=1., w_bound=1., batch_size=1, seed=None, mu_act_fx=None, mu_out_min=-jnp.inf, mu_out_max=jnp.inf, scalar_stddev=-1.0, **kwargs ) -> None: # This is because we have weights mu and weight log sigma input_dim, output_dim = shape super().__init__( name, (input_dim, output_dim * 2), weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs ) ## Synaptic hyper-parameters self.shape = shape ## shape of synaptic efficacy matrix self.Rscale = resist_scale ## post-transformation scale factor self.w_bound = w_bound #1. ## soft weight constraint self.eta = eta ## learning rate # self.out_min = out_min # self.out_max = out_max self.mu_act_fx, self.dmu_act_fx = create_function(mu_act_fx if mu_act_fx is not None else "identity") self.mu_out_min = mu_out_min self.mu_out_max = mu_out_max self.scalar_stddev = scalar_stddev ## Compartment setup self.dWeights = Compartment(self.weights.get() * 0) # self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate # For eligiblity traces later self.objective = Compartment(jnp.zeros(())) self.outputs = Compartment(jnp.zeros((batch_size, output_dim))) self.rewards = Compartment(jnp.zeros((batch_size,))) # the normalized reward (r - r_hat), input compartment self.act_fx, self.dact_fx = create_function(act_fx if act_fx is not None else "identity") self.accumulated_gradients = Compartment(jnp.zeros((input_dim, output_dim * 2))) self.decay = decay self.step_count = Compartment(jnp.zeros(())) self.learning_mask = Compartment(jnp.zeros(())) self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42))
[docs] @compilable def evolve(self, dt): # Get compartment values weights = self.weights.get() dWeights = self.dWeights.get() objective = self.objective.get() outputs = self.outputs.get() accumulated_gradients = self.accumulated_gradients.get() step_count = self.step_count.get() seed = self.seed.get() inputs = self.inputs.get() rewards = self.rewards.get() # Main logic main_seed, sub_seed = jax.random.split(seed) dWeights, objective, outputs = _compute_update( dt, inputs, rewards, self.act_fx, weights, sub_seed, self.mu_act_fx, self.dmu_act_fx, self.mu_out_min, self.mu_out_max, self.scalar_stddev ) ## do a gradient ascent update/shift weights = (weights + dWeights * self.eta) * self.learning_mask + weights * (1.0 - self.learning_mask.get()) # update the weights only where learning_mask is 1.0 ## enforce non-negativity eps = 0.0 # 0.01 # 0.001 weights = jnp.clip(weights, eps, self.w_bound - eps) # jnp.abs(w_bound)) step_count += 1 accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * self.decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients step_count = step_count * (1 - self.learning_mask.get()) # reset the step count to 0 when we have learned # Set updated compartment values self.weights.set(weights) self.dWeights.set(dWeights) self.objective.set(objective) self.outputs.set(outputs) self.accumulated_gradients.set(accumulated_gradients) self.step_count.set(step_count) self.seed.set(main_seed)
[docs] @compilable def reset(self): preVals = jnp.zeros((self.batch_size, self.shape[0])) postVals = jnp.zeros((self.batch_size, self.shape[1])) inputs = preVals outputs = postVals objective = jnp.zeros(()) rewards = jnp.zeros((self.batch_size,)) dWeights = jnp.zeros(self.shape) accumulated_gradients = jnp.zeros((self.shape[0], self.shape[1] * 2)) step_count = jnp.zeros(()) seed = jax.random.PRNGKey(42) hasattr(self.inputs, 'targeted') and not self.inputs.targeted and self.inputs.set(inputs) self.outputs.set(outputs) self.objective.set(objective) self.rewards.set(rewards) self.dWeights.set(dWeights) self.accumulated_gradients.set(accumulated_gradients) self.step_count.set(step_count) self.seed.set(seed)
[docs] @classmethod def help(cls): ## component help function properties = { "synapse_type": "REINFORCESynapse - implements a stochastic synaptic cable that uses " "the REINFORCE algorithm (policy gradient) to update weights based on rewards" } compartment_props = { "inputs": {"inputs": "Takes in external input signal values", "rewards": "Takes in reward signals for modulating weight updates. The reward is often normalized by baseline reward (r - r_hat)"}, "states": {"weights": "Synapse efficacy/strength parameter values (mean and log-std)", "dWeights": "Weight update values", "accumulated_gradients": "EMA of gradients over time", "step_count": "Counter for learning steps", "learning_mask": "Binary mask determining when learning occurs", "seed": "a single integer as initial jax PRNG key for this component"}, "outputs": {"outputs": "Output samples from Gaussian distribution", "objective": "Current value of the loss/objective function"}, } hyperparams = { "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", "eta": "Learning rate for weight updates", "decay": "Decay factor for EMA of gradients", "weight_init": "Initialization conditions for synaptic weight values", "resist_scale": "Resistance level scaling factor applied to output", "act_fx": "Activation function to apply to inputs", "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", "w_bound": "Upper bound for weight clipping", "batch_size": "Batch size dimension of this component", "seed": "Random seed for reproducibility", "mu_act_fx": "Activation function to apply to the mean of the Gaussian distribution" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "mean = act_fx(inputs) @ W_mu; fx_mean = mu_act_fx(mean); logstd = act_fx(inputs) @ W_logstd; " "outputs ~ N(fx_mean, exp(logstd)); " "dW = -grad_reinforce(rewards, log_prob(outputs)). ", "Check compute_update() for more details." "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: syn = REINFORCESynapse( name="reinforce_syn", shape=(3, 2) ) # Wab = syn.weights.get() print(syn)