Spiking Neural Networks: Learning with Broadcast Feedback Alignment

In this exhibit, we will see how one can train a spiking neural network model using surrogate functions and a credit assignment scheme called broadcast feedback alignment (BFA) [1]. This exhibit model effectively reproduces some of the results reported (Samadi et al., 2017) [1]. The model code for this exhibit can be found here.

The Broadcast Feedback Alignment Spiking Network (BFA-SNN)

The model proposed and studied in [1] is a multi-layer spiking neural network (SNN) composed of leaky integrators meant to engage in supervised learning, specifically the task of classification. Concretely, this means that it takes an image input \(\mathbf{x}\) and tries to predict its label, the ground truth of which is (one-hot) encoded in \(\mathbf{y}\). Additionally, the BFA-SNN exhibit is useful for thinking how one might craft and apply backprop-alternative, biologically-plausible credit assignment schemes to spiking neuronal networks; for instance, broadcast feedback alignment (BFA), as was done in [1]. Note that we, in this exhibit, abbreviate the SNN trained with BFA to “BFA-SNN” (and the constructor inside of the bfasnn_model.py file is named BFA_SNN).

The BFA-SNN model instantiated in this exhibit is made up of three layers:

  1. a sensory input layer made up of Bernoulli encoding neuronal cells, where the probability of firing for any pixel within in an image (yielding a value of one) is taken to be the normalized, scaled pixel intensity[1];

  2. one hidden layer of leaky integrate-and-fire (LIF) cells; and,

  3. one output layer of LIF cells (10 LIFs specifically, one per class category in MNIST).

Neuronal Dynamics

The sensory input layer of the BFA-SNN is rather simple – it assumes that input values are probabilities (specifically, the observed input image vector is assumed to be a collection of spike probabilities that drive a set of Bernoulli distributions, where each pixel represents a Bernoulli distribution $\mathcal{B}(\mathbf{x}_i(t) = 1; p=\mathbf{x}_i); note this means your image pixel values should be scaled to be between [0,1]). The hidden and output layers are made up leaky integrate-and-fire (LIF) neuronal cells; specifically, the BFA-SNN model utilizes ngc-learn’s simplified LIF[2] (i.e., the SLIF). The SLIF model simulates the core dynamics of what we would want from leaky integration and further sports a few conveniences useful for recovering certain modeling choices made by different research studies, such as the use of surrogate functions for approximating partial derivatives of non-differentiable spike emission functions (which will be touched on below).

The SLIF component, in effect, adheres to the following dynamics:

\[ \tau_m \frac{\partial \mathbf{v}_t}{\partial t} = (-\mathbf{v}_t + R \mathbf{j}_t) \odot \mathbf{m}_{rfr} \]

where \(\mathbf{j}_t\) is the set of electrical current values fed into the group of neuronal units within the component, \(\mathbf{v}_t\) is the membrane potentials of the component’s internal neural population, \(\tau_m\) (tau_m) is the membrane time constant (shared for all units inside the component), \(R_m\) (R) is the membrane resistance factor (shared for all units), and \(\mathbf{m}_{rfr}\) is a mask produced by the cell’s refractory variable – if any cell within the component is in its refractory period, it will be clamped to a resting potential of \(0\) milliVolts (mV). Note that Euler integration is used to get from \(\mathbf{v}_t\) to \(\mathbf{v}_{t+\Delta t}\).

To emit (binary) spikes in the hidden and output layer, each voltage value within \(\mathbf{v}_t\) is compared to a (scalar) threshold value \(v_{thr}\) (thr)[3]. Specifically, a spike is emitted if \(\mathbf{s}_{t+\Delta} = \mathbf{v}_{t+\Delta t} > v_{thr}\) and, if any cell \(i\) emits a spike its specific voltage value is hyperpolarized back to a base value of 0 mV.

Broadcast Feedback Alignment (and In-Built Surrogate Functions)

Adaptation in our BFA-SNN model, under the above neuronal dynamics described above, requires combining two simple mechanisms together:

  1. we introduce a feedback synaptic cable, E2, which will be randomly initialized and fixed (the feedback synapses will not be changed throughout the course of simulation), and

  2. the forward synaptic cables, W1 and W2, will be adjusted with simple multi-factor Hebbian rules. These two mechanisms can also be illustrated with the architecture of our BFA-SNN model, as shown below.


Specifically, this means we will instantiate W1 (\(\mathbf{W}^1\)), W2 (\(\mathbf{W}^2\)), and E2 (\(\mathbf{E}^2\)) as Hebbian synapses. For E2, the learning rate argument eta_w is set to zero, so no learning rule will ever applied to the feedback synapses. For W1 and W2, the learning rules will effectively become:

