Source code for ngclearn.components.synapses.competitive.hopfieldSynapse

from jax import random, numpy as jnp, jit, vmap
from ngclearn import compilable #from ngcsimlib.parser import compilable
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.utils.model_utils import softmax, bkwta

from ngclearn.components.synapses.denseSynapse import DenseSynapse

[docs] class HopfieldSynapse(DenseSynapse): # (Modern) Hopfield synaptic cable """ A synaptic cable that emulates a modern Hopfield network (MHN). Note that this model has been generalized a bit, a.l.a. NAC-Lab style, and comes equipped with two non-standard local plasticity update rules to alter the underlying memory matrix W from scratch (or to fine-tune an existing preloaded one); note that a mixed MHN can be created (one where initial patterns are stored but portions / elements of the memory matrix are further adapted in accordance to a local adjustment rule). This model currently only implements the exponential coupling/energy function. | --- Synapse Compartments: --- | inputs - input probe (takes in external signals) | outputs - output signals (retrieved memory / updated probe) | weights - current value matrix of synaptic efficacies | similarities - current raw similarity scores computed (pre-softmax) | memory_weights - current similarity scores computed (post-softmax) | i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`) | energy - current energy functional reading (given current clamped input probe) | key - JAX PRNG key | --- Synaptic Plasticity Compartments: --- | dWeights - current delta matrix containing changes to be applied to synapses | References: | Movellan, Javier R. "Contrastive Hebbian learning in the continuous Hopfield model." Connectionist models. | Morgan Kaufmann, 1991. 10-17. | | Krotov, Dmitry, and John Hopfield. "Large associative memory problem in neurobiology and machine learning." | arXiv preprint arXiv:2008.06996 (2020). | | Hintzman, Douglas L. "MINERVA 2: A simulation model of human memory." Behavior Research Methods, Instruments, | & Computers 16.2 (1984): 96-101. Args: name: the string name of this cell shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of inputs by number of outputs) eta: (initial) learning rate / step-size for this SOM (initial condition value for `eta`) reg_lambda: weight decay coefficient applied to Hebbian update beta: (inverse) temperature to control sharpness of memory similarity calculation initial_patterns: seed patterns to store within memory matrix (Default: None) update_rule: local plasticity rule to use to adjust/update memory matrix (Default: "delta"; Currently, two rules are encoded that work - a custom delta rule (prescribed error rule) and a custom contrastive Hebbian rule (Movellan/NAC-Lab-style) weight_init: a kernel to drive initialization of this synaptic cable's values; typically a tuple with 1st element as a string calling the name of initialization to use resist_scale: a fixed scaling factor to apply to synaptic transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) p_conn: probability of a connection existing (default: 1.); setting this to < 1. will result in a sparser synaptic structure """ def __init__( self, name, shape, eta, reg_lambda=0., beta=8., initial_patterns=None, update_rule = "delta", ## memory plasticity rule weight_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs ): super().__init__( name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs ) ### Synapse and Hopfield hyper-parameters self.eta = eta self.reg_lambda = reg_lambda #0.0001 ## regularization co-efficient self.l1_lambda = 0. #0.0001 ## coefficient for L1 decay self.beta = beta if initial_patterns is not None: ## preload memory synaptic matrix W = self.weights.get() D, H = W.shape tmp_key, *subkeys = random.split(self.key.get(), 3) if initial_patterns.shape[1] < H: ## randomly portions of memory with stored patterns/templates ptrs = random.permutation(subkeys[0], H) W = jnp.concat([initial_patterns, W[:, 0:(H - initial_patterns.shape[1])]], axis=1) W = W[:, ptrs] ## shuffle memories self.weights.set(W) else: ## memory is exactly the set of stored patterns/templates self.weights.set(initial_patterns) self.rule_fx = 2 ## Default: delta-rule if update_rule == "contrastive": self.rule_fx = 1 ## Hopfield Compartment setup inputVals = jnp.zeros((self.batch_size, shape[0])) simVals = jnp.zeros((self.batch_size, shape[0])) self.inputs = Compartment(inputVals) ## input shape = output shape self.outputs = Compartment(inputVals) ## output shape = input shape self.similarities = Compartment(simVals) ## "hidden layer" self.memory_weights = Compartment(simVals) self.energy = Compartment(jnp.zeros((1, 1)), display_name="Energy") self.i_tick = Compartment(jnp.zeros((1, 1))) self.dWeights = Compartment(self.weights.get() * 0)
[docs] @compilable def advance_state(self): ## forward-inference step of SOM WX = self.weights.get() probe_t = self.inputs.get() ## TODO: what about power/quadratic functions instead? (integrate Minerva power coupling) sims = jnp.matmul(probe_t, WX) ## similarities (w/ xn as probe) sims_max = jnp.max(sims, axis=1, keepdims=True) sims = sims - sims_max self.similarities.set(sims) ## similarities = "hidden layer" memory_weights = softmax(sims * self.beta) self.memory_weights.set(memory_weights) z = memory_weights probe_tp1 = jnp.matmul(z, WX.T) ## calc probe update self.outputs.set(probe_tp1) ## Calculate (modern) Hopfield energy functional N = WX.shape[1] ## how many neural memories are there max_sim_value = jnp.max(self.beta * sims, axis=1, keepdims=True) lse = max_sim_value + jnp.log(jnp.sum(jnp.exp(self.beta * sims - max_sim_value), axis=1, keepdims=True)) term1 = -(1. / self.beta) * lse term2 = 0.5 * jnp.expand_dims(jnp.diag(jnp.matmul(probe_t, probe_tp1.T)),axis=1) term3 = (1. / self.beta) * jnp.log(N) + 0.5 * jnp.max(jnp.linalg.norm(WX, ord=2, axis=1) ** 2) ## C Ex = jnp.mean(term1 + term2 + term3, axis=0, keepdims=True) #* (1. / probe_t.shape[0]) ## calc batch avg energy self.energy.set(Ex) self.i_tick.set(self.i_tick.get() + 1.) ## march internal tick forward
[docs] @compilable def evolve(self, t, dt): ## plasticity rule for changing this Hopfield network's memory matrix x = self.inputs.get() x_hat = self.outputs.get() s = self.memory_weights.get() W = self.weights.get() beta = self.beta ## TODO: make updates noisy? (perturbative) ## TODO: also, make a perturbation-based update synapse? if self.rule_fx == 1: ## contrastive (Movellan) Hebbian style plasticity ## TODO: add a loop to iterative over negative term several times ## we propagate the updated probe (negative) through memory to get a negative weighted state sims_hat = jnp.matmul(x_hat, W) s_hat = softmax(sims_hat - jnp.max(sims_hat, axis=1, keepdims=True) * beta) #s_hat = bkwta(s_hat, nWTA=1) ## positive Hebbian prod of probe+pos-state against negative Hebbian prod of updated-probe+neg-state term1 = (x.T @ s) term2 = -(x_hat.T @ s_hat) dW = term1 + term2 #elif self.rule_fx == XX: ## deriv of energy w.r.t. memory W rule # dW = x.T @ -s else: ## delta-rule (prescribed error rule) is the default dW = (x - x_hat).T @ s ## (deriv of MSE w.r.t. x_hat/updated probe) Ns = x.shape[0] ## get batch size dW = dW * (1./ Ns) ## we average batch updates ## TODO: add a term that checks if we need to append to memory W W = W + dW * self.eta - W * self.reg_lambda - jnp.sign(W) * self.l1_lambda ## actually adjust synaptic efficacies self.dWeights.set(dW) self.weights.set(W)
[docs] @compilable def reset(self): inputVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) outputVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) if not self.inputs.targeted: self.inputs.set(inputVals) self.outputs.set(inputVals) self.similarities.set(outputVals) self.memory_weights.set(outputVals) self.energy.set(self.energy.get() * 0) self.dWeights.set(jnp.zeros(self.shape.get()))
[docs] @classmethod def help(cls): ## component help function properties = { "synapse_type": "HopfieldSynapse - performs an adaptable synaptic transformation of inputs to produce output " "signals; synapses are adjusted via Hebbian learning in accordance with a Hopfield network" } compartment_props = { "input_compartments": {"inputs": "Takes in external input signal values", "key": "JAX PRNG key"}, "parameter_compartments": {"weights": "Synapse efficacy/strength parameter values"}, "output_compartments": {"outputs": "Output of synaptic transformation"} } hyperparams = { "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", "batch_size": "Batch size dimension of this component", "weight_init": "Initialization conditions for synaptic weight (W) values", "resist_scale": "Resistance level scaling factor (applied to output of transformation)", "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", "eta": "Global learning rate (to control update to memory matrix)", "beta": "Inverse temperature (controls softmax sharpness", "reg_lambda": "Weight decay coefficient to apply to local memory matrix updates", "update_rule": "What type of rule to use to update memory matrix (locally)", "initial_patterns": "Matrix containing a series of concatenated vectors to store into memory explicitly", } info = {cls.__name__: properties, "compartments": compartment_props, "dynamics": "outputs = Hopfield memory retrieval ;" "dW = Hopfield Hebbian update", "hyperparameters": hyperparams} return info
# if __name__ == '__main__': # from ngcsimlib.context import Context # with Context("Bar") as bar: # Wab = HopfieldSynapse("Wab", (2, 3), 4, 4, 1.) # print(Wab)