Plotting and Visualization

While writing one’s own custom task-specific matplotlib visualization code might be needed for specific experimental setups, there are several useful tools already in-built to ngc-learn, organized under the package sub-directory ngclearn.utils.viz, including utilities for generating raster plots and synaptic receptive field views (useful for biophysical models such as spiking neural networks) as well as t-SNE plots of model latent codes. While the other lesson/tutorials demonstrate some of these useful routines (e.g., raster plots for spiking neuronal cells), in this small lesson, we will demonstrate how to produce a t-SNE plot using ngc-learn’s in-built tool.

Generating a t-SNE Plot

Let’s say you have a labeled five-dimensional (5D) dataset – which we will artificially synthesize in this lesson from an “unobserved” trio of multivariate Gaussians – and that you wanted to visualize these “model outputs” and their corresponding labels in 2D via ngc-learn’s in-built t-SNE.

The following bit of Python code will do this for you (including setting up the data generator):

from jax import numpy as jnp, random
from ngclearn.utils.viz.dim_reduce import extract_tsne_latents, plot_latents

dkey = random.PRNGKey(1234)

def gen_data(dkey, N): ## data generator (or proxy stochastic data generating process)
    mu1 = jnp.asarray([[2.1, 3.2, 0.6, -4., -2.]])
    cov1 = jnp.eye(5) * 0.78
    mu2 = jnp.asarray([[-1.8, 0.2, -0.1, 1.99, 1.56]])
    cov2 = jnp.eye(5) * 0.52
    mu3 = jnp.asarray([[0.3, -1.2, -0.56, -4., 3.6]])
    cov3 = jnp.eye(5) * 1.2

    dkey, *subkeys = random.split(dkey, 7)
    samp1 = random.multivariate_normal(subkeys[0], mu1, cov1, shape=(N,))
    samp2 = random.multivariate_normal(subkeys[0], mu2, cov2, shape=(N,))
    samp3 = random.multivariate_normal(subkeys[0], mu3, cov3, shape=(N,))
    X = jnp.concatenate((samp1, samp2, samp3), axis=0)
    y1 = jnp.ones((N, 3)) * jnp.asarray([[1., 0., 0.]])
    y2 = jnp.ones((N, 3)) * jnp.asarray([[0., 1., 0.]])
    y3 = jnp.ones((N, 3)) * jnp.asarray([[0., 0., 1.]])
    lab = jnp.concatenate((y1, y2, y3), axis=0) ## one-hot codes
    return X, lab

## generate data or theoretical "model outputs"
data, lab = gen_data(dkey, 400)

## visualize the above data via the t-SNE algorithm
print("data.shape = ",data.shape)
codes = extract_tsne_latents(data)
print("code.shape = ",codes.shape)
plot_latents(codes, lab, plot_fname="codes.jpg")

which should produce a plot, i.e., codes.jpg, similar to the one below:

../../_images/simple_codes.jpg

In this example scenario, we see that we can successfully map the 5D model output data to a plottable 2D space, facilitating some level of downstream qualitative interpretation of the model.