Source code for ngclearn.components.synapses.hebbian.expSTDPSynapse

from ngcsimlib.component import Component
from jax import random, numpy as jnp, jit
from functools import partial
import time

@partial(jit, static_argnums=[6,7,8,9,10,11,12])
def evolve(dt, pre, x_pre, post, x_post, W, w_bound=1., eta=0.00005,
            x_tar=0.7, exp_beta=1., Aplus=1., Aminus=0., w_norm=None):
    """
    Evolves/changes the synpatic value matrix underlying this synaptic cable,
    given relevant statistics.

    Args:
        pre: pre-synaptic statistic to drive update

        x_pre: pre-synaptic trace value

        post: post-synaptic statistic to drive update

        x_post: post-synaptic trace value

        W: synaptic weight values (at time t)

        w_bound: maximum value to enforce over newly computed efficacies

        eta: global learning rate to apply to the Hebbian update

        x_tar: controls degree of pre-synaptic disconnect

        exp_beta: controls effect of exponential Hebbian shift/dependency

        Aplus: strength of long-term potentiation (LTP)

        Aminus: strength of long-term depression (LTD)

        w_norm: if not None, applies an L2 norm constraint to synapses

    Returns:
        the newly evolved synaptic weight value matrix
    """
    ## equations 4 from Diehl and Cook - full exponential weight-dependent STDP
    ## calculate post-synaptic term
    post_term1 = jnp.exp(-exp_beta * W) * jnp.matmul(x_pre.T, post)
    x_tar_vec = x_pre * 0 + x_tar # need to broadcast scalar x_tar to mat/vec form
    post_term2 = jnp.exp(-exp_beta * (w_bound - W)) * jnp.matmul(x_tar_vec.T, post)
    dWpost = (post_term1 - post_term2) * Aplus
    ## calculate pre-synaptic term
    dWpre = 0.
    if Aminus > 0.:
        dWpre = -jnp.exp(-exp_beta * W) * jnp.matmul(pre.T, x_post) * Aminus

    ## calc final weighted adjustment
    dW = (dWpost + dWpre) * eta
    _W = W + dW
    if w_norm is not None:
        _W = _W * (w_norm/(jnp.linalg.norm(_W, axis=1, keepdims=True) + 1e-5))
    _W = jnp.clip(_W, 0.01, w_bound) # not in source paper
    return _W

@jit
def compute_layer(inp, weight):
    """
    Applies the transformation/projection induced by the synaptic efficacie
    associated with this synaptic cable

    Args:
        inp: signal input to run through this synaptic cable

        weight: this cable's synaptic value matrix

    Returns:
        a projection/transformation of input "inp"
    """
    return jnp.matmul(inp, weight)

[docs] class ExpSTDPSynapse(Component): """ A synaptic cable that adjusts its efficacies via trace-based form of spike-timing-dependent plasticity (STDP) based on an exponential weight dependence (the strength of which is controlled by a factor). | References: | Nessler, Bernhard, et al. "Bayesian computation emerges in generic cortical | microcircuits through spike-timing-dependent plasticity." PLoS computational | biology 9.4 (2013): e1003037. | | Bi, Guo-qiang, and Mu-ming Poo. "Synaptic modification by correlated | activity: Hebb's postulate revisited." Annual review of neuroscience 24.1 | (2001): 139-166. Args: name: the string name of this cell shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of inputs by number of outputs) eta: global learning rate exp_beta: controls effect of exponential Hebbian shift/dependency Aplus: strength of long-term potentiation (LTP) Aminus: strength of long-term depression (LTD) preTrace_target: controls degree of pre-synaptic disconnect, i.e., amount of decay (higher -> lower synaptic values) wInit: a kernel to drive initialization of this synaptic cable's values; typically a tuple with 1st element as a string calling the name of initialization to use, e.g., ("uniform", -0.1, 0.1) samples U(-1,1) for each dimension/value of this cable's underlying value matrix key: PRNG key to control determinism of any underlying random values associated with this synaptic cable useVerboseDict: triggers slower, verbose dictionary mode (Default: False) directory: string indicating directory on disk to save synaptic parameter values to (i.e., initial threshold values and any persistent adaptive threshold values) """ ## Class Methods for Compartment Names
[docs] @classmethod def inputCompartmentName(cls): return 'in'
[docs] @classmethod def outputCompartmentName(cls): return 'out'
[docs] @classmethod def presynapticTraceName(cls): return 'x_pre'
[docs] @classmethod def postsynapticTraceName(cls): return 'x_post'
## Bind Properties to Compartments for ease of use @property def inputCompartment(self): return self.compartments.get(self.inputCompartmentName(), None) @inputCompartment.setter def inputCompartment(self, x): self.compartments[self.inputCompartmentName()] = x @property def outputCompartment(self): return self.compartments.get(self.outputCompartmentName(), None) @outputCompartment.setter def outputCompartment(self, x): self.compartments[self.outputCompartmentName()] = x @property def presynapticTrace(self): return self.compartments.get(self.presynapticTraceName(), None) @presynapticTrace.setter def presynapticTrace(self, x): self.compartments[self.presynapticTraceName()] = x @property def postsynapticTrace(self): return self.compartments.get(self.postsynapticTraceName(), None) @postsynapticTrace.setter def postsynapticTrace(self, x): self.compartments[self.postsynapticTraceName()] = x # Define Functions def __init__(self, name, shape, eta, exp_beta, Aplus, Aminus, preTrace_target, wInit=(0.025, 0.8), key=None, useVerboseDict=False, directory=None, **kwargs): super().__init__(name, useVerboseDict, **kwargs) ##Random Number Set up self.key = key if self.key is None: self.key = random.PRNGKey(time.time_ns()) ##parms self.shape = shape ## shape of synaptic efficacy matrix self.eta = eta ## global learning rate governing plasticity self.exp_beta = exp_beta ## if not None, will trigger exp-depend STPD rule self.preTrace_target = preTrace_target ## target (pre-synaptic) trace activity value # 0.7 self.Aplus = Aplus ## LTP strength self.Aminus = Aminus ## LTD strength self.shape = shape # shape of synaptic matrix W self.w_bound = 1. ## soft weight constraint self.w_norm = None ## normalization constant for synaptic matrix after update if directory is None: self.key, subkey = random.split(self.key) lb, ub = wInit self.weights = random.uniform(subkey, shape, minval=lb, maxval=ub) else: self.load(directory) ##Reset to initialize core compartments self.reset()
[docs] def verify_connections(self): self.metadata.check_incoming_connections(self.inputCompartmentName(), min_connections=1)
[docs] def advance_state(self, dt, t, **kwargs): ## run signals across synapses self.outputCompartment = compute_layer(self.inputCompartment, self.weights)
[docs] def evolve(self, dt, t, **kwargs): pre = self.inputCompartment post = self.outputCompartment x_pre = self.presynapticTrace x_post = self.postsynapticTrace self.weights = evolve(dt, pre, x_pre, post, x_post, self.weights, w_bound=self.w_bound, eta=self.eta, x_tar=self.preTrace_target, exp_beta=self.exp_beta, Aplus=self.Aplus, Aminus=self.Aminus)
[docs] def reset(self, **kwargs): self.inputCompartment = None self.outputCompartment = None self.presynapticTrace = None self.postsynapticTrace = None
[docs] def save(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" jnp.savez(file_name, weights=self.weights)
[docs] def load(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" data = jnp.load(file_name) self.weights = data['weights']