Source code for ngclearn.components.neurons.spiking.fitzhughNagumoCell

from ngcsimlib.component import Component
from jax import numpy as jnp, random, jit
from functools import partial
import time
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
                                            step_euler, step_rk2

@jit
def update_times(t, s, tols):
    """
    Updates time-of-last-spike (tols) variable.

    Args:
        t: current time (a scalar/int value)

        s: binary spike vector

        tols: current time-of-last-spike variable

    Returns:
        updated tols variable
    """
    _tols = (1. - s) * tols + (s * t)
    return _tols

@jit
def _dfv_internal(j, v, w, a, b, g, tau_m): ## raw voltage dynamics
    dv_dt = v - jnp.power(v, 3)/g - w + j ## dv/dt
    dv_dt = dv_dt * (1./tau_m)
    return dv_dt

def _dfv(t, v, params): ## voltage dynamics wrapper
    j, w, a, b, g, tau_m = params
    dv_dt = _dfv_internal(j, v, w, a, b, g, tau_m)
    return dv_dt

@jit
def _dfw_internal(j, v, w, a, b, g, tau_w): ## raw recovery dynamics
    dw_dt = v + a - b * w ## dw/dt
    dw_dt = dw_dt * (1./tau_w)
    return dw_dt

def _dfw(t, w, params): ## recovery dynamics wrapper
    j, v, a, b, g, tau_m = params
    dv_dt = _dfw_internal(j, v, w, a, b, g, tau_m)
    return dv_dt

@jit
def _emit_spike(v, v_thr):
    s = (v > v_thr).astype(jnp.float32)
    return s

