Source code for ngclearn.components.other.varTrace

# %%

from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from functools import partial
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment

@partial(jit, static_argnums=[4])
def _run_varfilter(dt, x, x_tr, decayFactor, gamma_tr, a_delta=0.):
    """
    Run variable trace filter (low-pass filter) dynamics one step forward.

    Args:
        dt: integration time constant (ms)

        x: variable value / stimulus input (at t)

        x_tr: currenet value of trace/filter

        decayFactor: coefficient to decay trace by before application of new value

        a_delta: increment to made to filter (multiplied w/ stimulus x)
    Returns:
        updated trace/filter value/state
    """
    _x_tr = gamma_tr * x_tr * decayFactor
    #x_tr + (-x_tr) * (dt / tau_tr) = (1 - dt/tau_tr) * x_tr
    if a_delta > 0.: ## perform additive form of trace ODE
        _x_tr = _x_tr + x * a_delta
        #_x_tr = x_tr + (-x_tr) * (dt / tau_tr) + x * a_delta
    else: ## run gated/piecewise ODE variant of trace
        _x_tr = _x_tr * (1. - x) + x
        #_x_tr = ( x_tr + (-x_tr) * (dt / tau_tr) ) * (1. - x) + x
    return _x_tr

[docs] class VarTrace(JaxComponent): ## low-pass filter """ A variable trace (filter) functional node. | --- Cell Input Compartments: --- | inputs - input (takes in external signals) | --- Cell State Compartments: --- | trace - traced value signal | --- Cell Output Compartments: --- | outputs - output signal (same as "trace" compartment) | trace - traced value signal (can be treated as output compartment) Args: name: the string name of this operator n_units: number of calculating entities or units tau_tr: trace time constant (in milliseconds, or ms) a_delta: value to increment a trace by in presence of a spike; note if set to a value <= 0, then a piecewise gated trace will be used instead P_scale: if `a_delta=0`, then this scales the value that the trace snaps to upon receiving a pulse value gamma_tr: an extra multiplier in front of the leak of the trace (Default: 1) decay_type: string indicating the decay type to be applied to ODE integration; low-pass filter configuration :Note: string values that this can be (Default: "exp") are: 1) `'lin'` = linear trace filter, i.e., decay = x_tr + (-x_tr) * (dt/tau_tr); 2) `'exp'` = exponential trace filter, i.e., decay = exp(-dt/tau_tr) * x_tr; 3) `'step'` = step trace, i.e., decay = 0 (a pulse applied upon input value) n_nearest_spikes: (k) if k > 0, this makes the trace act like a nearest-neighbor trace, i.e., k = 1 yields the 1-nearest (neighbor) trace (Default: 0) batch_size: batch size dimension of this cell (Default: 1) """ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay_type="exp", n_nearest_spikes=0, batch_size=1, key=None): super().__init__(name, key) ## Trace control coefficients self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay self.tau_tr = tau_tr ## trace time constant self.a_delta = a_delta ## trace increment (if spike occurred) self.P_scale = P_scale ## trace scale if non-additive trace to be used self.gamma_tr = gamma_tr self.n_nearest_spikes = n_nearest_spikes ## Layer Size Setup self.batch_size = batch_size self.n_units = n_units restVals = jnp.zeros((self.batch_size, self.n_units)) self.inputs = Compartment(restVals) # input compartment self.outputs = Compartment(restVals) # output compartment self.trace = Compartment(restVals)
[docs] @compilable def advance_state(self, dt): if "exp" in self.decay_type: decayFactor = jnp.exp(-dt/self.tau_tr) elif "lin" in self.decay_type: decayFactor = (1. - dt/self.tau_tr) else: decayFactor = 0. _x_tr = self.gamma_tr * self.trace.get() * decayFactor if self.n_nearest_spikes > 0: _x_tr = _x_tr + self.inputs.get() * (self.a_delta - (self.trace.get() / self.n_nearest_spikes)) else: if self.a_delta > 0.: _x_tr = _x_tr + self.inputs.get() * self.a_delta else: _x_tr = _x_tr * (1. - self.inputs.get()) + self.inputs.get() * self.P_scale self.trace.set(_x_tr) self.outputs.set(_x_tr)
[docs] @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) # BUG: the self.inputs here does not have the targeted field # NOTE: Quick workaround is to check if targeted is in the input or not hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.trace.set(restVals)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "VarTrace - maintains a low pass filter over incoming signal " "values (such as sequences of discrete pulses)" } compartment_props = { "inputs": {"inputs": "Takes in external input signal values"}, "states": {"trace": "Continuous low-pass filtered signal values, at time t"}, "outputs": {"outputs": "Continuous low-pass filtered signal values, " "at time t (same as `trace`)"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "batch_size": "Batch size dimension of this component", "tau_tr": "Trace/filter time constant", "a_delta": "Increment to apply to trace (if not set to 0); " "otherwise, traces clamp to 1 and then decay", "P_scale": "Max value to snap trace to if a max-clamp trace is triggered/configured", "decay_type": "Indicator of what type of decay dynamics to use " "as filter is updated at time t", "n_nearest_neighbors": "Number of nearest pulses to affect/increment trace (if > 0)" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "tau_tr * dz/dt ~ -z + inputs * a_delta (full convolution trace); " "tau_tr * dz/dt ~ -z + inputs * (a_delta - z/n_nearest_neighbors) (near trace)", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = VarTrace("X", 9, 0.0004, 3) print(X)