\[\begin{split} \Delta \mathbf{W}^1 &= \mathbf{d}^1 \cdot (\mathbf{s}^0(t+\Delta t))^T \\ \Delta \mathbf{W}^2 &= \mathbf{e}^2 \cdot (\mathbf{s}^1(t+\Delta t))^T \end{split}\]

where \(\mathbf{s}^0(t+\Delta t)\) contains the binary spikes produced from the BernoulliCell units in the input layer, \(\mathbf{s}^1(t+\Delta t)\) contains the spikes produced by the hidden layer of SLIF units, and \(\mathbf{e}^2\) is a set of Gaussian error neuron cells placed at the output layer to measure the mismatch between the output layer spikes (stored in a vector \(\mathbf{s}^2(t+\Delta t)\)) and some target spike train spikes, i.e., in this exhibit’s case, the target spike train is simple as it is merely copying the one-hot encoding of the label for T steps or \(\mathbf{s}_y(t) = \mathbf{y}\). \(\mathbf{e}^2\) is specifically computed as follows:

\[ \mathbf{e}^2 = -\big( \mathbf{s}_y(t) - \mathbf{s}^2(t) \big) \]

which, as described in the Gaussian error cell API, is effectively the first derivative of a Gaussian (with unit variance) distribution with respect to the output spike vector. This means that the output layer synaptic efficacies in W2 are adapted with a simple two-factor Hebbian rule where the first term (the pre-synaptic term) is the incoming spikes from the hidden layer and the second term (the post-synaptic term) is the mismatch values produced by comparing the output layer spikes against the target spikes.

Finally, the last part to take notice of is the first synaptic update rule for W1, i.e., \(\Delta \mathbf{W}^1\); the first term is similar to the one for the rule for W2 as it is the pre-synaptic spikes \(\mathbf{s}^0(t+\Delta t)\) produced by the input layer of BernoulliCell units. It is the second, post-synaptic term that is most interesting – it is a “teaching signal” \(\mathbf{d}^1\) produced by using the feedback synapses E2 we mentioned earlier in this tutorial. Formally, the teaching signals \(\mathbf{d}^1\) are computed as follows:

\[ \mathbf{d}^1 = \big( \mathbf{E}^2 \cdot \mathbf{e}^2 \big) \odot f_{surr}(\mathbf{j}^1_{t+\Delta t}) \]

where the first term of the Hadamard product (i.e., the \(\odot\)) is merely a transmission of error neuron values down along the E2 synaptic cable while the second term is known as a surrogate function. A surrogate function is, mathematically, a substitute derivative function for the true derivative of some typically non-differentiable function, such as the binary spike emission function typically used in spiking networks. Since BFA as an algorithm can be likened to performing backpropagation of errors (backprop; an algorithm typically used to train deep neural networks) on spiking neural network without reusing the synapses to back-transmit error/teaching signals (a biological criticism of backprop known as the “weight transport problem” [2]), we still need derivatives of most of the mathematical operations we took to get to the output. In a spiking network’s case, this would be its spike emission functions and the (integration of the) its differential equations for updating the voltage values. If we do not use surrogate functions for differentiation-based credit assignment, we either omit the function \(f_{surr}\) above – which leads to degraded generalization performance as was investigated in [1] – or we face the “dead neuron problem”, where effectively means that the true derivative of the spike function is zero, resulting in multiplications by zero. While surrogate functions and gradients for spiking networks are covered in more detail elsewhere, ngc-learn contains several useful in-built surrogate routines for dealing with the “dead neuron problem” and among them is the very one utilized in [1] for their BFA-SNN model, i.e., ngc-learn’s secant_lif_estimator().[4] In effect, the secant LIF estimator is essentially a specialized mathematical approximation to LIF spiking dynamics/emission patterns, producing a value of \((c_1 * c_2) \text{sech}^2(c_2 * \mathbf{j}_{t + \Delta t})\) for electrical current values \(\mathbf{j}_{t + \Delta t}\) greater than zero and zero otherwise.

Using this as the secant_lif_estimator() for the surrogate \(f_{surr}\) is the last core detail needed for implementing BFA in ngc-learn, completing the picture of what the BFA-SNN exhibit does under the hood.

Running the BFA-SNN Model

To fit the BFA-SNN model described above, go to the exhibits/bfa_snn sub-folder (this step assumes that you have git cloned the model museum repo code), and execute the BFA-SNN’s simulation script from the command line as follows:

$ ./sim.sh

