Source code for ngclearn.components.neurons.spiking.sLIFCell

# %%

from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from functools import partial
from ngclearn.utils.diffeq.ode_utils import step_euler
from ngclearn.utils.surrogate_fx import secant_lif_estimator

from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment

@jit
def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics
    mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
    #dv_dt = ((-v + j) * (dt / tau_m)) * mask
    dv_dt = (-v + j)
    dv_dt = dv_dt * (1./tau_m) * mask
    return dv_dt

#@partial(jit, static_argnums=[2])
def _dfv(t, v, params): ## voltage dynamics wrapper
    j, rfr, tau_m, refract_T = params
    dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T)
    return dv_dt

@partial(jit, static_argnums=[3,4,5])
def _update_threshold(dt, v_thr, spikes, thrGain=0.002, thrLeak=0.0005, rho_b = 0.):
    ## update thresholds if applicable
    if rho_b > 0.: ## run sparsity-enforcement threshold
        dthr = jnp.sum(spikes, axis=1, keepdims=True) - 1.0
        _v_thr = jnp.maximum(v_thr + dthr * rho_b, 0.025)
    else: ## run simple adaptive threshold
        thr_gain = spikes * thrGain
        thr_leak = (v_thr * thrLeak)
        _v_thr = v_thr + thr_gain - thr_leak
    return _v_thr

@partial(jit, static_argnums=[4])
def _update_refract_and_spikes(dt, rfr, s, refract_T, sticky_spikes=False):
    mask = (rfr >= refract_T).astype(jnp.float32) ## Note: wasted repeated compute
    ## update refractory variables
    _rfr = (rfr + dt) * (1. - s) + s * dt # set refract to dt
    _s = s
    if sticky_spikes == True: ## pin refractory spikes if configured
        _s = s * mask + (1. - mask)
    return _rfr, _s

