Source code for ngclearn.utils.viz.compartment_raster

"""
Raster visualization functions/utilities.
"""
import matplotlib.pyplot as plt
import jax
from typing import Sequence

[docs] def create_raster_plot(spike_train: jax.Array, ax: plt.Axes | None = None, indices: Sequence[int] | None = None, s=0.5, c="black"): """ Generates a raster plot of a given (binary) spike train (row dimension corresponds to the discrete time dimension). Args: spike_train: a numpy binary array of shape (T x number_of_neurons) ax: a hook/pointer to a currently external plot that this raster plot should be made a sub-figure of indices: optional indices of neurons (row integer indices) to focus on plotting s: size of the spike scatter points (Default = 0.5) c: color of the spike scatter points (Default = black) """ step_count = spike_train.shape[0] n_count = spike_train.shape[1] if ax is None: nc = n_count if indices is None else len(indices) fig_size = 5 if nc < 25 else int(nc / 5) plt.figure(figsize=(fig_size, fig_size)) _ax = ax if ax is not None else plt events = [] for t in range(n_count): if indices is None or t in indices: e = spike_train[:, t].nonzero() events.append(e[0]) _ax.eventplot(events, linelengths=s, colors=c) if ax is None: _ax.yticks(ticks=[i for i in (range(n_count if indices is None else len(indices)))], labels=["N" + str(i) for i in (range(n_count) if indices is None else indices)]) _ax.xticks(ticks=[i for i in range(0, step_count+1, max(int(step_count / 5), 1))]) else: _ax.set_yticks(ticks=[i for i in (range(n_count if indices is None else len(indices)))], labels=["N" + str(i) for i in (range(n_count) if indices is None else indices)])