Lesson 4: Evolving Synaptic Efficacies

In this tutorial, we will extend a controller with three components, two cell components connected with a synaptic cable component, to incorporate a basic a two-factor Hebbian adjustment process.

Adding a Learnable Synapse to a Multi-Component System

Let us start by building a controller similar to previous lessons with the one exception that now we will trigger the synaptic connection between a and b to adapt via a simple 2-factor Hebbian rule. This Hebbian rule will require us to wire the output compartment of a to the pre-synaptic compartment of the synapse Wab and the output compartment of b to the post-synaptic compartment of Wab. This will wire in the two relevant factors needed to compute a simple Hebbian adjustment.

We do this specifically as follows:

from ngcsimlib.controller import Controller
from jax import numpy as jnp, random, nn, jit

## create seeding keys
dkey = random.PRNGKey(1234)
dkey, *subkeys = random.split(dkey, 6)

## create simple dynamical system: a --> w_ab --> b
model = Controller()
a = model.add_component("rate", name="a", n_units=1, tau_m=0.,
                        act_fx="identity", key=subkeys[0])
b = model.add_component("rate", name="b", n_units=1, tau_m=0.,
                        act_fx="identity", key=subkeys[1])
Wab = model.add_component("hebbian", name="Wab", shape=(1, 1),
                          eta=1., signVal=-1., wInit=("constant", 1., None),
                          w_bound=0., key=subkeys[3])
## wire a to w_ab and wire w_ab to b
model.connect(a.name, a.outputCompartmentName(), Wab.name, Wab.inputCompartmentName())
model.connect(Wab.name, Wab.outputCompartmentName(), b.name, b.inputCompartmentName())

model.connect(a.name, a.outputCompartmentName(), Wab.name, Wab.presynapticCompartmentName())
model.connect(b.name, b.outputCompartmentName(), Wab.name, Wab.postsynapticCompartmentName())

## configure desired commands for simulation object
model.add_command("reset", command_name="reset",
                  component_names=[a.name, Wab.name, b.name],
                  reset_name="do_reset")
model.add_command(
    "advance", command_name="advance",
    component_names=[a.name, Wab.name, b.name]
)
model.add_command("evolve", command_name="evolve", component_names=[Wab.name])
model.add_command("clamp", command_name="clamp_data",
                  component_names=[a.name], compartment=a.inputCompartmentName(),
                  clamp_name="x")
## pin the commands to the object
model.add_step("advance")
model.add_step("evolve")

Now with our simple system above created, we will now run a simple sequence of one-dimensional “spike” data through it and evolve the synapse every time step like so:

## run some data through the dynamical system
x_seq = jnp.asarray([[1, 1, 0, 0, 1]], dtype=jnp.float32)

model.reset(do_reset=True)
print("{}: Wab = {}".format(-1, model.components["Wab"].weights))
for ts in range(x_seq.shape[1]):
  x_t = jnp.expand_dims(x_seq[0,ts], axis=0) ## get data at time t
  model.clamp_data(x=x_t)
  model.runCycle(t=ts*1., dt=1.)
  print(" {}: input = {} ~> Wab = {}".format(ts, x_t, model.components["Wab"].weights))

Your code should produce the same output (towards the bottom):

-1: Wab = [[1.]]
 0: input = [1.] ~> Wab = [[2.]]
 1: input = [1.] ~> Wab = [[4.]]
 2: input = [0.] ~> Wab = [[4.]]
 3: input = [0.] ~> Wab = [[4.]]
 4: input = [1.] ~> Wab = [[8.]]

Notice that for every non-spike (a value of 0), the synaptic value remains the same (because the product of a pre-synaptic value of 0 with a post-synaptic value of anything – in this case, also a 0 – is simply 0, meaning no change will be applied to the synapse). For every spike (a value of 1), we get a synaptic change equal to dW = input * (Wab * input); so for the first time-step, the weight will change according to W = W + eta * dW = W + dW and dW = 1 * (1 * 1) = 1, whereas, for the second time-step, W will be increased by dW = 1 * (2 * 1) = 2 (yielding a new synaptic strength of W = 4).

You have now created your first plastic, evolving neuronal system.