Source code for ngclearn.components.input_encoders.latencyCell

from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from functools import partial
import jax
from typing import Union

from ngclearn.utils.model_utils import clamp_min, clamp_max

from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment

@partial(jit, static_argnums=[5])
def _calc_spike_times_linear(data, tau, thr, first_spk_t, num_steps=1.,
                             normalize=False):
    """
    Computes spike times from data according to a linear latency encoding scheme.

    Args:
        data: pattern data to convert to spikes/times

        tau: latency coding time constant

        thr: latency coding threshold value

        first_spk_t: first spike time(s) (either int or vector
            with same shape as spk_times; in ms)

        num_steps: number of total time steps of simulation to consider

        normalize: normalize the logarithmic latency code values (uses num_steps)

    Returns:
        projected spike times
    """
    _tau = tau
    if normalize:
        _tau = num_steps - 1. - first_spk_t ## linear normalization
    #torch.clamp_max((-tau * (data - 1)), -tau * (threshold - 1))
    stimes = -_tau * (data - 1.) ## calc raw latency code values
    max_bound = -_tau * (thr - 1.) ## upper bound latency code values
    stimes = clamp_max(stimes, max_bound) ## apply upper bound
    return stimes + first_spk_t

@partial(jit, static_argnums=[6])
def _calc_spike_times_nonlinear(data, tau, thr, first_spk_t, eps=1e-7,
                                num_steps=1., normalize=False):
    """
    Computes spike times from data according to a logarithmic encoding scheme.

    Args:
        data: pattern data to convert to spikes/times

        tau: latency coding time constant

        thr: latency coding threshold value

        first_spk_t: first spike time(s) (either int or vector
            with same shape as spk_times; in ms)

        eps: small numerical error control factor (added to thr)

        num_steps: number of total time steps of simulation to consider

        normalize: normalize the logarithmic latency code values (uses num_steps)

    Returns:
        projected spike times
    """
    _data = clamp_min(data, thr + eps) # saturates all values below threshold.
    stimes = jnp.log(_data / (_data - thr)) * tau ## calc spike times
    stimes = stimes + first_spk_t

    if normalize:
        term1 = (stimes - first_spk_t)
        term2 = (num_steps - first_spk_t - 1.)
        term3 = jnp.max(stimes - first_spk_t)
        stimes = term1 * (term2 / term3) + first_spk_t
    return stimes

@jit
def _extract_spike(spk_times, t, mask):
    """
    Extracts a spike from a latency-coded spike train.

    Args:
        spk_times: spike times to compare against

        t: current time

        mask: prior spike mask (1 if spike has occurred, 0 otherwise)

    Returns:
        binary spikes, boolean mask to indicate if spikes have occurred as of yet
    """
    _spk_times = jnp.round(spk_times) # snap times to nearest integer time
    spikes_t = (_spk_times <= t).astype(jnp.float32) # get spike
    spikes_t = spikes_t * (1. - mask)
    _mask = mask + (1. - mask) * spikes_t
    return spikes_t, _mask

[docs] class LatencyCell(JaxComponent): """ A (nonlinear) latency encoding (spike) cell; produces a time-lagged set of spikes on-the-fly. | --- Cell Input Compartments: --- | inputs - input (takes in external signals) | --- Cell State Compartments: --- | targ_sp_times - target-spike-time | mask - spike-ordering mask | key - JAX PRNG key | --- Cell Output Compartments: --- | outputs - output | tols - time-of-last-spike Args: name: the string name of this cell n_units: number of cellular entities (neural population size) tau: time constant for model used to calculate firing time (Default: 1 ms) threshold: sensory input features below this threhold value will fire at final step in time of this latency coded spike train first_spike_time: time of first allowable spike (ms) (Default: 0 ms) linearize: should the linear latency encoding scheme be used? (otherwise, defaults to logarithmic latency encoding) normalize: normalize the latency code such that final spike(s) occur a pre-specified number of simulation steps "num_steps"? (Default: False) :Note: if this set to True, you will need to choose a useful value for the "num_steps" argument (>1), depending on how many steps simulated clip_spikes: should values under threshold be removed/suppressed? (default: False) num_steps: number of discrete time steps to consider for normalized latency code (only useful if "normalize" is set to True) (Default: 1) batch_size: batch size dimension of this cell (Default: 1) """ def __init__( self, name: str, n_units: int, tau: float = 1., threshold: float = 0.01, first_spike_time: float = 0., linearize: bool = False, normalize: bool = False, clip_spikes: bool = False, num_steps: float = 1., batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs ): super().__init__(name=name, key=key) ## latency meta-parameters self.first_spike_time = first_spike_time self.tau = tau self.threshold = threshold self.linearize = linearize self.clip_spikes = clip_spikes ## normalize latency code s.t. final spike(s) occur w/in num_steps self.normalize = normalize self.num_steps = num_steps ## Layer Size Setup self.batch_size = batch_size self.n_units = n_units ## Compartment setup restVals = jnp.zeros((batch_size, n_units)) self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment self.outputs = Compartment(restVals, display_name="Spikes") # output compartment self.mask = Compartment(restVals, display_name="Spike Time Mask") self.clip_mask = Compartment(restVals, display_name="Clip Mask") self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms")
[docs] @compilable def calc_spike_times(self): if self.clip_spikes: self.clip_mask.set((self.inputs.get() < self.threshold) * 1.) else: self.clip_mask.set(self.inputs.get() * 0.) if self.linearize: self.targ_sp_times.set( _calc_spike_times_linear(self.inputs.get(), self.tau.get(), self.threshold.get(), self.first_spike_time.get(), self.num_steps.get(), self.normalize.get())) else: self.targ_sp_times.set( _calc_spike_times_nonlinear(self.inputs.get(), self.tau.get(), self.threshold.get(), self.first_spike_time.get(), self.num_steps.get(), self.normalize.get()))
[docs] @compilable def advance_state(self, t): spikes, spike_mask = _extract_spike(self.targ_sp_times.get(), t, self.mask.get()) self.tols.set((1. - spikes) * self.tols.get() + (spikes * t)) self.outputs.set(spikes * (1. - self.clip_mask.get())) self.mask.set(spike_mask)
[docs] @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) # BUG: the self.inputs here does not have the targeted field # NOTE: Quick workaround is to check if targeted is in the input or not hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) self.mask.set(restVals) self.clip_mask.set(restVals) self.targ_sp_times.set(restVals)
[docs] @classmethod def help(cls): ## component help function properties = { "cell_type": "LatencyCell - samples input to produce spikes via latency " "coding, where each dimension's magnitude determines how " "early in the spike train a value occurs. This is a " "temporal/order encoder." } compartment_props = { "inputs": {"inputs": "Takes in external input signal values"}, "states": {"targ_sp_times": "Target spike times", "mask": "Spike ordering mask", "key": "JAX PRNG key"}, "outputs": {"tols": "Time-of-last-spike", "outputs": "Binary spike values emitted at time t"}, } hyperparams = { "n_units": "Number of neuronal cells to model in this layer", "batch_size": "Batch size dimension of this component", "threshold": "Spike threshold (constant and shared across neurons)", "linearize": "Should a linear latency encoding be used?", "normalize": "Should the latency code(s) be normalized?", "num_steps": "Number of total time steps of simulation to consider (" "useful for target spike time computation", "first_spike_time": "Time of first allowable spike" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "~ Latency(x)", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = LatencyCell("X", 9) print(X) print(X.calc_spike_times.compiled.code) print(X.advance_state.compiled.code)