which will execute a training process using an experimental configuration very similar to (Samadi et al., 2017). Specifically, the Bash script executes train_bfasnn.py which simulates a SNN-BA on the MNIST database for 30 epochs. This script will also save your trained SNN model to the /exp/ sub-directory (which is what the analyze_bfsnn.py script, used below, looks for). After your model finishes training you should see output similar to the one below:

 Trial.sim_time = 0.31237981763150957 h  (1124.5673434734344 sec)  Best Acc = 0.9668000340461731

The above simulation output of our SNN displays the wall-clock simulation time (i.e., about 18 minutes for a three-year old NVIDIA GPU when producing the example above) as well as the MNIST development/validation set accuracy. To use your saved/trained model and examine its performance on the MNIST test-set, you can execute the evaluation script like so:

$ python analyze_bfsnn.py --dataX=../data/mnist/testX.npy --dataY=../data/mnist/testY.npy

while will evaluate your BFA-SNN’s performance on the MNIST test-set, reporting its (negative) log likelihood[5] and label accuracy as follows:

=> NLL = 0.3721860647201538  Acc = 0.9628000259399414

In effect, we approximately recover the test performance of the single hidden layer model (with 1000 LIF units), trained with BFA (using ngc-learn’s in-built implementation of the secant surrogate function E(x) of [1]), as reported in [1]. Specifically, we observe the BFA-SNN reaches 96% test classification accuracy (bear in mind, we are counting spikes and, for each row in an evaluated test mini-batch matrix, the output LIF node with the highest spike count at the end of T * dt ms is chosen as the SNN’s predicted label).

Finally, running the above analysis script also produces a (t-SNE) visualization of the estimated rate codes related to the model’s internal/hidden layer (made up of \(1000\) SLIF cells) similar to the one below:


Intriguingly, we see that the latent codes represented by the BFA-SNN’s hidden layer spikes yield a rather (piecewise) linearly-separable transformation of the input digits, making the process of mapping inputs to label vectors much easier for the model’s second layer of classification LIF units. Note that, in the BFA_SNN model exhibit class, we estimated rate codes to produce this plot by converting a spike train of the SNN for each and every sample as follows:

\[ \mathbf{z}^{1,i} = \frac{\gamma}{(T-T_{nl})} \sum^T_{t=T_{nl}} \mathbf{s}^{1,i}(t). \]

where the superscript \(i\) indexes the \(i\) sample sensory data pattern.

One particularly noteworthy difference between the exhibit model and the one reported in [1] is that our BFA-SNN directly processes Bernoulli spike trains whereas the original focused on processing the raw real-valued pattern vectors, i.e., copying the input x to each time step (note that this could be done by replacing our input BernoulliCell with a RateCell that has its tau_m time constant set to 0). One notable limitation of both our model and the [1] model is that both allow the signs of synaptic weights to change throughout the course of learning, i.e., an initially non-negative valued synapse could, at one point in training, become negative, which is biologically-implausible as synapses should have positive values and a fixed population of excitatory and inhibitory neurons should do the work of amplification and depression (something that our Diehl and Cook model exhibit directly adheres to).

Aside: Plotting the BFA-SNN’s Negative Log Likelihood

The training script train_bfasnn.py also saves for you, in the /exp/ several Numpy arrays, containing measurements of the model’s training and development accuracy and negative log likelihoods. You can plot the values in these numpy arrays like to produce a nice visualization of the BFA-SNN’s learning curves like so:

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

colors = ["red", "blue"]
# post-process likelihood learning curve data
y = np.load("exp/trNll.npy") ## training measurements
vy = np.load("exp/nll.npy") ## dev measurements
x_iter = np.asarray(list(range(0, y.shape[0])))
## make the plots with matplotlib
fontSize = 20
plt.plot(x_iter, y, '-', color=colors[0])
plt.plot(x_iter, vy, '-', color=colors[1])
plt.xlabel("Epoch", fontsize=fontSize)
plt.ylabel("Negative Log Likelihood", fontsize=fontSize)
## construct the legend/key
loss = mpatches.Patch(color=colors[0], label='Train NLL')
vloss = mpatches.Patch(color=colors[1], label='Dev NLL')
plt.legend(handles=[loss, vloss], fontsize=13, ncol=2,borderaxespad=0, frameon=False,
           loc='upper center', bbox_to_anchor=(0.5, -0.175))
plt.savefig("exp/bfasnn_mnist_likelihood.jpg") ## save plot to disk

which would produce a plot much like below:


where we see that the SNN has decreased its approximate negative log likelihood from a starting point of about 2.30 nats to about 0.32 nats (on the validation dataset). This is bearing in mind that we have estimated class probabilities output by our SNN by probing and averaging over electrical current values from 25 simulated milliseconds per mini-batch of test patterns. We remark that this constructed SNN is not particularly deep and with additional layers of SLIF nodes, improvements to its accuracy and approximate log likelihood would be possible.[6]


