Source code for ngclearn.components.synapses.STPDenseSynapse

from jax import random, numpy as jnp, jit
from ngcsimlib.logger import info

from ngclearn.utils.distribution_generator import DistributionGenerator
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.synapses import DenseSynapse

[docs] class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable """ A dynamic dense synaptic cable; this synapse evolves according to short-term plasticity (STP) dynamics. | --- Synapse Compartments: --- | inputs - input (takes in external signals) | outputs - output signals | weights - current value matrix of synaptic efficacies | biases - current value vector of synaptic bias values | --- Short-Term Plasticity Compartments: --- | resources - fixed value matrix of synaptic resources (U) | u - release probability; fraction of resources ready for use | x - fraction of resources available after neurotransmitter depletion | Dynamics note: | If tau_d >> tau_f and resources U are large, then synapse is STD-dominated | If tau_d << tau_f and resources U are small, then synases is STF-dominated Args: name: the string name of this cell shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of inputs by number of outputs) weight_init: a kernel to drive initialization of this synaptic cable's values; typically a tuple with 1st element as a string calling the name of initialization to use bias_init: a kernel to drive initialization of biases for this synaptic cable (Default: None, which turns off/disables biases) resist_scale: a fixed (resistance) scaling factor to apply to synaptic transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) p_conn: probability of a connection existing (default: 1.); setting this to < 1 and > 0. will result in a sparser synaptic structure (lower values yield sparse structure) tau_f: short-term facilitation (STF) time constant (default: `750` ms); note that setting this to `0` ms will disable STF tau_d: shoft-term depression time constant (default: `50` ms); note that setting this to `0` ms will disable STD resources_int: initialization kernel for synaptic resources matrix """ def __init__( self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., tau_f=750., tau_d=50., resources_init=None, **kwargs ): super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) ## STP meta-parameters self.resources_init = resources_init self.tau_f = tau_f self.tau_d = tau_d ## Set up short-term plasticity / dynamic synapse compartment values tmp_key, *subkeys = random.split(self.key.get(), 4) preVals = jnp.zeros((self.batch_size, shape[0])) self.u = Compartment(preVals) ## release prob variables self.x = Compartment(preVals + 1) ## resource availability variables self.Wdyn = Compartment(self.weights.get() * 0) ## dynamic synapse values if self.resources_init is None: info(self.name, "is using default resources value initializer!") #self.resources_init = {"dist": "uniform", "amin": 0.125, "amax": 0.175} # 0.15 self.resources_init = DistributionGenerator.uniform(low=0.125, high=0.175) self.resources = Compartment( self.resources_init(shape, subkeys[2]) #initialize_params(subkeys[2], self.resources_init, shape) ) ## matrix U - synaptic resources matrix
[docs] @compilable def advance_state(self, t, dt): s = self.inputs.get() ## compute short-term facilitation #u = u - u * (1./tau_f) + (resources * (1. - u)) * s if self.tau_f > 0.: ## compute short-term facilitation u = self.u.get() - self.u.get() * (1./self.tau_f) + (self.resources.get() * (1. - self.u.get())) * s else: u = self.resources.get() ## disabling STF yields fixed resource u variables ## compute dynamic synaptic values/conductances Wdyn = (self.weights.get() * u * self.x.get()) * s + self.Wdyn.get() * (1. - s) ## OR: -W/tau_w + W * u * x ## compute short-term depression x = self.x.get() if self.tau_d > 0.: x = x + (1. - x) * (1./self.tau_d) - u * x * s ## else, do nothing with x (keep it pointing to current x compartment) outputs = jnp.matmul(self.inputs.get(), Wdyn * self.resist_scale) + self.biases.get() self.outputs.set(outputs) self.u.set(u) self.x.set(x) self.Wdyn.set(Wdyn)
[docs] @compilable def reset(self): preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) if not self.inputs.targeted: self.inputs.set(preVals) self.outputs.set(postVals) self.u.set(preVals) self.x.set(preVals + 1) self.Wdyn.set(jnp.zeros(self.shape.get()))
# def save(self, directory, **kwargs): # file_name = directory + "/" + self.name + ".npz" # if self.bias_init != None: # jnp.savez(file_name, # weights=self.weights.value, # biases=self.biases.value, # resources=self.resources.value) # else: # jnp.savez(file_name, # weights=self.weights.value, # resources=self.resources.value) # # def load(self, directory, **kwargs): # file_name = directory + "/" + self.name + ".npz" # data = jnp.load(file_name) # self.weights.set(data['weights']) # self.resources.set(data['resources']) # if "biases" in data.keys(): # self.biases.set(data['biases'])
[docs] @classmethod def help(cls): ## component help function properties = { "synapse_type": "STPDenseSynapse - performs a synaptic transformation of inputs to produce " "output signals (e.g., a scaled linear multivariate transformation); " "this synapse is dynamic, adapting via a form of short-term plasticity" } compartment_props = { "inputs": {"inputs": "Takes in external input signal values"}, "states": {"weights": "Synapse efficacy/strength parameter values", "biases": "Base-rate/bias parameter values", "resources": "Synaptic resource paramter values (U)", "u": "Release probability variables", "x": "Resource depletion variables", "key": "JAX PRNG key"}, "outputs": {"outputs": "Output of synaptic transformation"}, } hyperparams = { "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", "weight_init": "Initialization conditions for synaptic weight (W) values", "bias_init": "Initialization conditions for bias/base-rate (b) values", "resist_scale": "Resistance level scaling factor (applied to output of transformation)", "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", "tau_f": "Short-term facilitation time constant", "tau_d": "Short-term depression time constant" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "outputs = [(W * Rscale) * inputs] + b; " "dW/dt = W_full * u * x * inputs", "hyperparameters": hyperparams} return info