from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, nn, Array, jit
from ngcsimlib import deprecate_args
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
triangular_estimator,
straight_through_estimator)
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@jit
def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
## update voltage / membrane potential
dv_dt = (j * mask) ## integration only involves electrical current
dv_dt = dv_dt * (1./tau_m)
return dv_dt
def _dfv(t, v, params): ## voltage dynamics wrapper
j, rfr, tau_m, refract_T = params
dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T)
return dv_dt
[docs]
class IFCell(JaxComponent): ## integrate-and-fire cell
"""
A spiking cell based on integrate-and-fire (IF) neuronal dynamics.
The specific differential equation that characterizes this cell
is (for adjusting v, given current j, over time) is:
| tau_m * dv/dt = j * R
| where R is the membrane resistance and v_rest is the resting potential
| also, if a spike occurs, v is set to v_reset
| --- Cell Input Compartments: ---
| j - electrical current input (takes in external signals)
| --- Cell State Compartments: ---
| v - membrane potential/voltage state
| rfr - (relative) refractory variable state
| key - JAX PRNG key
| --- Cell Output Compartments: ---
| s - emitted binary spikes/action potentials
| s_raw - raw spike signals before post-processing (only if one_spike = True, else s_raw = s)
| tols - time-of-last-spike
Args:
name: the string name of this cell
n_units: number of cellular entities (neural population size)
tau_m: membrane time constant
resist_m: membrane resistance value (default: 1)
thr: base value for adaptive thresholds that govern short-term
plasticity (in milliVolts, or mV; default: -52. mV)
v_rest: membrane resting potential (in mV; default: -65 mV)
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
a neuronal cell's membrane potential will be set to this value;
(default: -60 mV)
refract_time: relative refractory period time (ms; default: 0 ms)
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
:Note: setting the integration type to the midpoint method will
increase the accuracy of the estimate of the cell's evolution
at an increase in computational cost (and simulation time)
surrogate_type: type of surrogate function to use for approximating a
partial derivative of this cell's spikes w.r.t. its voltage/current
(default: "straight_through")
:Note: surrogate options available include: "straight_through"
(straight-through estimator), "triangular" (triangular estimator),
and "arctan" (arc-tangent estimator)
lower_clamp_voltage: if True, this will ensure voltage never is below
the value of `v_rest` (default: True)
"""
@deprecate_args(thr_jitter=None)
def __init__(
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., refract_time=0.,
integration_type="euler", surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs
):
super().__init__(name, **kwargs)
## Integration properties
self.integrationType = integration_type
self.intgFlag = get_integrator_code(self.integrationType)
## membrane parameter setup (affects ODE integration)
self.tau_m = tau_m ## membrane time constant
self.resist_m = resist_m ## resistance value
self.v_rest = v_rest #-65. # mV
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
## basic asserts to prevent neuronal dynamics breaking...
assert self.resist_m > 0.
self.refract_T = refract_time #5. # 2. ## refractory period # ms
self.thr = thr ## (fixed) base value for threshold #-52 # -72. # mV
self.lower_clamp_voltage = lower_clamp_voltage
## Layer Size Setup
self.batch_size = 1
self.n_units = n_units
## set up surrogate function for spike emission
# if surrogate_type == "arctan":
# self.spike_fx, self.d_spike_fx = arctan_estimator()
# elif surrogate_type == "triangular":
# self.spike_fx, self.d_spike_fx = triangular_estimator()
# else: ## default: straight_through
# self.spike_fx, self.d_spike_fx = straight_through_estimator()
## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.j = Compartment(restVals, display_name="Current", units="mA")
self.v = Compartment(restVals + self.v_rest,
display_name="Voltage", units="mV")
self.s = Compartment(restVals, display_name="Spikes")
self.rfr = Compartment(restVals + self.refract_T,
display_name="Refractory Time Period", units="ms")
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
units="ms") ## time-of-last-spike
#self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
[docs]
@compilable
def advance_state(
self, dt, t
):
## run one integration step for neuronal dynamics
j = self.j.get() * self.resist_m
### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
## update voltage / membrane potential
v_params = (j, self.rfr.get(), self.tau_m, self.refract_T)
if self.intgFlag == 1:
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
else:
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
## obtain action potentials/spikes
s = (_v > self.thr) * 1.
## update refractory variables
rfr = (self.rfr.get() + dt) * (1. - s)
## perform hyper-polarization of neuronal cells
v = _v * (1. - s) + s * self.v_reset
#surrogate = d_spike_fx(v, self.thr)
## update tols
self.tols.set((1. - s) * self.tols.get() + (s * t))
if self.lower_clamp_voltage: ## ensure voltage never < v_rest
_v = jnp.maximum(v, self.v_rest)
self.v.set(_v)
self.s.set(s)
self.rfr.set(rfr)
[docs]
@compilable
def reset(self):
restVals = jnp.zeros((self.batch_size, self.n_units))
if not self.j.targeted:
self.j.set(restVals)
self.v.set(restVals + self.v_rest)
self.s.set(restVals)
self.rfr.set(restVals + self.refract_T)
self.tols.set(restVals)
#surrogate = restVals + 1.
[docs]
def load(self, directory, seeded=False, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
## constants loaded in
self.tau_m = data['tau_m']
self.thr = data['thr']
self.v_rest = data['v_rest']
self.v_reset = data['v_reset']
self.v_decay = data['v_decay']
self.resist_m = data['resist_m']
self.tau_theta = data['tau_theta']
self.theta_plus = data['theta_plus']
if seeded:
self.key.set(data['key'])
[docs]
@classmethod
def help(cls): ## component help function
properties = {
"cell_type": "IFCell - evolves neurons according to integrate-"
"and-fire spiking dynamics."
}
compartment_props = {
"inputs":
{"j": "External input electrical current"},
"states":
{"v": "Membrane potential/voltage at time t",
"rfr": "Current state of (relative) refractory variable",
"thr": "Current state of voltage threshold at time t",
"key": "JAX PRNG key"},
"outputs":
{"s": "Emitted spikes/pulses at time t",
"tols": "Time-of-last-spike"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"tau_m": "Cell membrane time constant",
"resist_m": "Membrane resistance value",
"thr": "Base voltage threshold value",
"v_rest": "Resting membrane potential value",
"v_reset": "Reset membrane potential value",
"refract_time": "Length of relative refractory period (ms)",
"integration_type": "Type of numerical integration to use for the cell dynamics",
"surrgoate_type": "Type of surrogate function to use approximate "
"derivative of spike w.r.t. voltage/current",
"lower_bound_clamp": "Should voltage be lower bounded to be never be below `v_rest`"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "tau_m * dv/dt = (v_rest - v) + j * resist_m",
"hyperparameters": hyperparams}
return info