"""
Raster visualization functions/utilities.
"""
import math
import random
#import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import imageio.v3 as iio
import jax.numpy as jnp
#suffix = '.jpg' #.png
def _create_raster_plot(spike_train, ax=None, s=1.5, c="black", marker="|",
plot_fname=None, indices=None, suffix='.jpg'):
spk_ = spike_train
# Process spikes and create the binary coordinate grid
if len(spk_.shape) == 1:
spk_ = jnp.expand_dims(spk_,axis=1)
n_units = spk_.shape[1]
if indices is not None and indices.shape[0] < spk_.shape[1]:
spk_ = spk_[:,indices] # access specific neurons if requested
if len(spk_.shape) > 1:
n_units = spk_.shape[1]
else:
n_units = spk_.shape[0]
coords = spk_.nonzero() #jnp.where(spk_) #tf.where(spk_).numpy()
spk_x = coords[0] #coords[:,0]
spk_y = coords[1] #coords[:,1]
if ax is not None:
ax.scatter(spk_x, spk_y, s=s, c=c, marker=marker,linewidths=4)
yint = range(0, n_units)
ax.set_yticks(yint)
ax.set_yticklabels(yint, fontsize=12)
return ax
else:
if plot_fname is None:
plot_fname = "raster_plot" + suffix
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
ax.scatter(spk_x, spk_y, s=s, c=c, marker=marker)
yint = range(0, n_units)
ax.set_yticks(yint)
ax.set_yticklabels(yint, fontsize=12)
plt.title("Spike Train Raster Plot")
plt.xlabel("Time Step")
plt.ylabel("Neuron Index")
plt.savefig(plot_fname)
plt.clf()
plt.close()
[docs]
def create_raster_plot(spike_train, ax=None, s=0.5, c="black",
plot_fname=None, indices=None, tag="", suffix='.jpg'):
"""
Generates a raster plot of a given (binary) spike train (row dimension
corresponds to the discrete time dimension).
Args:
spike_train: a numpy binary array of shape (T x number_neurons)
ax: a hook/pointer to a currently external plot that this raster plot
should be made a sub-figure of
s: size of the spike scatter points (Default = 1.5)
c: color of the spike scatter points (Default = black)
plot_fname: if ax is None, then this is the file name of the raster plot
saved to disk (if plot_fname and ax are both None, then default
plot_fname will be "raster_plot.png" and saved locally)
indices: optional indices of neurons (row integer indices) to focus on plotting
tag:
suffix: output plot file suffix name to append
"""
n_count = spike_train.shape[0]
step_count = spike_train.shape[1]
save = False
if ax is None:
if plot_fname is None:
plot_fname = "raster_plot" + suffix
nc = n_count if indices is None else len(indices)
fig_size = 5 if nc < 25 else int(nc / 5)
plt.figure(figsize=(fig_size, fig_size)) # (fig_size * K, fig_size)
plt.title("Spike Train Raster Plot, {}".format(tag))
plt.xlabel("Time Step")
# plt.ylabel("Neuron Index")
save = True
ax = ax if ax is not None else plt
events = []
for t in range(n_count):
if indices is None or t in indices:
e = spike_train[t,:].nonzero()
events.append(e[0])
ax.eventplot(events, linelengths=s, colors=c)
ax.yticks(ticks=[i for i in (range(n_count if indices is None else len(indices)))],
labels=["N" + str(i) for i in (range(n_count) if indices is None else indices)])
ax.xticks(ticks=[i for i in range(0, step_count+1, max(int(step_count / 5), 1))])
if save:
ax.savefig(plot_fname)
ax.clf()
ax.close()
plt.close()
[docs]
def create_overlay_raster_plot(spike_train, targ_train, Y, idxs, s=1.5, c="black", marker="|",
plot_fname=None, indices=None, end_time=100, delay=10,
suffix='.jpg'):
"""
Generates a raster plot of a given (binary) spike train (row dimension
corresponds to the discrete time dimension).
Args:
spike_train: a numpy binary array of shape (T x number_neurons)
ax: a hook/pointer to a currently external plot that this raster plot
should be made a sub-figure of
s: size of the spike scatter points (Default = 1.5)
c: color of the spike scatter points (Default = black)
marker: format of the marker used to represent each spike (Default = "|")
plot_fname: if ax is None, then this is the file name of the raster plot
saved to disk (if plot_fname and ax are both None, then default
plot_fname will be "raster_plot.png" and saved locally)
indices: optional indices of neurons (row integer indices) to focus on plotting
end_time:
delay:
suffix: output plot file suffix name to append
"""
for idx in idxs:
spk_ = jnp.concatenate([jnp.expand_dims(s[idx, :], axis=0) for s in spike_train], axis=0)
trg_ = jnp.concatenate([jnp.expand_dims(s[idx, :], axis=0) for s in targ_train], axis=0)
tag = "Label: " + str(jnp.argmax(Y[idx,:]))
n_units = spk_.shape[1]
correct_spikes = jnp.where(spk_ + trg_ == 2.) # black
failed_spikes = jnp.where(spk_ < trg_) # blue
extra_spikes = jnp.where(spk_ > trg_) # red
if plot_fname is None:
plot_fname = "raster_plot" + suffix
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
_x = correct_spikes[0] #coords[:,0]
_y = correct_spikes[1] #coords[:,1]
ax.scatter(_x, _y, s=s, c="black", marker=marker)
_x = failed_spikes[0] #coords[:,0]
_y = failed_spikes[1] #coords[:,1]
ax.scatter(_x, _y, s=s, c="blue", marker=marker)
_x = extra_spikes[0] #coords[:,0]
_y = extra_spikes[1] #coords[:,1]
ax.scatter(_x, _y, s=s, c="red", marker=marker)
yint = range(0, n_units)
ax.set_yticks(yint)
ax.set_yticklabels(yint, fontsize=12)
ax.xaxis.set_ticks(jnp.arange(0, end_time + delay, delay))
plt.title("Overlay Raster Plot, {}".format(tag))
plt.xlabel("Time Step")
plt.ylabel("Neuron Index")
plt.savefig(plot_fname + '_' + str(idx) + suffix)
plt.clf()
plt.close()