#@partial(jit, static_argnums=[10])
[docs] def run_cell(dt, j, v, w, v_thr, tau_m, tau_w, a, b, g=3., integType=0): """ Args: dt: integration time constant j: electrical current v: membrane potential / voltage w: recovery variable value(s) v_thr: voltage/membrane threshold (to obtain action potentials in terms of binary spikes) tau_m: membrane time constant tau_w: recover variable time constant (Default: 12.5 ms) a: dimensionless recovery variable shift factor "alpha" (Default: 0.7) b: dimensionless recovery variable scale factor "beta" (Default: 0.8) g: power-term divisor 'gamma' (Default: 3.) integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2) Returns: updated voltage, updated recovery, spikes """ if integType == 1: v_params = (j, w, a, b, g, tau_m) _, _v = step_rk2(0., v, _dfv, dt, v_params) #_v = step_rk2(v, v_params, _dfv, dt) w_params = (j, v, a, b, g, tau_w) _, _w = step_rk2(0., w, _dfw, dt, w_params) #_w = step_rk2(w, w_params, _dfw, dt) else: # integType == 0 (default -- Euler) v_params = (j, w, a, b, g, tau_m) _, _v = step_euler(0., v, _dfv, dt, v_params) #_v = step_euler(v, v_params, _dfv, dt) w_params = (j, v, a, b, g, tau_w) _, _w = step_euler(0., w, _dfw, dt, w_params) #_w = step_euler(w, w_params, _dfw, dt) #s = (_v > v_thr).astype(jnp.float32) s = _emit_spike(_v, v_thr) return _v, _w, s
[docs] class FitzhughNagumoCell(Component): """ The Fitzhugh-Nagumo neuronal cell model; a two-variable simplification of the Hodgkin-Huxley (squid axon) model. This cell model iteratively evolves voltage "v" and recovery "w" (which represents the combined effects of sodium channel deinactivation and potassium channel deactivation in the Hodgkin-Huxley model). The specific pair of differential equations that characterize this cell are (for adjusting v and w, given current j, over time): | tau_m * dv/dt = v - (v^3)/3 - w + j | tau_w * dw/dt = v + a - b * w | References: | FitzHugh, Richard. "Impulses and physiological states in theoretical | models of nerve membrane." Biophysical journal 1.6 (1961): 445-466. | | Nagumo, Jinichi, Suguru Arimoto, and Shuji Yoshizawa. "An active pulse | transmission line simulating nerve axon." Proceedings of the IRE 50.10 | (1962): 2061-2070. Args: name: the string name of this cell n_units: number of cellular entities (neural population size) tau_m: membrane time constant tau_w: recover variable time constant (Default: 12.5 ms) alpha: dimensionless recovery variable shift factor "a" (Default: 0.7) beta: dimensionless recovery variable scale factor "b" (Default: 0.8) gamma: power-term divisor (Default: 3.) v_thr: voltage/membrane threshold (to obtain action potentials in terms of binary spikes) v0: initial condition / reset for voltage w0: initial condition / reset for recovery integration_type: type of integration to use for this cell's dynamics; current supported forms include "euler" (Euler/RK-1 integration) and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") :Note: setting the integration type to the midpoint method will increase the accuray of the estimate of the cell's evolution at an increase in computational cost (and simulation time) key: PRNG key to control determinism of any underlying synapses associated with this cell useVerboseDict: triggers slower, verbose dictionary mode (Default: False) """ ## Class Methods for Compartment Names
[docs] @classmethod def inputCompartmentName(cls): return 'in'
[docs] @classmethod def outputCompartmentName(cls): return 'out'
[docs] @classmethod def voltageName(cls): return 'v'
[docs] @classmethod def recoveryName(cls): return 'w'
[docs] @classmethod def timeOfLastSpikeCompartmentName(cls): return 'tols'
## Bind Properties to Compartments for ease of use @property def inputCompartment(self): return self.compartments.get(self.inputCompartmentName(), None) @inputCompartment.setter def inputCompartment(self, inp): self.compartments[self.inputCompartmentName()] = inp @property def outputCompartment(self): return self.compartments.get(self.outputCompartmentName(), None) @outputCompartment.setter def outputCompartment(self, out): self.compartments[self.outputCompartmentName()] = out @property def voltage(self): return self.compartments.get(self.voltageName(), None) @voltage.setter def voltage(self, t): self.compartments[self.voltageName()] = t @property def recovery(self): return self.compartments.get(self.recoveryName(), None) @recovery.setter def recovery(self, t): self.compartments[self.recoveryName()] = t @property def timeOfLastSpike(self): return self.compartments.get(self.timeOfLastSpikeCompartmentName(), None) @timeOfLastSpike.setter def timeOfLastSpike(self, t): self.compartments[self.timeOfLastSpikeCompartmentName()] = t # Define Functions def __init__(self, name, n_units, tau_m=1., tau_w=12.5, alpha=0.7, beta=0.8, gamma=3., v_thr=1.07, v0=0., w0=0., integration_type="euler", key=None, useVerboseDict=False, **kwargs): super().__init__(name, useVerboseDict, **kwargs) ## Random Number Set up self.key = key if self.key is None: self.key = random.PRNGKey(time.time_ns()) ## Integration properties self.integrationType = integration_type self.intgFlag = get_integrator_code(self.integrationType) ## Cell properties self.tau_m = tau_m self.tau_w = tau_w self.alpha = alpha self.beta = beta self.gamma = gamma self.v0 = v0 ## initial membrane potential/voltage condition self.w0 = w0 ## initial w-parameter condition self.v_thr = v_thr ## Layer Size Setup self.batch_size = 1 self.n_units = n_units self.reset()
[docs] def verify_connections(self): pass
[docs] def advance_state(self, t, dt, **kwargs): self.key, *subkeys = random.split(self.key, 2) j = self.inputCompartment v = self.voltage w = self.recovery v, w, s = run_cell(dt, j, v, w, self.v_thr, self.tau_m, self.tau_w, self.alpha, self.beta, self.gamma, self.intgFlag) self.voltage = v self.recovery = w self.outputCompartment = s self.timeOfLastSpike = update_times(t, self.outputCompartment, self.timeOfLastSpike)
[docs] def reset(self, **kwargs): self.inputCompartment = None self.voltage = jnp.zeros((self.batch_size, self.n_units)) + self.v0 self.recovery = jnp.zeros((self.batch_size, self.n_units)) + self.w0 self.outputCompartment = jnp.zeros((self.batch_size, self.n_units)) #None self.timeOfLastSpike = jnp.zeros((self.batch_size, self.n_units))
[docs] def save(self, **kwargs): pass