# Spiking Neural Networks: Learning with Broadcast Feedback Alignment (Samadi et al.; 2017)
In this exhibit, we will see how one can train a spiking neural network model
using surrogate functions and a credit assignment scheme called broadcast
feedback alignment (BFA) [1].
This exhibit model effectively reproduces some of the results
reported (Samadi et al., 2017) [1]. The model code for this
exhibit can be found
[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/bfa_snn).
Note: You will need to unzip the MNIST arrays in `exhibits/data/mnist.zip` to the
folder `exhibits/data/` to work through this exhibit/walkthrough.
## The Broadcast Feedback Alignment Spiking Network (BFA-SNN)
The model proposed and studied in [1] is a multi-layer spiking neural
network (SNN) composed of leaky integrators meant to engage in supervised learning,
specifically the task of classification. Concretely, this means that it takes
an image input $\mathbf{x}$ and tries to predict its label, the ground truth of
which is (one-hot) encoded in $\mathbf{y}$.
Additionally, the BFA-SNN exhibit is useful for thinking how one might
craft and apply backprop-alternative, biologically-plausible credit
assignment schemes to spiking neuronal networks; for instance, broadcast
feedback alignment (BFA), as was done in [1]. Note that we, in this
exhibit, abbreviate the SNN trained with BFA to "BFA-SNN" (and the constructor
inside of the `bfasnn_model.py` file is named `BFA_SNN`).
The BFA-SNN model instantiated in this exhibit is made up of three layers:
1. a sensory input layer made up of [Bernoulli encoding](ngclearn.components.input_encoders.bernoulliCell)
neuronal cells, where the probability of firing for any pixel within in an image
(yielding a value of one) is taken to be the normalized, scaled pixel intensity[^1];
2. one hidden layer of leaky integrate-and-fire (LIF) cells; and,
3. one output layer of LIF cells (`10` LIFs specifically, one per class category in MNIST).
### Neuronal Dynamics
The sensory input layer of the BFA-SNN is rather simple -- it assumes
that input values are probabilities (specifically, the observed input image
vector is assumed to be a collection of spike probabilities that drive a set
of Bernoulli distributions, where each pixel represents a Bernoulli distribution
$\mathcal{B}(\mathbf{x}_i(t) = 1; p=\mathbf{x}_i); note this means your image
pixel values should be scaled to be between `[0,1]`). The hidden and output
layers are made up leaky integrate-and-fire (LIF) neuronal cells; specifically,
the BFA-SNN model utilizes ngc-learn's simplified LIF[^2] (i.e., the
[SLIF](ngclearn.components.neurons.spiking.sLIFCell)).
The `SLIF` model simulates the core dynamics of what we would want from leaky
integration and further sports a few conveniences useful for recovering
certain modeling choices made by different research studies, such as the use
of surrogate functions for approximating partial derivatives of non-differentiable
spike emission functions (which will be touched on below).
The SLIF component, in effect, adheres to the following dynamics:
$$
\tau_m \frac{\partial \mathbf{v}_t}{\partial t} =
(-\mathbf{v}_t + R \mathbf{j}_t) \odot \mathbf{m}_{rfr}
$$
where $\mathbf{j}_t$ is the set of electrical current values fed into the group of
neuronal units within the component, $\mathbf{v}_t$ is the membrane potentials
of the component's internal neural population, $\tau_m$ (`tau_m`) is the membrane
time constant (shared for all units inside the component), $R_m$ (`R`) is the
membrane resistance factor (shared for all units), and $\mathbf{m}_{rfr}$ is a
mask produced by the cell's refractory variable -- if any cell within the component
is in its refractory period, it will be clamped to a resting potential of
$0$ milliVolts (mV). Note that Euler integration is used to get from
$\mathbf{v}_t$ to $\mathbf{v}_{t+\Delta t}$.
To emit (binary) spikes in the hidden and output layer, each voltage value
within $\mathbf{v}_t$ is compared to a (scalar) threshold value $v_{thr}$ (`thr`)[^3].
Specifically, a spike is emitted if $\mathbf{s}_{t+\Delta} = \mathbf{v}_{t+\Delta t} > v_{thr}$ and,
if any cell $i$ emits a spike its specific voltage value is hyperpolarized back
to a base value of `0` mV.
### Broadcast Feedback Alignment (and In-Built Surrogate Functions)
Adaptation in our BFA-SNN model, under the above neuronal dynamics described
above, requires combining two simple mechanisms together:
1) we introduce a feedback synaptic cable, `E2`, which will be randomly
initialized and fixed (the feedback synapses will not be changed throughout
the course of simulation), and
2) the forward synaptic cables, `W1` and `W2`, will be adjusted with simple multi-factor
Hebbian rules.
These two mechanisms can also be illustrated with the architecture of our
BFA-SNN model, as shown below.
```{eval-rst}
.. table::
:align: center
+-------------------------------------------------------+
| .. image:: ../images/museum/bfa_snn/bfa_snn.png |
| :scale: 85% |
| :align: center |
+-------------------------------------------------------+
```
Specifically, this means we will instantiate `W1` ($\mathbf{W}^1$), `W2` ($\mathbf{W}^2$),
and `E2` ($\mathbf{E}^2$) as [Hebbian synapses](ngclearn.components.synapses.hebbian.hebbianSynapse).
For `E2`, the learning rate argument `eta_w` is set to zero, so no learning rule
will ever applied to the feedback synapses. For `W1` and `W2`, the learning rules
will effectively become:
$$
\Delta \mathbf{W}^1 &= \mathbf{d}^1 \cdot (\mathbf{s}^0(t+\Delta t))^T \\
\Delta \mathbf{W}^2 &= \mathbf{e}^2 \cdot (\mathbf{s}^1(t+\Delta t))^T
$$
where $\mathbf{s}^0(t+\Delta t)$ contains the binary spikes produced from the
`BernoulliCell` units in the input layer, $\mathbf{s}^1(t+\Delta t)$ contains the
spikes produced by the hidden layer of `SLIF` units, and $\mathbf{e}^2$ is a
set of [Gaussian error neuron cells](ngclearn.components.neurons.graded.gaussianErrorCell)
placed at the output layer to measure the mismatch between the output layer
spikes (stored in a vector $\mathbf{s}^2(t+\Delta t)$) and some target spike
train spikes, i.e., in this exhibit's case, the target spike train is simple as
it is merely copying the one-hot encoding of the label for `T` steps or
$\mathbf{s}_y(t) = \mathbf{y}$. $\mathbf{e}^2$ is specifically computed as follows:
$$
\mathbf{e}^2 = -\big( \mathbf{s}_y(t) - \mathbf{s}^2(t) \big)
$$
which, as described in [the Gaussian error cell API](ngclearn.components.neurons.graded.gaussianErrorCell),
is effectively the first derivative of a Gaussian (with unit variance) distribution
with respect to the output spike vector. This means that the output layer synaptic efficacies
in `W2` are adapted with a simple two-factor Hebbian rule where the first term
(the pre-synaptic term) is the incoming spikes from the hidden layer and
the second term (the post-synaptic term) is the mismatch values produced by
comparing the output layer spikes against the target spikes.
Finally, the last part to take notice of is the first synaptic update rule for `W1`,
i.e., $\Delta \mathbf{W}^1$; the first term is similar to the one for the
rule for `W2` as it is the pre-synaptic spikes $\mathbf{s}^0(t+\Delta t)$ produced by the input layer of
`BernoulliCell` units. It is the second, post-synaptic term that is most interesting
-- it is a "teaching signal" $\mathbf{d}^1$ produced by using the feedback synapses `E2` we
mentioned earlier in this tutorial. Formally, the teaching signals $\mathbf{d}^1$
are computed as follows:
$$
\mathbf{d}^1 = \big( \mathbf{E}^2 \cdot \mathbf{e}^2 \big) \odot
f_{surr}(\mathbf{j}^1_{t+\Delta t})
$$
where the first term of the Hadamard product (i.e., the $\odot$) is merely
a transmission of error neuron values down along the `E2` synaptic cable
while the second term is known as a [surrogate function](ngclearn.utils.surrogate_fx).
A surrogate function is, mathematically, a substitute derivative function for the
true derivative of some typically non-differentiable function, such as the
binary spike emission function typically used in spiking networks. Since
BFA as an algorithm can be likened to performing backpropagation of errors (backprop;
an algorithm typically used to train deep neural networks) on spiking neural
network without reusing the synapses to back-transmit error/teaching signals
(a biological criticism of backprop known as the "weight transport problem"
[2]), we still need derivatives of most of the mathematical operations we took
to get to the output. In a spiking network's case, this would be its spike emission
functions and the (integration of the) its differential equations for updating the
voltage values. If we do not use surrogate functions for differentiation-based
credit assignment, we either omit the function $f_{surr}$ above -- which leads
to degraded generalization performance as was investigated in [1] -- or
we face the "dead neuron problem", where effectively means that the true derivative
of the spike function is zero, resulting in multiplications by zero.
While surrogate functions and gradients for spiking networks are covered
in [more detail elsewhere](https://arxiv.org/pdf/2109.12894.pdf), ngc-learn
contains several useful [in-built surrogate routines](ngclearn.utils.surrogate_fx)
for dealing with the "dead neuron problem" and among them is the very one
utilized in [1] for their BFA-SNN model, i.e., ngc-learn's
`secant_lif_estimator()`.[^4] In effect, the secant LIF estimator is essentially
a specialized mathematical approximation to LIF spiking dynamics/emission patterns,
producing a value of $(c_1 * c_2) \text{sech}^2(c_2 * \mathbf{j}_{t + \Delta t})$ for
electrical current values $\mathbf{j}_{t + \Delta t}$ greater than zero and zero
otherwise.
Using this as the `secant_lif_estimator()` for the surrogate $f_{surr}$ is the
last core detail needed for implementing BFA in ngc-learn, completing the picture
of what the BFA-SNN exhibit does under the hood.
## Running the BFA-SNN Model
To fit the BFA-SNN model described above, go to the `exhibits/bfa_snn`
sub-folder (this step assumes that you have git cloned the model museum repo
code), and execute the BFA-SNN's simulation script from the command line as follows:
```console
$ ./sim.sh
```
which will execute a training process using an experimental configuration very
similar to (Samadi et al., 2017). Specifically, the Bash script executes
`train_bfasnn.py` which simulates a SNN-BA on the MNIST database for `30`
epochs. This script will also save your trained SNN model to the `/exp/`
sub-directory (which is what the `analyze_bfsnn.py` script, used below, looks for).
After your model finishes training you should see output similar to the one below:
```console
------------------------------------
Trial.sim_time = 0.31237981763150957 h (1124.5673434734344 sec) Best Acc = 0.9668000340461731
```
The above simulation output of our SNN displays the wall-clock simulation time
(i.e., about `18` minutes for a three-year old NVIDIA GPU when producing the
example above) as well as the MNIST development/validation set accuracy.
To use your saved/trained model and examine its performance on the MNIST test-set, you
can execute the evaluation script like so:
```console
$ python analyze_bfsnn.py --dataX=../data/mnist/testX.npy --dataY=../data/mnist/testY.npy
```
while will evaluate your BFA-SNN's performance on the MNIST test-set, reporting
its (negative) log likelihood[^5] and label accuracy as follows:
```console
=> NLL = 0.3721860647201538 Acc = 0.9628000259399414
```
In effect, we approximately recover the test performance of the single hidden
layer model (with `1000` LIF units), trained with BFA (using ngc-learn's
in-built implementation of the secant surrogate function `E(x)` of [1]),
as reported in [1]. Specifically, we observe the BFA-SNN reaches
`96`% test classification accuracy (bear in mind, we are counting spikes and,
for each row in an evaluated test mini-batch matrix, the output LIF node with the
highest spike count at the end of `T * dt` ms is chosen as the SNN's predicted label).
Finally, running the above analysis script also produces a (t-SNE) visualization
of the estimated rate codes related to the model's internal/hidden layer (made up
of $1000$ `SLIF` cells) similar to the one below:
Intriguingly, we see that the latent codes represented by the BFA-SNN's hidden
layer spikes yield a rather (piecewise) linearly-separable representation
of the input digits, making the process of mapping inputs to label vectors
much easier for the model's second layer of classification LIF units.
Note that, in the `BFA_SNN` model exhibit class, we estimated
rate codes to produce this plot by converting a spike train of the SNN for each
and every sample as follows:
$$
\mathbf{z}^{1,i} = \frac{\gamma}{(T-T_{nl})} \sum^T_{t=T_{nl}} \mathbf{s}^{1,i}(t).
$$
where the superscript $i$ indexes the $i$ sample sensory data pattern.
One particularly noteworthy difference between the exhibit model and
the one reported in [1] is that our BFA-SNN directly processes
Bernoulli spike trains whereas the original focused on processing the raw
real-valued pattern vectors, i.e., copying the input x to each time step (note
that this could be done by replacing our input `BernoulliCell` with a `RateCell`
that has its `tau_m` time constant set to `0`). One notable limitation of both
our model and the [1] model is that both allow the signs of synaptic weights
to change throughout the course of learning, i.e., an initially non-negative
valued synapse could, at one point in training, become negative, which is
biologically-implausible as synapses should have positive values and
a fixed population of excitatory and inhibitory neurons should do the work
of amplification and depression (something that our
[Diehl and Cook model](../museum/snn_dc.md) exhibit directly adheres to).
### Aside: Plotting the BFA-SNN's Negative Log Likelihood
The training script `train_bfasnn.py` also saves for you, in the `/exp/` several
Numpy arrays, containing measurements of the model's training and development accuracy
and negative log likelihoods.
You can plot the values in these numpy arrays to produce a nice visualization
of the BFA-SNN's learning curves like so:
```python
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
colors = ["red", "blue"]
# post-process likelihood learning curve data
y = np.load("exp/trNll.npy") ## training measurements
vy = np.load("exp/nll.npy") ## dev measurements
x_iter = np.asarray(list(range(0, y.shape[0])))
## make the plots with matplotlib
fontSize = 20
plt.plot(x_iter, y, '-', color=colors[0])
plt.plot(x_iter, vy, '-', color=colors[1])
plt.xlabel("Epoch", fontsize=fontSize)
plt.ylabel("Negative Log Likelihood", fontsize=fontSize)
plt.grid()
## construct the legend/key
loss = mpatches.Patch(color=colors[0], label='Train NLL')
vloss = mpatches.Patch(color=colors[1], label='Dev NLL')
plt.legend(handles=[loss, vloss], fontsize=13, ncol=2,borderaxespad=0, frameon=False,
loc='upper center', bbox_to_anchor=(0.5, -0.175))
plt.tight_layout()
plt.savefig("exp/bfasnn_mnist_likelihood.jpg") ## save plot to disk
plt.clf()
```
which would produce a plot much like below:
where we see that the SNN has decreased its approximate negative log likelihood
from a starting point of about `2.30` nats to about `0.32` nats (on the
validation dataset). This is bearing in mind that we have estimated class
probabilities output by our SNN by probing and averaging over electrical current
values from `25` simulated milliseconds per mini-batch of test patterns.
We remark that this constructed SNN is not particularly deep and with additional
layers of `SLIF` nodes, improvements to its accuracy and approximate log
likelihood would be possible.[^6]
### Computing Hardware Note:
This tutorial was tested and run on an `Ubuntu 22.04.2 LTS` operating system
using an `NVIDIA GeForce RTX 2070` GPU with `CUDA Version: 12.1`
(`Driver Version: 530.41.03`). Note that the times reported in any tutorial
screenshot/console snippets were produced on this system.
## References
[1] Samadi, Arash, Timothy P. Lillicrap, and Douglas B. Tweed. "Deep
learning with dynamic spiking neurons and fixed feedback weights." Neural
computation 29.3 (2017): 578-602.
[2] Grossberg, Stephen. "Competitive learning: From interactive activation
to adaptive resonance." Cognitive science 11.1 (1987): 23-63.
[^1]: In the model constructor code in `bfasnn_model.py`, there is a small
co-routine called `scale_input()` which will multiply pixel inputs, assuming
they are normalized between `[0,1]`, by a default factor of `input_gain = 0.25`.
[^2]: The simplified LIF is also covered in a more detail in the tutorial on
the [SLIF component](../tutorials/neurocog/simple_leaky_integrator.md).
[^3]: Note that it is possible to configure this threshold to be a per-adaptive
neuronal unit threshold that adapts/decays with time; see the API of the
[SLIF component](../tutorials/neurocog/simple_leaky_integrator.md) for details.
This modeling decision was not used in the source work [1] that proposed the BFA-SNN.
[^4]: Surrogate routines in ngc-learn take on the
$f_{surr}(\mathbf{x}, \text{other args})$ function format to create approximate
derivatives; some functions, like the secant estimator use electrical current
for $\mathbf{x}$ while others might use the voltage/membrane potential and/or
voltage threshold.
[^5]: Inside of the `BFA_SNN` exhibit class, we, unlike [1] also report and track
the negative Categorical log likelihood (`NLL`) by approximating the SNN's
label distribution using spike outputs (from the second `SLIF` layer `z2`) and
applying a softmax function to a simple temporal average (over `T` time
steps). However, actual classification in the `BFA_SNN` does not use this approximate
label distribution; instead, to produce label predictions, we simply take the
argmax of a final output spike count/frequency vector produced at the end
of `T` time steps..
[^6]: The BFA learning approach would, in principle, work well for any
number of layers. This is motivated by the results reported in [1], where
additional layers were found to, experimentally on MNIST, improve generalization
a bit more.