# %%
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, jit
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
[docs]
class RewardErrorCell(JaxComponent): ## Reward prediction error cell
"""
A reward prediction error (RPE) cell.
| --- Cell Input Compartments: ---
| reward - current reward signal at time `t`
| accum_reward - current accumulated episodic reward signal
| --- Cell Output Compartments: ---
| mu - current moving average prediction of reward at time `t`
| rpe - current reward prediction error (RPE) signal
| accum_reward - current accumulated episodic reward signal (IF online predictor not used)
Args:
name: the string name of this cell
n_units: number of cellular entities (neural population size)
alpha: decay factor to apply to (exponential) moving average prediction
ema_window_len: exponential moving average window length -- for use only
in `evolve` step for updating episodic reward signals; (default: 10)
use_online_predictor: use online prediction of reward signal (default: True)
-- if set to False, then reward prediction will only occur upon a call
to this cell's `evolve` function
"""
def __init__(self, name, n_units, alpha, ema_window_len=10,
use_online_predictor=True, batch_size=1, **kwargs):
super().__init__(name, **kwargs)
## RPE meta-parameters
self.alpha = alpha
self.ema_window_len = ema_window_len
self.use_online_predictor = use_online_predictor
## Layer Size Setup
self.n_units = n_units
self.batch_size = batch_size
## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.mu = Compartment(restVals) ## reward predictor state(s)
self.reward = Compartment(restVals) ## target reward signal(s)
self.rpe = Compartment(restVals) ## reward prediction error(s)
self.accum_reward = Compartment(restVals) ## accumulated reward signal(s)
self.n_ep_steps = Compartment(jnp.zeros((self.batch_size, 1))) ## number of episode steps taken
[docs]
@compilable
def advance_state(self, dt):
# Get the variables
mu = self.mu.get()
reward = self.reward.get()
n_ep_steps = self.n_ep_steps.get()
accum_reward = self.accum_reward.get()
## compute/update RPE and predictor values
accum_reward = accum_reward + reward
rpe = reward - mu
if self.use_online_predictor:
mu = mu * (1. - self.alpha) + reward * self.alpha
n_ep_steps = n_ep_steps + 1
# Update compartments
self.mu.set(mu)
self.rpe.set(rpe)
self.n_ep_steps.set(n_ep_steps)
self.accum_reward.set(accum_reward)
[docs]
@compilable
def evolve(self, dt):
# Get the variables
mu = self.mu.get()
n_ep_steps = self.n_ep_steps.get()
accum_reward = self.accum_reward.get()
if self.use_online_predictor:
## total episodic reward signal
r = accum_reward/n_ep_steps
mu = (1. - 1./self.ema_window_len) * mu + (1./self.ema_window_len) * r
# Update compartment
self.mu.set(mu)
[docs]
@compilable
def reset(self): ## reset core components/statistics
restVals = jnp.zeros((self.batch_size, self.n_units))
mu = restVals
rpe = restVals
accum_reward = restVals
n_ep_steps = jnp.zeros((self.batch_size, 1))
self.mu.set(mu)
self.rpe.set(rpe)
self.accum_reward.set(accum_reward)
self.n_ep_steps.set(n_ep_steps)
[docs]
@classmethod
def help(cls): ## component help function
properties = {
"cell_type": "RewardErrorCell - computes the reward prediction error "
"at each time step `t`; this is an online RPE estimator"
}
compartment_props = {
"inputs":
{"reward": "External reward signals/values"},
"outputs":
{"mu": "Current state of reward predictor",
"rpe": "Current value of reward prediction error at time `t`",
"accum_reward": "Current accumulated episodic reward signal (generally "
"produced at the end of a control episode of `n_steps`)",
"n_ep_steps": "Number of episodic steps taken/tracked thus far "
"(since last `reset` call)",
"use_online_predictor": "Should an online reward predictor be used/maintained?"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"alpha": "Moving average decay factor",
"ema_window_len": "Exponential moving average window length",
"batch_size": "Batch size dimension of this component"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "rpe = reward - mu; mu = mu * (1 - alpha) + reward * alpha; "
"accum_reward = accum_reward + reward",
"hyperparameters": hyperparams}
return info
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
X = RewardErrorCell("X", 9, 0.03)
print(X)