from ngcsimlib.component import Component
from jax import numpy as jnp, random, jit, nn
from functools import partial
import time, sys
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
@jit
def update_times(t, s, tols):
"""
Updates time-of-last-spike (tols) variable.
Args:
t: current time (a scalar/int value)
s: binary spike vector
tols: current time-of-last-spike variable
Returns:
updated tols variable
"""
_tols = (1. - s) * tols + (s * t)
return _tols
@jit
def _dfv_internal(j, v, w, b, tau_m): ## raw voltage dynamics
## (v^2 * 0.04 + v * 5 + 140 - u + j) * a, where a = (1./tau_m) (w = u)
dv_dt = (jnp.square(v) * 0.04 + v * 5. + 140. - w + j)
dv_dt = dv_dt * (1./tau_m)
return dv_dt
def _dfv(t, v, params): ## voltage dynamics wrapper
j, w, b, tau_m = params
dv_dt = _dfv_internal(j, v, w, b, tau_m)
return dv_dt
@jit
def _dfw_internal(j, v, w, b, tau_w): ## raw recovery dynamics
## (v * b - u) from (v * b - u) * a (Izh. form) (w = u)
dw_dt = (v * b - w)
dw_dt = dw_dt * (1./tau_w)
return dw_dt
def _dfw(t, w, params): ## recovery dynamics wrapper
j, v, b, tau_w = params
dv_dt = _dfw_internal(j, v, w, b, tau_w)
return dv_dt
def _post_process(s, _v, _w, v, w, c, d): ## internal post-processing routine
# this step is specific to izh neuronal cells, where, after dynamics
# have evolved for a step in term, we then use the variables c and d
# to gate accordingly the dynamics for voltage v and recovery w
v_next = _v * (1. - s) + s * c
w_next = _w * (1. - s) + s * (w + d)
return v_next, w_next
@jit
def _emit_spike(v, v_thr):
s = (v > v_thr).astype(jnp.float32)
return s
@jit
def _modify_current(j, R_m):
_j = j * R_m
return _j
#@partial(jit, static_argnums=[12])
[docs]
def run_cell(dt, j, v, s, w, v_thr=30., tau_m=1., tau_w=50., b=0.2, c=-65., d=8.,
R_m=1., integType=0):
"""
Runs Izhikevich neuronal dynamics
Args:
dt: integration time constant (milliseconds, or ms)
j: electrical current value
v: membrane potential (voltage, in milliVolts or mV) value (at t)
s: previously measured spikes/action potentials (binary values)
w: recovery variable/state
v_thr: voltage threshold value (in mV)
tau_m: membrane time constant
tau_w: (tau_recovery) time scale/constant of recovery variable; note
that this is the inverse of Izhikevich's scale variable `a` (tau_w = 1/a)
b: (coupling factor) how sensitive is recovery to subthreshold voltage
fluctuations
c: (reset_voltage) voltage to reset to after spike emitted (in mV)
d: (reset_recovery) recovery value to reset to after a spike
R_m: membrane resistance value (Default: 1 mega-Ohm)
integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2)
Returns:
updated voltage, updated recovery, spikes
"""
## note: a = 0.1 --> fast spikes, a = 0.02 --> regular spikes
a = 1./tau_w ## we map time constant to variable "a" (a = 1/tau_w)
_j = _modify_current(j, R_m)
#_j = jnp.maximum(-30.0, _j) ## lower-bound/clip input current
## check for spikes
s = _emit_spike(v, v_thr)
## for non-spikes, evolve according to dynamics
if integType == 1:
v_params = (_j, w, b, tau_m)
_, _v = step_rk2(0., v, _dfv, dt, v_params) #_v = step_rk2(v, v_params, _dfv, dt)
w_params = (_j, v, b, tau_w)
_, _w = step_rk2(0., w, _dfw, dt, w_params) #_w = step_rk2(w, w_params, _dfw, dt)
else: # integType == 0 (default -- Euler)
v_params = (_j, w, b, tau_m)
_, _v = step_euler(0., v, _dfv, dt, v_params) #_v = step_euler(v, v_params, _dfv, dt)
w_params = (_j, v, b, tau_w)
_, _w = step_euler(0., w, _dfw, dt, w_params) #_w = step_euler(w, w_params, _dfw, dt)
## for spikes, snap to particular states
_v, _w = _post_process(s, _v, _w, v, w, c, d)
return _v, _w, s
[docs]
class IzhikevichCell(Component): ## Izhikevich neuronal cell
"""
A spiking cell based on Izhikevich's model of neuronal dynamics. Note that
this a two-variable simplification of more complex multi-variable systems
(e.g., Hodgkin-Huxley model). This cell model iteratively evolves
voltage "v" and recovery "w", the last of which represents the combined
effects of sodium channel deinactivation and potassium channel deactivation.
The specific pair of differential equations that characterize this cell
are (for adjusting v and w, given current j, over time):
| tau_m * dv/dt = 0.04 v^2 + 5v + 140 - w + j * R_m
| tau_w * dw/dt = (v * b - w), where tau_w = 1/a
| References:
| Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE Transactions
| on neural networks 14.6 (2003): 1569-1572.
Note: Izhikevich's constants/hyper-parameters 'a', 'b', 'c', and 'd' have
been remapped/renamed (see arguments below). Note that the default settings
for this component cell is for a regular spiking cell; to recover other
types of spiking cells (depending on what neuronal circuitry one wants to
model), one can recover specific models with the following particular values:
| Intrinsically bursting neurons: ``v_reset=-55, w_reset=4``
| Chattering neurons: ``v_reset = -50, w_reset = 2``
| Fast spiking neurons: ``tau_w = 10``
| Low-threshold spiking neurons: ``tau_w = 10, coupling_factor = 0.25, w_reset = 2``
| Resonator neurons: ``tau_w = 10, coupling_factor = 0.26``
Args:
name: the string name of this cell
n_units: number of cellular entities (neural population size)
tau_m: membrane time constant (Default: 1 ms)
R_m: membrane resistance value
v_thr: voltage threshold value to cross for emitting a spike
(in milliVolts, or mV) (Default: 30 mV)
v_reset: voltage value to reset to after a spike (in mV)
(Default: -65 mV), i.e., 'c'
tau_w: recovery variable time constant (Default: 50 ms), i.e., 1/'a'
w_reset: recovery value to reset to after a spike (Default: 8), i.e., 'd'
coupling_factor: degree of to which recovery is sensitive to any
subthreshold fluctuations of voltage (Default: 0.2), i.e., 'b'
v0: initial condition / reset for voltage (Default: -65 mV)
w0: initial condition / reset for recovery (Default: -14)
key: PRNG key to control determinism of any underlying random values
associated with this cell
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
:Note: setting the integration type to the midpoint method will
increase the accuray of the estimate of the cell's evolution
at an increase in computational cost (and simulation time)
useVerboseDict: triggers slower, verbose dictionary mode (Default: False)
"""
## Class Methods for Compartment Names
[docs]
@classmethod
def outputCompartmentName(cls):
return 'out'
[docs]
@classmethod
def voltageName(cls):
return 'v'
[docs]
@classmethod
def recoveryName(cls):
return 'w'
[docs]
@classmethod
def timeOfLastSpikeCompartmentName(cls):
return 'tols'
## Bind Properties to Compartments for ease of use
@property
def inputCompartment(self):
return self.compartments.get(self.inputCompartmentName(), None)
@inputCompartment.setter
def inputCompartment(self, inp):
self.compartments[self.inputCompartmentName()] = inp
@property
def outputCompartment(self):
return self.compartments.get(self.outputCompartmentName(), None)
@outputCompartment.setter
def outputCompartment(self, out):
self.compartments[self.outputCompartmentName()] = out
@property
def voltage(self):
return self.compartments.get(self.voltageName(), None)
@voltage.setter
def voltage(self, t):
self.compartments[self.voltageName()] = t
@property
def recovery(self):
return self.compartments.get(self.recoveryName(), None)
@recovery.setter
def recovery(self, t):
self.compartments[self.recoveryName()] = t
@property
def timeOfLastSpike(self):
return self.compartments.get(self.timeOfLastSpikeCompartmentName(), None)
@timeOfLastSpike.setter
def timeOfLastSpike(self, t):
self.compartments[self.timeOfLastSpikeCompartmentName()] = t
# Define Functions
def __init__(self, name, n_units, tau_m=1., R_m=1., v_thr=30., v_reset=-65.,
tau_w=50., w_reset=8., coupling_factor=0.2, v0=-65., w0=-14.,
integration_type="euler", key=None, useVerboseDict=False, **kwargs):
super().__init__(name, useVerboseDict, **kwargs)
## Cell properties
self.R_m = R_m
self.tau_m = tau_m
self.tau_w = tau_w
self.coupling = coupling_factor
self.v_reset = v_reset
self.w_reset = w_reset
self.v0 = v0 ## initial membrane potential/voltage condition
self.w0 = w0 ## initial recovery w-parameter condition
self.v_thr = v_thr
## Integration properties
self.integrationType = integration_type
self.intgFlag = get_integrator_code(self.integrationType)
##Random Number Set up
self.key = key
if self.key is None:
self.key = random.PRNGKey(time.time_ns())
##Layer Size Setup
self.batch_size = 1
self.n_units = n_units
self.reset()
[docs]
def verify_connections(self):
self.metadata.check_incoming_connections(self.inputCompartmentName(), min_connections=1)
[docs]
def advance_state(self, t, dt, **kwargs):
j = self.inputCompartment
v = self.voltage
w = self.recovery
s = self.outputCompartment
#if self.integration_type == "euler":
v, w, s = run_cell(dt, j, v, s, w, v_thr=self.v_thr, tau_m=self.tau_m,
tau_w=self.tau_w, b=self.coupling, c=self.v_reset,
d=self.w_reset, R_m=self.R_m, integType=self.intgFlag)
self.voltage = v
self.recovery = w
self.outputCompartment = s
[docs]
def reset(self, **kwargs):
self.inputCompartment = None
self.voltage = jnp.zeros((self.batch_size, self.n_units)) + self.v0
self.recovery = jnp.zeros((self.batch_size, self.n_units)) + self.w0
self.outputCompartment = jnp.zeros((self.batch_size, self.n_units)) #None
self.timeOfLastSpike = jnp.zeros((self.batch_size, self.n_units))
[docs]
def save(self, directory, **kwargs):
pass