Source code for ngclearn.utils.data_loader

"""
Data functions and utilies for data loading.
"""
from jax import random, numpy as jnp, jit, nn
import io, sys #, math

[docs] class DataLoader(object): """ A data loader object, meant to allow sampling w/o replacement of one or more named design matrices. Note that this object is iterable (and implements an __iter__() method). Args: design_matrices: list of named data design matrices - [("name", matrix), ...] batch_size: number of samples to place inside a mini-batch disable_shuffle: if True, turns off sample shuffling (thus no sampling w/o replacement) ensure_equal_batches: if True, ensures sampled batches are equal in size (Default = True). Note that this means the very last batch, if it's not the same size as the rest, will reuse random samples from previously seen batches (yielding a batch with a mix of vectors sampled with and without replacement). key: PRNG key to control determinism of any underlying random values associated with this synaptic cable """ def __init__(self, design_matrices, batch_size, disable_shuffle=False, ensure_equal_batches=True, key=None): self.key = key if self.key is None: self.key = random.PRNGKey(time.time_ns()) self.batch_size = batch_size self.ensure_equal_batches = ensure_equal_batches self.disable_shuffle = disable_shuffle self.design_matrices = design_matrices if len(design_matrices) < 1: print(" ERROR: design_matrices must contain at least one design matrix!") sys.exit(1) self.data_len = len( self.design_matrices[0][1] ) self.ptrs = jnp.arange(0, self.data_len, 1) if self.data_len < self.batch_size: print("ERROR: batch size {} is > total number data samples {}".format( self.batch_size, self.data_len)) sys.exit(1) def __iter__(self): """ Yields a mini-batch of the form: [("name", batch),("name",batch),...] """ if self.disable_shuffle == False: self.key, *subkeys = random.split(self.key, 2) self.ptrs = random.permutation(subkeys[0], self.data_len) idx = 0 while idx < len(self.ptrs): # go through each sample via the sampling pointer e_idx = idx + self.batch_size if e_idx > len(self.ptrs): # prevents reaching beyond length of dataset e_idx = len(self.ptrs) # extract sampling integer pointers indices = self.ptrs[idx:e_idx] if self.ensure_equal_batches == True: if indices.shape[0] < self.batch_size: diff = self.batch_size - indices.shape[0] indices = jnp.concatenate((indices, self.ptrs[0:diff])) # create the actual pattern vector batch block matrices data_batch = [] for dname, dmatrix in self.design_matrices: x_batch = dmatrix[indices] data_batch.append((dname, x_batch)) yield data_batch idx = e_idx