Source code for ngclearn.utils.viz.synapse_plot

"""
Synaptic/receptive field visualization functions/utilities.
"""
import math
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import imageio.v3 as iio
import jax.numpy as jnp



[docs] def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'): """ Args: thetas: sizes: prefix: suffix: """ if order is None: order = ['C' for _ in range(len(thetas))] Ts = [t.T for t in thetas] # [tf.transpose(t) for t in thetas] num_filters = [T.shape[0] for T in Ts] n_cols = [math.ceil(math.sqrt(nf)) for nf in num_filters] n_rows = [math.ceil(nf / c) for nf, c in zip(num_filters, n_cols)] starts = [sum(n_cols[:i]) + i for i in range(len(n_cols))] max_size = max(sizes) spacers = len(sizes) - 1 n_cols_total = sum(n_cols) + spacers n_rows_total = max(n_rows) plt.figure(figsize=(n_cols_total, n_rows_total)) plt.subplots_adjust(hspace=0.1, wspace=0.1) for idx in range(len(Ts)): T = Ts[idx] size = n_cols[idx] start = starts[idx] for i in range(num_filters[idx]): r = math.floor(i / n_cols[idx]) #math.sqrt(num_filters[idx])) extra = n_cols_total - size point = start + 1 + i + (r * extra) plt.subplot(n_rows_total, n_cols_total, point) filter = T[i, :] plt.imshow(np.reshape(filter, (sizes[idx][0], sizes[idx][1]), order=order[idx]), cmap=plt.cm.bone, interpolation='nearest') plt.axis("off") plt.subplots_adjust(top=0.9) plt.savefig(prefix+suffix, bbox_inches='tight') plt.clf() plt.close()
[docs] def visualize_labels(thetas, sizes, prefix, space_width=None, widths=None, suffix='.jpg'): """ Args: thetas: sizes: prefix: space_width: widths: suffix: """ Ts = [t.T for t in thetas] # [tf.transpose(t) for t in thetas] num_filters = [T.shape[0] for T in Ts] n_cols = [math.ceil(math.sqrt(nf)) for nf in num_filters] n_rows = [math.ceil(nf / c) for nf, c in zip(num_filters, n_cols)] starts = [sum(n_cols[:i]) + i for i in range(len(n_cols))] spacers = len(sizes) - 1 n_cols_total = sum(n_cols) + spacers n_rows_total = max(n_rows) max_height = max(sizes, key=lambda x: x[0])[0] max_width = max(sizes, key=lambda x: x[1])[1] fig = plt.figure() fig.set_figheight(max_height) fig.set_figwidth(max_width) if widths is None: _widths = [sizes[0][1] for _ in range(n_cols[0])] for i in range(1, len(n_cols)): _widths += [math.ceil(max_width / 2) if space_width is None else space_width] + [sizes[i][1] for _ in range(n_cols[i])] else: _widths = [widths[0] for _ in range(n_cols[0])] for i in range(1, len(n_cols)): _widths += [math.ceil(max(widths) / 2) if space_width is None else space_width] + [widths[i] for _ in range(n_cols[i])] spec = gridspec.GridSpec(ncols=n_cols_total, nrows=n_rows_total, width_ratios=_widths, wspace=0.1, hspace=0.1) # plt.figure(figsize=(n_cols_total*5, n_rows_total*2)) # plt.subplots_adjust(hspace=0.1, wspace=0.1) for idx in range(len(Ts)): T = Ts[idx] size = n_cols[idx] start = starts[idx] for i in range(num_filters[idx]): r = math.floor(i / n_cols[idx]) #math.sqrt(num_filters[idx])) extra = n_cols_total - size point = start + i + (r * extra) # plt.subplot(n_rows_total, n_cols_total, point) ax = fig.add_subplot(spec[point]) if r == 0: ax.set_title(str(chr(i + 65)), weight="bold") if i % n_cols[idx] == 0: ax.set_ylabel(str(r + 1), rotation=0, labelpad=15, size=max_height, weight="bold") filter = T[i, :] ax.imshow(np.reshape(filter, (sizes[idx][0], sizes[idx][1])), cmap=plt.cm.bone, interpolation='nearest') ax.axes.xaxis.set_ticks([]) ax.axes.yaxis.set_ticks([]) fig.subplots_adjust(top=0.9) fig.savefig(prefix+suffix, bbox_inches='tight') plt.close(fig)
[docs] def visualize_frame(frame, path='.', name='tmp', suffix='.jpg', **kwargs): iio.imwrite(path + '/' + name + suffix, frame.astype(jnp.uint8), **kwargs)
[docs] def visualize_gif(frames, path='.', name='tmp', suffix='.jpg', **kwargs): _frames = [f.astype(jnp.uint8) for f in frames] iio.imwrite(path + '/' + name + '.gif', _frames, **kwargs)
[docs] def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs): images = [] for i in range(f_start, f_end+1, skip): print("Reading frame " + str(i)) images.append(iio.imread(path + "/" + prefix + str(i) + suffix)) print("writing gif") iio.imwrite(path + '/training.gif', images, **kwargs)
# def visualize_norm(thetas, sizes, prefix, suffix='.jpg'):
[docs] def viz_block(thetas, sizes, prefix, suffix=".jpg", padding=1, low_rez=True): num_filters = [T.shape[1] for T in thetas] n_cols = [math.ceil(math.sqrt(nf)) for nf in num_filters] n_rows = [math.ceil(nf / c) for nf, c in zip(num_filters, n_cols)] idxs = [i for i in range(len(thetas))] if not low_rez: n_cols_size = int(sum([(t_c * cols) + padding * (cols - 1) for (t_c, _), cols in zip(sizes, n_cols)])) n_rows_size = int(sum([(t_r * rows) + padding * (rows - 1) for (_, t_r), rows in zip(sizes, n_rows)])) plt.figure(figsize=(n_cols_size, n_rows_size)) for t, num_f, (t_c, t_r), cols, rows, idx in zip(thetas, num_filters, sizes, n_cols, n_rows, idxs): c_dim = (t_c * cols) + padding * (cols - 1) r_dim = (t_r * rows) + padding * (rows - 1) full = jnp.ones((r_dim, c_dim)) * np.amax(t) for k in range(num_f): r = k // cols c = k % cols r_start = (r * (t_r + padding)) r_end = (r * (t_r + padding)) + t_r c_start = (c * (t_c + padding)) c_end = (c * (t_c + padding)) + t_c full = full.at[r_start:r_end, c_start:c_end].set( jnp.reshape(t[:, k], (t_r, t_c))) plt.subplot(1, len(thetas), idx+1) plt.imshow(full, cmap=plt.cm.bone, interpolation='nearest') plt.axis("off") plt.savefig(prefix + suffix, bbox_inches='tight') plt.clf() plt.close()