Source code for ngclearn.components.neurons.spiking.WTASCell

from jax import numpy as jnp, random, jit, nn
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit, nn
from ngcsimlib import deprecate_args
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.utils.model_utils import softmax


[docs] class WTASCell(JaxComponent): ## winner-take-all spiking cell """ A spiking cell based on winner-take-all neuronal dynamics ("WTAS" stands for "winner-take-all-spiking"). The differential equation for adjusting this specific cell (for adjusting v, given current j, over time) is: | tau_m * dv/dt = j * R ; v_p = softmax(v) | where R is membrane resistance and v_p is a voltage probability vector | --- Cell Input Compartments: --- | j - electrical current input (takes in external signals) | --- Cell State Compartments: --- | v - membrane potential/voltage state | rfr - (relative) refractory variable state | thr - (adaptive) threshold state | key - JAX PRNG key | --- Cell Output Compartments: --- | s - emitted binary spikes/action potentials | tols - time-of-last-spike Args: name: the string name of this cell n_units: number of cellular entities (neural population size) tau_m: membrane time constant resist_m: membrane resistance value (Default: 1) thr_base: base value for adaptive thresholds that govern short-term plasticity (in milliVolts, or mV) thr_gain: increment to be applied to threshold in presence of spike refract_time: relative refractory period time (ms; Default: 1 ms) thr_jitter: scale of uniform jitter to add to initialization of thresholds """ @deprecate_args(thrBase="thr_base") def __init__( self, name, n_units, tau_m, resist_m=1., thr_base=0.4, thr_gain=0.002, refract_time=0., thr_jitter=0.05, batch_size=1, **kwargs ): super().__init__(name, **kwargs) ## membrane parameter setup (affects ODE integration) self.tau_m = tau_m ## membrane time constant self.R_m = resist_m ## resistance value self.thr_gain = thr_gain self.thr_base = thr_base # mV ## base value for threshold self.refract_T = refract_time ## Layer Size Setup self.batch_size = batch_size self.n_units = n_units ## base threshold setup ## according to eqn 26 of the source paper, the initial condition for the ## threshold should technically be between: 1/n_units < threshold0 << 0.5, e.g., 0.15 key, subkey = random.split(self.key.get()) self.threshold0 = thr_base + random.uniform(subkey, (1, n_units), minval=-thr_jitter, maxval=thr_jitter, dtype=jnp.float32) ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) self.j = Compartment(restVals) self.v = Compartment(restVals) self.s = Compartment(restVals) self.thr = Compartment(self.threshold0) self.rfr = Compartment(restVals + self.refract_T) self.tols = Compartment(restVals) ## time-of-last-spike
[docs] @compilable def advance_state(self, t, dt): mask = (self.rfr.get() >= self.refract_T) * 1. ## check refractory period v = (self.j.get() * self.R_m) * mask vp = softmax(v) # convert to Categorical (spike) probabilities # s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike s = (vp > self.thr.get()) * 1. ## calculate action potential q = 1. ## Note: thr_gain ==> "rho_b" ## increment threshold upon spike(s) occurrence dthr = jnp.sum(s, axis=1, keepdims=True) - q thr = jnp.maximum(self.thr.get() + dthr * self.thr_gain, 0.025) ## calc new threshold rfr = (self.rfr.get() + dt) * (1. - s) + s * dt # set refract to dt self.tols.set((1. - s) * self.tols.get() + (s * t)) ## update times-of-last-spike(s) self.v.set(v) self.s.set(s) self.thr.set(thr) self.rfr.set(rfr)
[docs] @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) if not self.j.targeted: self.j.set(restVals) self.v.set(restVals) self.s.set(restVals) self.rfr.set(restVals + self.refract_T) self.tols.set(restVals)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "WTASCell - evolves neurons according to winner-take-all " "spiking dynamics " } compartment_props = { "inputs": {"j": "External input electrical current"}, "states": {"v": "Membrane potential/voltage at time t", "rfr": "Current state of (relative) refractory variable", "thr": "Current state of voltage threshold at time t", "key": "JAX PRNG key"}, "outputs": {"s": "Emitted spikes/pulses at time t", "tols": "Time-of-last-spike"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "tau_m": "Cell membrane time constant", "resist_m": "Membrane resistance value", "thr_base": "Base voltage threshold value", "thr_gain": "Amount to increment threshold by upon occurrence of spike", "refract_time": "Length of relative refractory period (ms)", "thr_jitter": "Scale of random uniform noise to apply to initial condition of threshold" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "tau_m * dv/dt = j * resist_m", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = WTASCell("X", 1, 1.) print(X)