from ngcsimlib.component import Component
from jax import numpy as jnp, random, jit, nn
from functools import partial
import time, sys
@jit
def update_times(t, s, tols):
"""
Updates time-of-last-spike (tols) variable.
Args:
t: current time (a scalar/int value)
s: binary spike vector
tols: current time-of-last-spike variable
Returns:
updated tols variable
"""
_tols = (1. - s) * tols + (s * t)
return _tols
@partial(jit, static_argnums=[7,8,9,10,11,12,13])
def run_cell(dt, j, v, v_thr, v_theta, rfr, skey, v_c, a0, tau_m, R_m, v_rest,
v_reset, refract_T):
"""
Runs quadratic leaky integrator neuronal dynamics
Args:
dt: integration time constant (milliseconds, or ms)
j: electrical current value
v: membrane potential (voltage, in milliVolts or mV) value (at t)
v_thr: base voltage threshold value (in mV)
v_theta: threshold shift (homeostatic) variable (at t)
rfr: refractory variable vector (one per neuronal cell)
skey: PRNG key which, if not None, will trigger a single-spike constraint
(i.e., only one spike permitted to emit per single step of time);
specifically used to randomly sample one of the possible action
potentials to be an emitted spike
v_c: scaling factor for voltage accumulation
a0: critical voltage value
tau_m: cell membrane time constant
R_m: membrane resistance value
v_rest: membrane resting potential (in mV)
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
a neuronal cell's membrane potential will be set to this value
refract_T: (relative) refractory time period (in ms; Default
value is 1 ms)
Returns:
voltage(t+dt), spikes, raw spikes, updated refactory variables
"""
_v_thr = v_theta + v_thr ## calc present voltage threshold
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
## update voltage / membrane potential (v_c ~> 0.8?) (a0 usually <1?)
_v = v + ((v_rest - v) * (v - v_c) * a0) * (dt/tau_m) + (j * mask)
## obtain action potentials
s = (_v > _v_thr).astype(jnp.float32)
## update refractory variables
_rfr = (rfr + dt) * (1. - s)
## perform hyper-polarization of neuronal cells
_v = _v * (1. - s) + s * v_reset
raw_s = s + 0 ## preserve un-altered spikes
############################################################################
## this is a spike post-processing step
if skey is not None: ## FIXME: this would not work for mini-batches!!!!!!!
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32)
rS = random.choice(skey, s.shape[1], p=jnp.squeeze(s))
rS = nn.one_hot(rS, num_classes=s.shape[1], dtype=jnp.float32)
s = s * (1. - m_switch) + rS * m_switch
############################################################################
return _v, s, raw_s, _rfr
@partial(jit, static_argnums=[3,4])
def update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
"""
Runs homeostatic threshold update dynamics one step.
Args:
dt: integration time constant (milliseconds, or ms)
v_theta: current value of homeostatic threshold variable
s: current spikes (at t)
tau_theta: homeostatic threshold time constant
theta_plus: physical increment to be applied to any threshold value if
a spike was emitted
Returns:
updated homeostatic threshold variable
"""
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
#theta_plus = 0.05
#_V_theta = V_theta * theta_decay + S * theta_plus
theta_decay = jnp.exp(-dt/tau_theta)
_v_theta = v_theta * theta_decay + s * theta_plus
#_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
return _v_theta
[docs]
class QuadLIFCell(Component): ## quadratic leaky integrate-and-fire cell
"""
A spiking cell based on quadratic leaky integrate-and-fire (LIF) neuronal
dynamics.
Dynamics can be taken to be governed by the following ODE:
| d.Vz/d.t = a0 * (V - V_rest) * (V - V_c) + Jz * R) * (dt/tau_mem)
where:
| a0 - scaling factor for voltage accumulation
| V_c - critical voltage (value)
Args:
name: the string name of this cell
n_units: number of cellular entities (neural population size)
tau_m: membrane time constant
R_m: membrane resistance value
thr: base value for adaptive thresholds that govern short-term
plasticity (in milliVolts, or mV)
v_rest: membrane resting potential (in mV)
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
a neuronal cell's membrane potential will be set to this value
v_scale: scaling factor for voltage accumulation (v_c)
critical_V: critical voltage value (a0)
tau_theta: homeostatic threshold time constant
theta_plus: physical increment to be applied to any threshold value if
a spike was emitted
refract_T: relative refractory period time (ms; Default: 1 ms)
one_spike: if True, a single-spike constraint will be enforced for
every time step of neuronal dynamics simulated, i.e., at most, only
a single spike will be permitted to emit per step -- this means that
if > 1 spikes emitted, a single action potential will be randomly
sampled from the non-zero spikes detected
key: PRNG key to control determinism of any underlying random values
associated with this cell
useVerboseDict: triggers slower, verbose dictionary mode (Default: False)
directory: string indicating directory on disk to save LIF parameter
values to (i.e., initial threshold values and any persistent adaptive
threshold values)
"""
## Class Methods for Compartment Names
[docs]
@classmethod
def outputCompartmentName(cls):
return 's' ## spike/action potential
[docs]
@classmethod
def timeOfLastSpikeCompartmentName(cls):
return 'tols' ## time-of-last-spike (record vector)
[docs]
@classmethod
def voltageCompartmentName(cls):
return 'v' ## membrane potential/voltage
[docs]
@classmethod
def thresholdThetaName(cls):
return 'thrTheta' ## action potential threshold
[docs]
@classmethod
def refractCompartmentName(cls):
return 'rfr' ## refractory variable(s)
## Bind Properties to Compartments for ease of use
@property
def current(self):
return self.compartments.get(self.inputCompartmentName(), None)
@current.setter
def current(self, inp):
self.compartments[self.inputCompartmentName()] = inp
@property
def spikes(self):
return self.compartments.get(self.outputCompartmentName(), None)
@spikes.setter
def spikes(self, out):
self.compartments[self.outputCompartmentName()] = out
@property
def timeOfLastSpike(self):
return self.compartments.get(self.timeOfLastSpikeCompartmentName(), None)
@timeOfLastSpike.setter
def timeOfLastSpike(self, t):
self.compartments[self.timeOfLastSpikeCompartmentName()] = t
@property
def voltage(self):
return self.compartments.get(self.voltageCompartmentName(), None)
@voltage.setter
def voltage(self, v):
self.compartments[self.voltageCompartmentName()] = v
@property
def refract(self):
return self.compartments.get(self.refractCompartmentName(), None)
@refract.setter
def refract(self, rfr):
self.compartments[self.refractCompartmentName()] = rfr
@property
def threshold_theta(self):
return self.compartments.get(self.thresholdThetaName(), None)
@threshold_theta.setter
def threshold_theta(self, thr):
self.compartments[self.thresholdThetaName()] = thr
# Define Functions
def __init__(self, name, n_units, tau_m, R_m, thr=-52., v_rest=-65., v_reset=60.,
v_c=-41.6, a0=1., tau_theta=1e7, theta_plus=0.05, refract_T=5.,
key=None, one_spike=True, useVerboseDict=False, directory=None,
**kwargs):
super().__init__(name, useVerboseDict, **kwargs)
##Random Number Set up
self.key = key
if self.key is None:
self.key = random.PRNGKey(time.time_ns())
## membrane parameter setup (affects ODE integration)
self.tau_m = tau_m ## membrane time constant
self.R_m = R_m ## resistance value
self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
self.v_rest = v_rest #-65. # mV
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off)
self.theta_plus = theta_plus #0.05 ## threshold increment
self.refract_T = refract_T #5. # 2. ## refractory period # ms
self.v_c = v_scale
self.a0 = critical_V
##Layer Size Setup
self.n_units = n_units
self.threshold = thr ## (fixed) base value for threshold #-52 # -72. # mV
## adaptive threshold setup
if directory is None:
self.threshold_theta = jnp.zeros((1, n_units))
else:
self.load(directory)
self.reset()
[docs]
def verify_connections(self):
self.metadata.check_incoming_connections(self.inputCompartmentName(), min_connections=1)
[docs]
def advance_state(self, t, dt, **kwargs):
if self.spikes is None:
self.spikes = jnp.zeros((1, self.n_units))
if self.refract is None:
self.refract = jnp.zeros((1, self.n_units)) + self.refract_T
skey = None ## this is an empty dkey if single_spike mode turned off
if self.one_spike is False:
self.key, *subkeys = random.split(self.key, 2)
skey = subkeys[0]
## run one step of Euler integration over neuronal dynamics
self.voltage, self.spikes, raw_spikes, self.refract = \
run_cell(dt, self.current, self.voltage, self.threshold,
self.threshold_theta, self.refract, skey, self.v_c, self.a0,
self.tau_m, self.R_m, self.v_rest, self.v_reset, self.refract_T)
if self.tau_theta > 0.:
## run one step of Euler integration over threshold dynamics
self.threshold_theta = update_theta(dt, self.threshold_theta, raw_spikes,
self.tau_theta, self.theta_plus)
## update tols
self.timeOfLastSpike = update_times(t, self.spikes, self.timeOfLastSpike)
#self.timeOfLastSpike = (1 - self.spikes) * self.timeOfLastSpike + (self.spikes * t)
#self.current = None
[docs]
def reset(self, **kwargs):
self.voltage = jnp.zeros((1, self.n_units)) + self.v_rest
self.refract = jnp.zeros((1, self.n_units)) + self.refract_T
self.current = jnp.zeros((1, self.n_units)) #None
self.timeOfLastSpike = jnp.zeros((1, self.n_units))
self.spikes = jnp.zeros((1, self.n_units)) #None
[docs]
def save(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
jnp.savez(file_name, threshold_theta=self.threshold_theta)
[docs]
def load(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
self.threshold_theta = data['threshold_theta']