Source code for ngclearn.components.synapses.denseSynapse

from jax import random, numpy as jnp, jit
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.distribution_generator import DistributionGenerator
from ngcsimlib.logger import info

from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment

[docs] class DenseSynapse(JaxComponent): ## base dense synaptic cable """ A dense synaptic cable; no form of synaptic evolution/adaptation is in-built to this component. | --- Synapse Compartments: --- | inputs - input (takes in external signals) | outputs - output signals | weights - current value matrix of synaptic efficacies (strength values) | biases - current value vector of synaptic bias values Args: name: the string name of this cell shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of inputs by number of outputs) weight_init: a kernel to drive initialization of this synaptic cable's values; typically a tuple with 1st element as a string calling the name of initialization to use bias_init: a kernel to drive initialization of biases for this synaptic cable (Default: None, which turns off/disables biases) resist_scale: a fixed (resistance) scaling factor to apply to synaptic transform (Default: 1.), i.e., yields: out = ((W * in) * resist_scale) + bias p_conn: probability of a connection existing (default: 1.); setting this to < 1 and > 0. will result in a sparser synaptic structure (lower values yield sparse structure) """ def __init__( self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs ): super().__init__(name, **kwargs) self.batch_size = batch_size ## Synapse meta-parameters self.shape = shape self.resist_scale = resist_scale ## Set up synaptic weight values tmp_key, *subkeys = random.split(self.key.get(), 4) if weight_init is None: info(self.name, "is using default weight initializer!") # self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8} weight_init = DistributionGenerator.uniform(0.025, 0.8) weights = weight_init(shape, subkeys[0]) if 0. < p_conn < 1.: ## Modifier/constraint: only non-zero and <1 probs allowed p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape) weights = weights * p_mask ## sparsify matrix ## Compartment setup preVals = jnp.zeros((self.batch_size, shape[0])) postVals = jnp.zeros((self.batch_size, shape[1])) self.inputs = Compartment(preVals) self.outputs = Compartment(postVals) self.weights = Compartment(weights) ## Set up (optional) bias values if bias_init is None: info(self.name, "is using default bias value of zero (no bias kernel provided)!") self.biases = Compartment(bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0) ## pin weight/bias initializers to component self.weight_init = weight_init self.bias_init = bias_init
[docs] @compilable def advance_state(self): self.outputs.set((jnp.matmul(self.inputs.get(), self.weights.get()) * self.resist_scale) + self.biases.get())
[docs] @compilable def reset(self): if not self.inputs.targeted: self.inputs.set(jnp.zeros((self.batch_size, self.shape[0]))) self.outputs.set(jnp.zeros((self.batch_size, self.shape[1])))
[docs] @classmethod def help(cls): ## component help function properties = { "synapse_type": "DenseSynapse - performs a synaptic transformation " "of inputs to produce output signals (e.g., a " "scaled linear multivariate transformation)" } compartment_props = { "inputs": {"inputs": "Takes in external input signal values"}, "states": {"weights": "Synapse efficacy/strength parameter values", "biases": "Base-rate/bias parameter values", "key": "JAX PRNG key"}, "outputs": {"outputs": "Output of synaptic transformation"}, } hyperparams = { "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", "batch_size": "Batch size dimension of this component", "weight_init": "Initialization conditions for synaptic weight (W) values", "bias_init": "Initialization conditions for bias/base-rate (b) values", "resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation", "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)" } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "outputs = [W * inputs] * Rscale + b", "hyperparameters": hyperparams} return info
if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: Wab = DenseSynapse("Wab", (2, 3)) print(Wab)