Source code for ngclearn.components.neurons.spiking.fitzhughNagumoCell

from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit, nn
from ngcsimlib import deprecate_args
from ngcsimlib.logger import info, warn
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
                                            step_euler, step_rk2

from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment

@jit
def _dfv_internal(j, v, w, a, b, g, tau_m): ## raw voltage dynamics
    dv_dt = v - jnp.power(v, 3)/g - w + j ## dv/dt
    dv_dt = dv_dt * (1./tau_m)
    return dv_dt

def _dfv(t, v, params): ## voltage dynamics wrapper
    j, w, a, b, g, tau_m = params
    dv_dt = _dfv_internal(j, v, w, a, b, g, tau_m)
    return dv_dt

@jit
def _dfw_internal(j, v, w, a, b, g, tau_w): ## raw recovery dynamics
    dw_dt = v + a - b * w ## dw/dt
    dw_dt = dw_dt * (1./tau_w)
    return dw_dt

def _dfw(t, w, params): ## recovery dynamics wrapper
    j, v, a, b, g, tau_m = params
    dv_dt = _dfw_internal(j, v, w, a, b, g, tau_m)
    return dv_dt

[docs] class FitzhughNagumoCell(JaxComponent): ## F-H cell """ The Fitzhugh-Nagumo neuronal cell model; a two-variable simplification of the Hodgkin-Huxley (squid axon) model. This cell model iteratively evolves voltage "v" and recovery "w" (which represents the combined effects of sodium channel deinactivation and potassium channel deactivation in the Hodgkin-Huxley model). The specific pair of differential equations that characterize this cell are (for adjusting v and w, given current j, over time): | tau_m * dv/dt = v - (v^3)/3 - w + j | tau_w * dw/dt = v + a - b * w | --- Cell Input Compartments: --- | j - electrical current input (takes in external signals) | --- Cell State Compartments: --- | v - membrane potential/voltage state | w - recovery variable state | key - JAX PRNG key | --- Cell Output Compartments: --- | s - emitted binary spikes/action potentials | tols - time-of-last-spike | References: | FitzHugh, Richard. "Impulses and physiological states in theoretical | models of nerve membrane." Biophysical journal 1.6 (1961): 445-466. | | Nagumo, Jinichi, Suguru Arimoto, and Shuji Yoshizawa. "An active pulse | transmission line simulating nerve axon." Proceedings of the IRE 50.10 | (1962): 2061-2070. Args: name: the string name of this cell n_units: number of cellular entities (neural population size) tau_m: membrane time constant resist_m: membrane resistance value tau_w: recover variable time constant (Default: 12.5 ms) alpha: dimensionless recovery variable shift factor "a" (Default: 0.7) beta: dimensionless recovery variable scale factor "b" (Default: 0.8) gamma: power-term divisor (Default: 3.) v0: initial condition / reset for voltage w0: initial condition / reset for recovery v_thr: voltage/membrane threshold (to obtain action potentials in terms of binary spikes) spike_reset: if True, once voltage crosses threshold, then dynamics of voltage and recovery are reset/snapped to initial conditions (default: False) integration_type: type of integration to use for this cell's dynamics; current supported forms include "euler" (Euler/RK-1 integration) and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") :Note: setting the integration type to the midpoint method will increase the accuracy of the estimate of the cell's evolution at an increase in computational cost (and simulation time) """ def __init__( self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, beta=0.8, gamma=3., v0=0., w0=0., v_thr=1.07, spike_reset=False, integration_type="euler", **kwargs ): super().__init__(name, **kwargs) ## Integration properties self.integrationType = integration_type self.intgFlag = get_integrator_code(self.integrationType) ## Cell properties self.tau_m = tau_m self.resist_m = resist_m ## resistance R_m self.tau_w = tau_w self.alpha = alpha self.beta = beta self.gamma = gamma self.v0 = v0 ## initial membrane potential/voltage condition self.w0 = w0 ## initial w-parameter condition self.v_thr = v_thr self.spike_reset = spike_reset ## Layer Size Setup self.batch_size = 1 self.n_units = n_units ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) self.j = Compartment(restVals) self.v = Compartment(restVals + self.v0) self.w = Compartment(restVals + self.w0) self.s = Compartment(restVals) self.tols = Compartment(restVals) ## time-of-last-spike
[docs] @compilable def advance_state(self, t, dt): j_mod = self.j.get() * self.resist_m if self.intgFlag == 1: v_params = (j_mod, self.w.get(), self.alpha, self.beta, self.gamma, self.tau_m) _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt) w_params = (j_mod, self.v.get(), self.alpha, self.beta, self.gamma, self.tau_w) _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt) else: # integType == 0 (default -- Euler) v_params = (j_mod, self.w.get(), self.alpha, self.beta, self.gamma, self.tau_m) _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt) w_params = (j_mod, self.v.get(), self.alpha, self.beta, self.gamma, self.tau_w) _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt) s = (_v > self.v_thr) * 1. v = _v w = _w if self.spike_reset: ## if spike-reset used, variables snapped back to initial conditions v = v * (1. - s) + s * self.v0 w = w * (1. - s) + s * self.w0 ## update time-of-last spike variable(s) self.tols.set((1. - s) * self.tols.get() + (s * t)) # self.j.set(j) ## j is not getting modified in these dynamics self.v.set(v) self.w.set(w) self.s.set(s)
[docs] @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) if not self.j.targeted: self.j.set(restVals) self.v.set(restVals + self.v0) self.w.set(restVals + self.w0) self.s.set(restVals) self.tols.set(restVals)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "FitzhughNagumoCell - evolves neurons according to nonlinear, " "Fizhugh-Nagumo dual-ODE spiking cell dynamics." } compartment_props = { "inputs": {"j": "External input electrical current", "key": "JAX PRNG key"}, "states": {"v": "Membrane potential/voltage at time t", "w": "Recovery variable at time t"}, "outputs": {"s": "Emitted spikes/pulses at time t", "tols": "Time-of-last-spike"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "tau_m": "Cell membrane time constant", "resist_m": "Membrane resistance value", "tau_w": "Recovery variable time constant", "v_thr": "Base voltage threshold value", "spike_reset": "Should voltage/recover be snapped to initial condition(s) if spike emitted?", "alpha": "Dimensionless recovery variable shift factor `a", "beta": "Dimensionless recovery variable scale factor `b`", "gamma": "Power-term divisor constant", "v0": "Initial condition for membrane potential/voltage", "w0": "Initial condition for recovery variable", "integration_type": "Type of numerical integration to use for the cell dynamics" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "tau_m * dv/dt = v - (v^3)/3 - w + j * resist_m; " "tau_w * dw/dt = v + a - b * w", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = FitzhughNagumoCell("X", 9) print(X)