Source code for ngclearn.components.input_encoders.poissonCell

from ngcsimlib.component import Component
from jax import numpy as jnp, random, jit
from functools import partial
import time

@jit
def update_times(t, s, tols):
    """
    Updates time-of-last-spike (tols) variable.

    Args:
        t: current time (a scalar/int value)

        s: binary spike vector

        tols: current time-of-last-spike variable

    Returns:
        updated tols variable
    """
    _tols = (1. - s) * tols + (s * t)
    return _tols

@partial(jit, static_argnums=[3])
def sample_poisson(dkey, data, dt, fmax=63.75):
    """
    Samples a Poisson spike train on-the-fly.

    Args:
        dkey: JAX key to drive stochasticity/noise
        
        data: sensory data (vector/matrix)

        dt: integration time constant

        fmax: maximum frequency (Hz)

    Returns:
        binary spikes
    """
    pspike = data * (dt/1000.) * fmax
    eps = random.uniform(dkey, data.shape, minval=0., maxval=1., dtype=jnp.float32)
    s_t = (eps < pspike).astype(jnp.float32)
    return s_t

[docs] class PoissonCell(Component): """ A Poisson cell that produces approximately Poisson-distributed spikes on-the-fly. Args: name: the string name of this cell n_units: number of cellular entities (neural population size) max_freq: maximum frequency (in Hertz) of this Poisson spike train (must be > 0.) key: PRNG key to control determinism of any underlying synapses associated with this cell useVerboseDict: triggers slower, verbose dictionary mode (Default: False) """ ## Class Methods for Compartment Names
[docs] @classmethod def inputCompartmentName(cls): return 'in'
[docs] @classmethod def outputCompartmentName(cls): return 'out'
[docs] @classmethod def timeOfLastSpikeCompartmentName(cls): return 'tols'
## Bind Properties to Compartments for ease of use @property def inputCompartment(self): return self.compartments.get(self.inputCompartmentName(), None) @inputCompartment.setter def inputCompartment(self, inp): self.compartments[self.inputCompartmentName()] = inp @property def outputCompartment(self): return self.compartments.get(self.outputCompartmentName(), None) @outputCompartment.setter def outputCompartment(self, out): self.compartments[self.outputCompartmentName()] = out @property def timeOfLastSpike(self): return self.compartments.get(self.timeOfLastSpikeCompartmentName(), None) @timeOfLastSpike.setter def timeOfLastSpike(self, t): self.compartments[self.timeOfLastSpikeCompartmentName()] = t # Define Functions def __init__(self, name, n_units, max_freq=63.75, key=None, useVerboseDict=False, **kwargs): super().__init__(name, useVerboseDict, **kwargs) ##Random Number Set up self.key = key if self.key is None: self.key = random.PRNGKey(time.time_ns()) ## Poisson parameters self.max_freq = max_freq ## maximum frequency (in Hertz/Hz) ##Layer Size Setup self.batch_size = 1 self.n_units = n_units self.reset()
[docs] def verify_connections(self): pass
[docs] def advance_state(self, t, dt, **kwargs): self.key, *subkeys = random.split(self.key, 2) self.outputCompartment = sample_poisson(subkeys[0], data=self.inputCompartment, dt=dt, fmax=self.max_freq) #self.timeOfLastSpike = (1 - self.outputCompartment) * self.timeOfLastSpike + (self.outputCompartment * t) self.timeOfLastSpike = update_times(t, self.outputCompartment, self.timeOfLastSpike)
[docs] def reset(self, **kwargs): self.inputCompartment = None self.outputCompartment = jnp.zeros((self.batch_size, self.n_units)) #None self.timeOfLastSpike = jnp.zeros((self.batch_size, self.n_units))
[docs] def save(self, **kwargs): pass