Source code for ngclearn.components.neurons.graded.leakyNoiseCell

from jax import numpy as jnp, random, jit
#from ngcsimlib.logger import info
from ngcsimlib import deprecate_args
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.model_utils import create_function
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4

def _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale): ## raw dynamics ODE
    dz_dt = -(z * leak_scale) + (j_recurrent + j_input) + jnp.sqrt(2. * tau_x * jnp.square(sigma_rec)) * eps
    return dz_dt * (1. / tau_x)

def _dfz(t, z, params): ## raw dynamics ODE wrapper
    j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale = params
    return _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale)

[docs] class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell """ A non-spiking cell driven by the gradient dynamics entailed by a continuous-time noisy, leaky recurrent state. Reference: https://pmc.ncbi.nlm.nih.gov/articles/PMC4771709/ The specific differential equation that characterizes this cell is (for adjusting x) is: | tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_pre)^2) * eps; and, | r = f(x) + (eps * sigma_post). | where j_in is the set of incoming input signals | and j_rec is the set of recurrent input signals | and eps is a sample of unit Gaussian noise, i.e., eps ~ N(0, 1) | and f(x) is the rectification function | and sigma_pre is the pre-rectification noise applied to membrane x | and sigma_post is the post-rectification noise applied to rates f(x) | --- Cell Input Compartments: --- | j_input - input (bottom-up) electrical/stimulus current (takes in external signals) | j_recurrent - recurrent electrical/stimulus pressure | --- Cell State Compartments --- | x - noisy rate activity / current value of state | --- Cell Output Compartments: --- | r - post-rectified activity, e.g., fx(x) = relu(x) | r_prime - post-rectified temporal derivative, e.g., dfx(x) = d_relu(x) Args: name: the string name of this cell n_units: number of cellular entities (neural population size) tau_x: state membrane time constant (milliseconds) act_fx: rectification function (Default: "relu") output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.) integration_type: type of integration to use for this cell's dynamics; current supported forms include "euler" (Euler/RK-1 integration) and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") :Note: setting the integration type to the midpoint method will increase the accuracy of the estimate of the cell's evolution at an increase in computational cost (and simulation time) sigma_pre: pre-rectification noise scaling factor / standard deviation (Default: 0.1) sigma_post: post-rectification noise scaling factor / standard deviation (Default: 0.) leak_scale: degree to which membrane leak should be scaled (Default: 1) """ @deprecate_args(sigma_rec="sigma_pre") def __init__( self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_pre=0.1, sigma_post=0.1, leak_scale=1., shape=None, **kwargs ): super().__init__(name, **kwargs) self.tau_x = tau_x self.sigma_pre = sigma_pre ## a pre-rectification scaling factor self.sigma_post = sigma_post ## a post-rectification scaling factor self.leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1) ## integration properties self.integrationType = integration_type self.intgFlag = get_integrator_code(self.integrationType) ## Layer size setup _shape = (batch_size, n_units) ## default shape is 2D/matrix if shape is None: shape = (n_units,) ## we set shape to be equal to n_units if nothing provided else: _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor self.shape = shape self.n_units = n_units self.batch_size = batch_size self.fx, self.dfx = create_function(fun_name=act_fx) # compartments (state of the cell & parameters will be updated through stateless calls) restVals = jnp.zeros(_shape) self.j_input = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current self.j_recurrent = Compartment(restVals, display_name="Recurrent Stimulus Current", units="mA") # electrical current self.x = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity self.r = Compartment(restVals, display_name="(Rectified) Rate Activity") # rectified output self.r_prime = Compartment(restVals, display_name="Derivative of rate activity")
[docs] @compilable def advance_state(self, t, dt): ## run a step of integration over neuronal dynamics ### Note: self.fx is the "rectifier" (rectification function) key, skey = random.split(self.key.get(), 2) eps_pre = random.normal(skey, shape=self.x.get().shape) ## pre-rectifier distributional noise key, skey = random.split(self.key.get(), 2) eps_post = random.normal(skey, shape=self.x.get().shape) ## post-rectifier distributional noise _step_fns = { 0: step_euler, 1: step_rk2, 2: step_rk4, } _step_fn = _step_fns[self.intgFlag] #_step_fns.get(self.intgFlag, step_euler) params = (self.j_input.get(), self.j_recurrent.get(), eps_pre, self.tau_x, self.sigma_pre, self.leak_scale) _, x = _step_fn(0., self.x.get(), _dfz, dt, params) ## update state activation dynamics r = self.fx(x) + (eps_post * self.sigma_post) ## calculate (rectified) activity rates; f(x) r_prime = self.dfx(x) ## calculate local deriv of activity rates; f'(x) ## set compartments to next state values in accordance with dynamics self.key.set(key) ## carry noise key over transition (to next state of component) self.x.set(x) self.r.set(r) self.r_prime.set(r_prime)
[docs] @compilable def reset(self): ## reset core components/statistics self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
[docs] @compilable def batched_reset(self, batch_size): _shape = (batch_size, self.shape[0]) if len(self.shape) > 1: _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) if not self.j_input.targeted: self.j_input.set(restVals) if not self.j_recurrent.targeted: self.j_recurrent.set(restVals) self.x.set(restVals) self.r.set(restVals) self.r_prime.set(restVals)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "LeakyNoiseCell - evolves neurons according to continuous-time noisy/leaky dynamics " } compartment_props = { "inputs": {"j_input": "External input stimulus value(s)", "j_recurrent": "Recurrent/prior-state stimulus value(s)"}, "states": {"x": "Update to continuous noisy, leaky dynamics; value at time t"}, "outputs": {"r": "A linear rectifier applied to rate-coded dynamics; f(z)", "r_prime": "Temporal derivative applied to rate-coded dynamics; f'(z)"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "batch_size": "Batch size dimension of this component", "tau_x": "State time constant", "act_fx": "Type of rectification function to use", "sigma_pre": "The non-zero degree/scale of pre-rectification noise to inject into this neuron", "sigma_post": "The non-zero degree/scale of post-rectification noise to inject into this neuron" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "tau_x * dz/dt = -z + j_input + j_recurrent + noise, where noise ~N(0, sigma_rec)", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = LeakyNoiseCell("X", 9, 0.03) print(X)