Source code for ngclearn.utils.viz.spike_plot

import matplotlib #.pyplot as plt
matplotlib.use('Agg')
import matplotlib.pyplot as plt
cmap = plt.cm.jet
import jax.numpy as jnp
import numpy as np
from ngclearn.utils.viz.raster import _create_raster_plot

[docs] def plot_spiking_neuron(curr, mem, spike, ref, dt, thr_line=False, title=False, min_mem_val=0.0, max_mem_val=1.25, fname=None): """ Simple plotting function for visualizing the trajectory of a single neuron (where its input electrical current, membrane potential value, output spike readings, and refractory variable states have been recorded into arrays). This is particularly useful for single neuronal dynamics demonstrations and analyses. Args: curr: the recorded electrical current (over T steps) mem: the recorded membrane potential (over T steps) spike: the recorded spike train (over T steps) ref: the recorded refractory state value (over T steps) dt: the integration time constant thr_line: optional vertical threshold line to plot for voltage title: the title of the plot min_mem_val: minimum value bound on membrane potential subplot (min of y-axis) max_mem_val: maximum value bound on membrane potential subplot (max of y-axis) fname: the filename to save this plot as, i.e., /path/to/name.png (Default: lif_analysis.png) """ x_lim = curr.shape[0] y_curr_lim = float(np.amax(curr)) + 0.2 fig, ax = plt.subplots(3, figsize=(8,6), sharex=True, gridspec_kw = {'height_ratios': [1, 1, 0.4]}) # plot input current ax[0].plot(curr, c="tab:blue") ax[0].set_ylim([0, y_curr_lim]) ax[0].set_xlim([0, x_lim]) ax[0].set_ylabel("Input Current ($J_t$)") if title: ax[0].set_title(title) # plot membrane potential ax[1].plot(mem, c="tab:red") ax[1].set_ylim([min_mem_val, max_mem_val]) ax[1].set_ylabel("Membrane Potential ($V_t$)") if thr_line: ax[1].axhline(y=thr_line, alpha=0.25, linestyle="dashed", c="black", linewidth=2) plt.xlabel("Time Step $t$") spk_ = spike + 0 mask = 0. if ref != None: mask = np.greater_equal(ref,dt).astype(np.float32) spk_ = spk_ * (1.0 - mask) # Plot output spike using spikeplot _create_raster_plot(spk_, ax[2], s=100, c="black") plt.ylabel("Output Spikes") plt.yticks([]) if fname is None: fname = "lif_plot.png" plt.tight_layout() plt.savefig(fname) plt.clf()