Source code for ngclearn.components.other.expKernel

from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from functools import partial
from ngclearn.utils import tensorstats
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment

@partial(jit, static_argnums=[5,6])
def _apply_kernel(tf_curr, s, t, tau_w, win_len, krn_start, krn_end):
    idx_sub = tf_curr.shape[0]-1
    tf_new = (s * t) ## track new spike time(s)
    tf = tf_curr.at[idx_sub,:,:].set(tf_new)
    tf = jnp.roll(tf, shift=-1, axis=0) # 1,2,3  2,3,1

    ## apply exp time kernel
    ## EPSP = sum_{tf} exp(-(t - tf)/tau_w)
    _tf = tf[krn_start:krn_end,:,:]
    mask = jnp.greater(_tf, 0.).astype(jnp.float32) # 1 for every valid tf (non-zero)
    epsp = jnp.sum( jnp.exp( -(t - _tf)/tau_w ) * mask, axis=0 )
    return tf, epsp

[docs] class ExpKernel(JaxComponent): ## exponential kernel """ A spiking function based on an exponential kernel applied to a moving window of spike times. | --- Cell Input Compartments: --- | inputs - input (takes in external signals) | --- Cell State Compartments: --- | tf - maintained local window of pulse signals | --- Cell Output Compartments: --- | epsp - excitatory postsynaptic potential/pulse Args: name: the string name of this operator n_units: number of calculating entities or units dt: integration time constant (the kernel needs access to this value) nu: (ms, spike time interval for window) tau_w: spike window time constant (in micro-secs, or nano-s) batch_size: batch size dimension of this cell (Default: 1) """ def __init__(self, name, n_units, dt, tau_w=500., nu=4., batch_size=1, **kwargs): super().__init__(name, **kwargs) self.tau_w = tau_w ## kernel window time constant self.nu = nu self.win_len = int(nu/dt) + 1 ## window length ## Layer Size Setup self.batch_size = batch_size self.n_units = n_units restVals = jnp.zeros((self.batch_size, self.n_units)) self.inputs = Compartment(restVals) # input comp self.epsp = Compartment(restVals) ## output comp ## window of spike times self.tf = Compartment(jnp.zeros((self.win_len, self.batch_size, self.n_units)))
[docs] @compilable def advance_state(self, t): # Get the variables inputs = self.inputs.get() tf = self.tf.get() s = inputs ## update spike time window and corresponding window volume tf, epsp = _apply_kernel( tf, s, t, self.tau_w, self.win_len, krn_start=0, krn_end=self.win_len-1 ) #0:win_len-1) # Update compartments self.epsp.set(epsp) self.tf.set(tf)
[docs] @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) ## inputs, epsp restTensor = jnp.zeros([self.win_len, self.batch_size, self.n_units], jnp.float32) ## tf # BUG: the self.inputs here does not have the targeted field # NOTE: Quick workaround is to check if targeted is in the input or not hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.epsp.set(restVals) self.tf.set(restTensor)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "ExpKernel - maintains an exponential kernel over " "incoming signal values (such as sequences of discrete pulses)" } compartment_props = { "inputs": {"inputs": "Takes in external input signal values"}, "states": {"tr": "Value signal (rolling) time window"}, "outputs": {"epsp": "Excitatory postsynaptic potential/pulse emitted at time t"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "batch_size": "Batch size dimension of this component", "dt": "Integration time constant (kernel needs knowledge of `dt`)", "nu": "Spike time interval for window", "tau_w": "Spike window time constant" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "epsp ~ Sum_{tf} exp(-(t - tf)/tau_w)", "hyperparameters": hyperparams} return info
def __repr__(self): comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: stats = tensorstats(getattr(self, c).value) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) else: line = "None" lines += f" {f'({c})'.ljust(maxlen)}{line}\n" return lines
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = ExpKernel("X", 1, 1.) print(X)