Source code for ngclearn.components.synapses.hebbian.eventSTDPSynapse

from jax import random, numpy as jnp, jit
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.synapses.denseSynapse import DenseSynapse

[docs] class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP """ A synaptic cable that adjusts its efficacies via event-driven, post-synaptic spike-timing-dependent plasticity (STDP). | --- Synapse Compartments: --- | inputs - input (takes in external signals) | outputs - output signals (transformation induced by synapses) | weights - current value matrix of synaptic efficacies | key - JAX PRNG key | --- Synaptic Plasticity Compartments: --- | pre_tols - pre-synaptic time-of-last-spike (tols) to drive 1st term of STDP update (takes in external values) | postSpike - post-synaptic spike to drive 2nd term of STDP update (takes in external signals) | dWeights - current delta matrix containing changes to be applied to synaptic efficacies | References: | Tavanaei, Amirhossein, Timothée Masquelier, and Anthony Maida. | "Representation learning using event-based STDP." Neural Networks 105 | (2018): 294-303. Args: name: the string name of this cell shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of inputs by number of outputs) eta: global learning rate initial condition lmbda: controls degree of synaptic disconnect ("lambda") A_plus: strength of long-term potentiation (LTP) A_minus: strength of long-term depression (LTD) presyn_win_len: pre-synaptic window time, or how far back in time to look for the presence of a pre-synaptic spike, in milliseconds (default: 1 ms) w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1) weight_init: a kernel to drive initialization of this synaptic cable's values; typically a tuple with 1st element as a string calling the name of initialization to use resist_scale: a fixed scaling factor to apply to synaptic transform (Default: 1), i.e., yields: out = ((W * Rscale) * in) + b p_conn: probability of a connection existing (default: 1.); setting this to < 1. will result in a sparser synaptic structure """ def __init__( self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1., presyn_win_len=2., w_bound=1., weight_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs ): super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) ## Synaptic hyper-parameters self.eta = eta ## global learning rate governing plasticity self.lmbda = lmbda ## controls scaling of STDP rule self.presyn_win_len = presyn_win_len assert self.presyn_win_len >= 0. ## pre-synaptic window must be non-negative self.Aplus = A_plus self.Aminus = A_minus self.Rscale = resist_scale ## post-transformation scale factor self.w_bound = w_bound ## soft weight constraint ## Compartment setup preVals = jnp.zeros((self.batch_size, shape[0])) postVals = jnp.zeros((self.batch_size, shape[1])) self.pre_tols = Compartment(preVals) self.postSpike = Compartment(postVals) self.dWeights = Compartment(self.weights.get() * 0) self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate governing plasticity def _compute_update(self, t, dt): ## synaptic adjustment calculation co-routine ## check if a spike occurred in window of (t - presyn_win_len, t] m = (self.pre_tols.get() > 0.) * 1. ## ignore default value of tols = 0 ms if self.presyn_win_len > 0.: lbound = ((t - self.presyn_win_len) < self.pre_tols.get()) * 1. preSpike = lbound * m else: check_spike = (self.pre_tols.get() == t) * 1. preSpike = check_spike * m ## this implements a generalization of the rule in eqn 18 of the paper pos_shift = self.w_bound - (self.weights.get() * (1. + self.lmbda)) pos_shift = pos_shift * self.Aplus neg_shift = -self.weights.get() * (1. + self.lmbda) neg_shift = neg_shift * self.Aminus dW = jnp.where(preSpike.T, pos_shift, neg_shift) # at pre-spikes => LTP, else decay dW = (dW * self.postSpike.get()) ## gate to make sure only post-spikes trigger updates return dW
[docs] @compilable def evolve(self, t, dt): dWeights = self._compute_update(t, dt) weights = self.weights.get() + dWeights * self.eta # * (1. - w) * eta weights = jnp.clip(weights, 0.01, self.w_bound) ## Note: this step not in source paper self.weights.set(weights) self.dWeights.set(dWeights)
[docs] @compilable def reset(self): preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) if not self.inputs.targeted: self.inputs.set(preVals) self.outputs.set(postVals) self.pre_tols.set(preVals) ## pre-synaptic time-of-last-spike(s) record self.postSpike.set(postVals) self.dWeights.set(jnp.zeros(self.shape.get()))
[docs] @classmethod def help(cls): ## component help function properties = { "synapse_type": "EventSTDPSynapse - performs an adaptable synaptic " "transformation of inputs to produce output signals; " "synapses are adjusted with event-based post-synaptic " "spike-timing-dependent plasticity (STDP)" } compartment_props = { "inputs": {"inputs": "Takes in external input signal values", "pre_tols": "Pre-synaptic time-of-last-spike (`tols` for s_j)", "postSpike": "Post-synaptic spike compartment value/term for STDP (s_i)"}, "states": {"weights": "Synapse efficacy/strength parameter values", "biases": "Base-rate/bias parameter values", "eta": "Global learning rate (multiplier beyond A_plus and A_minus)", "key": "JAX PRNG key"}, "analytics": {"dWeights": "Synaptic weight value adjustment matrix produced at time t"}, "outputs": {"outputs": "Output of synaptic transformation"}, } hyperparams = { "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", "batch_size": "Batch size dimension of this component", "weight_init": "Initialization conditions for synaptic weight (W) values", "resist_scale": "Resistance level scaling factor (applied to output of transformation)", "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", "lmbda": "Degree of synaptic disconnect", "eta": "Global learning rate (multiplier beyond A_plus and A_minus)", "w_bound ": "Maximum value/magnitude that any single synapse can take" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "outputs = [(W * Rscale) * inputs] ;" "dW_{ij}/dt = eta * [ (1 - W_{ij}(1 + lmbda)) * s_j - W_{ij} * (1 + lmbda) * s_j ]", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: Wab = EventSTDPSynapse("Wab", (2, 3), 1.) print(Wab)