Source code for ngclearn.components.synapses.hebbian.gerstnerHebbianSynapse

import jax.numpy as jnp
from jax import random, jit

from ngclearn import compilable
from ngclearn import Compartment
from ngclearn.components.synapses import DenseSynapse
from ngclearn.utils import tensorstats
from ngcsimlib import deprecate_args
#from ngclearn.utils.io_utils import save_pkl, load_pkl

[docs] class GerstnerHebbianSynapse(DenseSynapse): """ A synapse component that implements Gerstner's general Hebbian learning (Taylor) expansion (Equation 3 from Gerstner & Kistler, 2002). Note that this synpatic update model can recover several classical forms of Hebbian-like update rules, including the covariance rule. There are other higher-order terms possible, i.e., \Theta(xy), such as x * y2 and y x^2, etc. | c2_corr > 0 and c0 = c1_pre = c1_post = 0 => Hebbian update | c2_corr < 0 and c0 = c1_pre = c1_post = 0 => anti-Hebbian update | c2_corr = 1 and c1_pre = -x_theta < 0 """ def __init__( self, name, shape, ## (post_dim, pre_dim) eta=0.01, ## global step-size coeffs=None, ## these configure which kind of Hebb learning is done weight_init=None, p_conn=1., resist_scale=1., sign_value=1., batch_size=1, **kwargs ): bias_init = None ## no biases are included in Gerster's formulation super().__init__( name, shape=shape, weight_init=weight_init, bias_init=bias_init, resist_scale=resist_scale, p_conn=p_conn, batch_size=batch_size, **kwargs ) ## General Hebbian meta-parameters self.eta = eta self.sign_value = sign_value ## Expansion coefficients (c0, c1_pre, c1_post, c2_corr) if coeffs is None: ## Default to standard bilinear Hebb self.coeffs = { 'c0': 0., 'c1_pre': 0., 'c1_post': 0., 'c2_corr': 1.0 } else: self.coeffs = coeffs self.c0 = self.coeffs['c0'] self.c1_pre = self.coeffs['c1_pre'] self.c1_post = self.coeffs['c1_post'] self.c2_corr = self.coeffs['c2_corr'] # Initialize Weights (using JAX PRNG) #init_key, _ = random.split(self.key) #w_init = random.normal(init_key, shape) * 0.05 # Compartments (ngc-learn state management) #self.weights = Compartment(w_init) self.pre = Compartment(jnp.zeros((1, shape[1]))) self.post = Compartment(jnp.zeros((1, shape[0])))
[docs] @compilable def evolve(self, **kwargs): """ Updates weights using the Gerstner general expansion. Assumes pre_act and post_act compartments have been populated. """ # Retrieve current states W = self.weights.get() x = self.pre.get() # pre-synaptic activity (batch, pre_dim) y = self.post.get() # post-synaptic activity (batch, post_dim) batch_size = self.batch_size ## Bilinear Term (c2): correlation matrix ### (post_dim, batch) @ (batch, pre_dim) -> (post_dim, pre_dim) dW_corr = jnp.matmul(x.T, y) * (1./batch_size) ## Linear pre-synaptic term (c1_pre) ### Average over batch then broadcast to match weight matrix dW_pre = jnp.sum(x, axis=0, keepdims=True).T * (1./batch_size) ## Linear post-synaptic term (c1_post) dW_post = jnp.sum(y, axis=0, keepdims=True) * (1./batch_size) ## Apply Equation 3 Taylor expansion dW = (self.c0 * W + ## synaptic decay self.c1_pre * dW_pre + ## bilinear term self.c1_post * dW_post + ## pre-synaptic gating term self.c2_corr * dW_corr ## post-synpatic gating term ) ## perform a step of Hebbian ascent W = W + self.eta * dW ## Update weights self.weights.set(W)
[docs] @compilable def reset(self, **kwargs): """Clears activity compartments""" self.pre.set( jnp.zeros((self.batch_size, self.shape[1])) ) self.post.set( jnp.zeros((self.batch_size, self.shape[0])) )