from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, nn, Array
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
triangular_estimator,
straight_through_estimator)
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
def _dfv(t, v, params): ## voltage dynamics wrapper
j, rfr, tau_m, refract_T, v_rest, g_L = params
mask = (rfr >= refract_T) * 1. # get refractory mask
## update voltage / membrane potential
dv_dt = (v_rest - v) * g_L + (j * mask)
dv_dt = dv_dt * (1. / tau_m)
return dv_dt
#@partial(jit, static_argnums=[3, 4])
def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05):
### Runs homeostatic threshold update dynamics one step (via Euler integration).
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
#theta_plus = 0.05
#_V_theta = V_theta * theta_decay + S * theta_plus
theta_decay = jnp.exp(-dt/tau_theta)
_v_theta = v_theta * theta_decay + s * theta_plus
#_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
return _v_theta
[docs]
class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
"""
A spiking cell based on leaky integrate-and-fire (LIF) neuronal dynamics.
The specific differential equation that characterizes this cell
is (for adjusting v, given current j, over time) is:
| tau_m * dv/dt = (v_rest - v) + j * R
| where R is the membrane resistance and v_rest is the resting potential
| also, if a spike occurs, v is set to v_reset
| --- Cell Input Compartments: ---
| j - electrical current input (takes in external signals)
| --- Cell State Compartments: ---
| v - membrane potential/voltage state
| rfr - (relative) refractory variable state
| thr_theta - homeostatic/adaptive threshold increment state
| key - JAX PRNG key
| --- Cell Output Compartments: ---
| s - emitted binary spikes/action potentials
| s_raw - raw spike signals before post-processing (only if one_spike = True, else s_raw = s)
| tols - time-of-last-spike
Args:
name: the string name of this cell
n_units: number of cellular entities (neural population size)
tau_m: membrane time constant
resist_m: membrane resistance value (Default: 1)
thr: base value for adaptive thresholds that govern short-term
plasticity (in milliVolts, or mV; default: -52. mV)
v_rest: reversal potential or membrane resting potential (in mV; default: -65 mV)
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
a neuronal cell's membrane potential will be set to this value;
(default: -60 mV)
conduct_leak: leak conductance (g_L) value or decay factor applied to voltage leak
(Default: 1.); setting this to 0 mV recovers pure integrate-and-fire (IF) dynamics
tau_theta: homeostatic threshold time constant
theta_plus: physical increment to be applied to any threshold value if
a spike was emitted
refract_time: relative refractory period time (ms; Default: 5 ms)
one_spike: if True, a single-spike constraint will be enforced for
every time step of neuronal dynamics simulated, i.e., at most, only
a single spike will be permitted to emit per step -- this means that
if > 1 spikes emitted, a single action potential will be randomly
sampled from the non-zero spikes detected (Default: False)
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
:Note: setting the integration type to the midpoint method will
increase the accuracy of the estimate of the cell's evolution
at an increase in computational cost (and simulation time)
surrogate_type: type of surrogate function to use for approximating a
partial derivative of this cell's spikes w.r.t. its voltage/current
(default: "straight_through")
:Note: surrogate options available include: "straight_through"
(straight-through estimator), "triangular" (triangular estimator),
"arctan" (arc-tangent estimator), and "secant_lif" (the
LIF-specialized secant estimator)
v_min: minimum voltage to clamp dynamics to (Default: None)
""" ## batch_size arg?
def __init__(
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., conduct_leak=1., tau_theta=1e7,
theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", surrogate_type="straight_through",
v_min=None, max_one_spike=False, key=None
):
super().__init__(name, key)
## Integration properties
self.integrationType = integration_type
self.intgFlag = get_integrator_code(self.integrationType)
self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
self.max_one_spike = max_one_spike
## membrane parameter setup (affects ODE integration)
self.tau_m = tau_m ## membrane time constant
self.resist_m = resist_m ## resistance value
self.v_min = v_min ## ensures voltage is never < v_min
self.v_rest = v_rest #-65. # mV
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
self.g_L = conduct_leak ## controls strength of voltage leak (1 -> LIF, 0 => IF)
## basic asserts to prevent neuronal dynamics breaking...
#assert (self.conduct_leak * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify...
assert self.resist_m > 0.
self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off)
self.theta_plus = theta_plus #0.05 ## threshold increment
self.refract_T = refract_time #5. # 2. ## refractory period # ms
self.thr = thr ## (fixed) base value for threshold #-52 # -72. # mV
## Layer Size Setup
self.batch_size = 1
self.n_units = n_units
# ## set up surrogate function for spike emission
# if surrogate_type == "secant_lif":
# spike_fx, d_spike_fx = secant_lif_estimator()
# elif surrogate_type == "arctan":
# spike_fx, d_spike_fx = arctan_estimator()
# elif surrogate_type == "triangular":
# spike_fx, d_spike_fx = triangular_estimator()
# else: ## default: straight_through
# spike_fx, d_spike_fx = straight_through_estimator()
## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.j = Compartment(restVals, display_name="Current", units="mA")
self.v = Compartment(restVals + self.v_rest, display_name="Voltage", units="mV")
self.s = Compartment(restVals, display_name="Spikes")
self.s_raw = Compartment(restVals, display_name="Raw Spike Pulses")
self.rfr = Compartment(restVals + self.refract_T, display_name="Refractory Time Period", units="ms")
self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift", units="mV")
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") ## time-of-last-spike
# self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
[docs]
@compilable
def advance_state(self, dt, t):
j = self.j.get() * self.resist_m
_v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold
v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T, self.v_rest, self.g_L)
if self.intgFlag == 1:
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
else:
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
s = (_v > _v_thr) * 1.
_rfr = (self.rfr.get() + dt) * (1. - s)
_v = _v * (1. - s) + s * self.v_reset
raw_s = s
if self.one_spike and not self.max_one_spike:
key, skey = random.split(self.key.get(), 2)
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
rS = s * random.uniform(skey, s.shape)
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1], dtype=jnp.float32)
s = s * (1. - m_switch) + rS * m_switch
self.key.set(key)
if self.max_one_spike:
rS = nn.one_hot(jnp.argmax(self.v.get(), axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike
s = s * rS ## mask out non-max volt spikes
if self.tau_theta > 0.:
## run one integration step for threshold dynamics
thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get())
self.thr_theta.set(thr_theta)
## update time-of-last spike variable(s)
self.tols.set((1. - s) * self.tols.get() + (s * t))
if self.v_min is not None: ## ensures voltage never < v_rest
_v = jnp.maximum(_v, self.v_min)
self.v.set(_v)
self.s.set(s)
self.s_raw.set(raw_s)
self.rfr.set(_rfr)
[docs]
@compilable
def reset(self):
restVals = jnp.zeros((self.batch_size, self.n_units))
if not self.j.targeted:
self.j.set(restVals)
self.v.set(restVals + self.v_rest)
self.s.set(restVals)
self.s_raw.set(restVals)
self.rfr.set(restVals + self.refract_T)
self.tols.set(restVals)
[docs]
@classmethod
def help(cls): ## component help function
properties = {
"cell_type": "LIFCell - evolves neurons according to leaky integrate-"
"and-fire spiking dynamics."
}
compartment_props = {
"inputs":
{"j": "External input electrical current"},
"states":
{"v": "Membrane potential/voltage at time t",
"rfr": "Current state of (relative) refractory variable",
"thr": "Current state of voltage threshold at time t",
"thr_theta": "Current state of homeostatic adaptive threshold at time t",
"key": "JAX PRNG key"},
"outputs":
{"s": "Emitted spikes/pulses at time t",
"tols": "Time-of-last-spike"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"tau_m": "Cell membrane time constant",
"resist_m": "Membrane resistance value",
"thr": "Base voltage threshold value",
"v_rest": "Resting membrane potential value",
"v_reset": "Reset membrane potential value",
"conduct_leak": "Conductance leak / voltage decay factor",
"tau_theta": "Threshold/homoestatic increment time constant",
"theta_plus": "Amount to increment threshold by upon occurrence of a spike",
"refract_time": "Length of relative refractory period (ms)",
"one_spike": "Should only one spike be sampled/allowed to emit at any given time step?",
"integration_type": "Type of numerical integration to use for the cell dynamics",
"surrgoate_type": "Type of surrogate function to use approximate "
"derivative of spike w.r.t. voltage/current",
"v_min": "Minimum voltage allowed before voltage variables are min-clipped/clamped"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "tau_m * dv/dt = (v_rest - v) + j * resist_m",
"hyperparameters": hyperparams}
return info
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
X = LIFCell("X", 9, 0.0004, 3)
print(X)