[docs] class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell """ A spiking cell based on a simplified leaky integrate-and-fire (sLIF) model. This neuronal cell notably contains functionality required by the computational model employed by (Samadi et al., 2017, i.e., a surrogate derivative function and "sticky spikes") as well as the additional incorporation of an adaptive threshold (per unit) scheme. (Note that this particular spiking cell only supports Euler integration of its voltage dynamics.) | --- Cell Input Compartments: --- | j - electrical current input (takes in external signals) | --- Cell State Compartments: --- | v - membrane potential/voltage state | rfr - (relative) refractory variable state | thr - (adaptive) threshold state | key - JAX PRNG key | --- Cell Output Compartments: --- | s - emitted binary spikes/action potentials | surrogate - state of surrogate function output signals (currently, the secant LIF estimator) | tols - time-of-last-spike | Reference: | Samadi, Arash, Timothy P. Lillicrap, and Douglas B. Tweed. "Deep learning with | dynamic spiking neurons and fixed feedback weights." Neural computation 29.3 | (2017): 578-602. Args: name: the string name of this cell n_units: number of cellular entities (neural population size) tau_m: membrane time constant resist_m: membrane resistance value thr: base value for adaptive thresholds (initial condition for per-cell thresholds) that govern short-term plasticity resist_inh: lateral modulation factor (DEFAULT: 6.); if >0, this will trigger a heuristic form of lateral inhibition via an internally integrated hollow matrix multiplication thr_persist: are adaptive thresholds persistent? (Default: False) :Note: depending on the value of this boolean variable: True = adaptive thresholds are NEVER reset upon call to reset False = adaptive thresholds are reset to "thr" upon call to reset thr_gain: how much adaptive thresholds increment by thr_leak: how much adaptive thresholds are decremented/decayed by refract_time: relative refractory period time (ms; Default: 1 ms) rho_b: threshold sparsity factor (Default: 0); note that setting rho_b > 0 will force the adaptive threshold to follow dynamics that ignore `thr_grain` and `thr_leak` sticky_spikes: if True, spike variables will be pinned to action potential value (i.e, 1) throughout duration of the refractory period; this recovers a key setting used by Samadi et al., 2017 thr_jitter: scale of uniform jitter to add to initialization of thresholds batch_size: batch size dimension of this cell (Default: 1) """ def __init__( self, name, n_units, tau_m, resist_m, thr, resist_inh=0., thr_persist=False, thr_gain=0.0, thr_leak=0.0, rho_b=0., refract_time=0., sticky_spikes=False, thr_jitter=0.05, batch_size=1, **kwargs ): super().__init__(name, **kwargs) ## membrane parameter setup (affects ODE integration) self.tau_m = tau_m ## membrane time constant self.R_m = resist_m ## resistance value self.refract_T = refract_time #5. # 2. ## refractory period # ms self.v_min = -3. ## variable below determines if spikes pinned at 1 during refractory period? self.sticky_spikes = sticky_spikes ## set up surrogate function for spike emission self.spike_fx, self.d_spike_fx = secant_lif_estimator() ## create simple recurrent inhibitory pressure self.inh_R = resist_inh ## lateral inhibitory magnitude key, subkey = random.split(self.key.get()) self.inh_weights = random.uniform(subkey, (n_units, n_units), minval=0.025, maxval=1.) self.inh_weights = self.inh_weights * (1. - jnp.eye(n_units)) ## Layer Size Setup self.n_units = n_units self.batch_size = batch_size ## Adaptive threshold setup self.rho_b = rho_b self.thr_persist = thr_persist ## are adapted thresholds persistent? True (persistent) self.thrGain = thr_gain #0.0005 self.thrLeak = thr_leak #0.00005 # thr_jitter: some random jitter to ensure thresholds start off different key, subkey = random.split(key) self.threshold0 = thr + random.uniform(subkey, (1, n_units), minval=-thr_jitter, maxval=thr_jitter, dtype=jnp.float32) ## Compartments restVals = jnp.zeros((self.batch_size, self.n_units)) self.j = Compartment(restVals) ## electrical current, input self.s = Compartment(restVals) ## spike/action potential, output self.tols = Compartment(restVals) ## time-of-last-spike (record vector) self.v = Compartment(restVals) ## membrane potential/voltage self.thr = Compartment(self.threshold0 + 0.) ## action potential threshold self.rfr = Compartment(restVals + self.refract_T) ## refractory variable(s) self.surrogate = Compartment(restVals + 1.) ## surrogate signal
[docs] @compilable def advance_state(self, t, dt): ##################################################################################### #The following 3 lines of code modify electrical current j via application of a #scalar membrane resistance value and an approximate form of lateral inhibition. #Functionally, this routine carries out the following piecewise equation: #| j * R_m - [Wi * s(t-dt)] * inh_R, if inh_R > 0 #| j * R_m, otherwise #| where j: electrical current value, spikes: previous binary spike vector (for t-dt), # inh_weights: lateral recurrent inhibitory synapses (typically should be chosen # to be a scaled hollow matrix), #| R_m: membrane resistance (to multiply/scale j by), #| inh_R: inhibitory resistance to scale lateral inhibitory current by; if inh_R = 0, # NO lateral inhibitory pressure will be applied # First, get the relevant compartment values j = self.j.get() # s = self.s.get() # NOTE: This is unused tols = self.tols.get() v = self.v.get() thr = self.thr.get() rfr = self.rfr.get() surrogate = self.surrogate.get() ## modify electrical current j via membrane resistance and lateral inhibition j = j * self.R_m if self.inh_R > 0.: ## if inh_R > 0, then lateral inhibition is applied j = j - (jnp.matmul(self.s.get(), self.inh_weights) * self.inh_R) ##################################################################################### surrogate = self.d_spike_fx(j, c1=0.82, c2=0.08) ## calc surrogate deriv of spikes ## transition to: voltage(t+dt), spikes, threshold(t+dt), refractory_variables(t+dt) v_params = (j, rfr, self.tau_m, self.refract_T) _, _v = step_euler(0., v, _dfv, dt, v_params) spikes = self.spike_fx(_v, thr) #_v = _hyperpolarize(_v, spikes) _v = (1. - spikes) * _v ## hyper-polarize cells new_thr = _update_threshold(dt, thr, spikes, self.thrGain, self.thrLeak, self.rho_b) _rfr, spikes = _update_refract_and_spikes(dt, rfr, spikes, self.refract_T, self.sticky_spikes) v = _v s = spikes thr = new_thr rfr = _rfr ## update tols tols = (1. - s) * tols + (s * t) # return j, s, tols, v, thr, rfr, surrogate self.j.set(j) self.s.set(s) self.tols.set(tols) self.v.set(v) self.thr.set(thr) self.rfr.set(rfr) self.surrogate.set(surrogate)
[docs] @compilable def reset(self): # refract_T, thr_persist, threshold0, batch_size, n_units, thr restVals = jnp.zeros((self.batch_size, self.n_units)) voltage = restVals refract = restVals + self.refract_T current = restVals surrogate = restVals + 1. timeOfLastSpike = restVals spikes = restVals if not self.thr_persist: ## if thresh non-persistent, reset to base value thr = self.threshold0 + 0 self.thr.set(thr) # return current, spikes, timeOfLastSpike, voltage, thr, refract, surrogate self.j.set(current) self.s.set(spikes) self.tols.set(timeOfLastSpike) self.v.set(voltage) self.rfr.set(refract) self.surrogate.set(surrogate)
[docs] def save(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" if self.thr_persist == False: jnp.savez(file_name, threshold=self.threshold0) # save threshold0 else: jnp.savez(file_name, threshold=self.thr.get()) # save the actual threshold param/compartment
[docs] def load(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" data = jnp.load(file_name) self.thr.set(data['threshold']) self.threshold0 = self.thr.get() + 0
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "SLIFCell - evolves neurons according to simplified " "leaky integrate-and-fire spiking dynamics." } compartment_props = { "inputs": {"j": "External input electrical current"}, "states": {"v": "Membrane potential/voltage at time t", "rfr": "Current state of (relative) refractory variable", "thr": "Current state of voltage threshold at time t", "key": "JAX PRNG key"}, "outputs": {"s": "Emitted spikes/pulses at time t", "tols": "Time-of-last-spike", "surrogate": "State/value of surrogate function at time t"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "tau_m": "Cell membrane time constant", "resist_m": "Membrane resistance value", "thr": "Base voltage threshold value", "resist_inh": "Inhibitory resistance value", "thr_persist": "Should adaptive threshold persist across reset calls?", "thr_gain": "Amount to increment threshold by upon occurrence of spike", "thr_leak": "Amount to decay threshold upon occurrence of spike", "rho_b": "Shared threshold sparsity control parameter (if using shared threshold)", "refract_time": "Length of relative refractory period (ms)", "thr_jitter": "Scale of random uniform noise to apply to initial condition of threshold", "sticky_spikes": "Should spikes be allowed to persist during refractory period?" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "tau_m * dv/dt = -v + j * resist_m", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = SLIFCell("X", 9, 0.0004, 3, 0.3) print(X)