ngclearn.utils package
Subpackages
- ngclearn.utils.analysis package
- ngclearn.utils.density package
- ngclearn.utils.diffeq package
- ngclearn.utils.feature_dictionaries package
- ngclearn.utils.filters package
- ngclearn.utils.masks package
- ngclearn.utils.optim package
- ngclearn.utils.viz package
- Submodules
- ngclearn.utils.viz.classification_analysis module
- ngclearn.utils.viz.compartment_plot module
- ngclearn.utils.viz.compartment_raster module
- ngclearn.utils.viz.dim_reduce module
- ngclearn.utils.viz.raster module
- ngclearn.utils.viz.spike_plot module
- ngclearn.utils.viz.synapse_plot module
- Module contents
Submodules
ngclearn.utils.JaxProcessesMixin module
- class ngclearn.utils.JaxProcessesMixin.JaxCompiledMethod(fn, fn_ast, auxiliary_ast, namespace, extra_globals)[source]
Bases:
CompiledMethodA wrapper for a compiled method that includes jax’s jit wrapped. Used exclusively by the mixin and shouldn’t be used elsewhere.
- property source_fn
The source method not wrapped in jit
- class ngclearn.utils.JaxProcessesMixin.JaxJointProcess(*args, **kwargs)[source]
Bases:
JaxProcessesMixin,JointProcess
- class ngclearn.utils.JaxProcessesMixin.JaxMethodProcess(*args, **kwargs)[source]
Bases:
JaxProcessesMixin,MethodProcess
- class ngclearn.utils.JaxProcessesMixin.JaxProcessesMixin(name, *args, use_jit=True, **kwargs)[source]
Bases:
objectA mixin for the base Process that adds JAX functionality such as scan and implicit jit wrapping
- property previous_result
Stores and returns the last result of scan (the second returned value)
- property previous_state
Stores and returns the last returned state of scan (the first returned value)
- scan(inputs, current_state=None, store_state: bool = True, store_results: bool = True)[source]
Runs the process through jax’s scan method :param inputs: The inputs for scan (use pack rows to generate), must be a jax array :param current_state: Optional, the current state of the model, if none uses current global state :param store_state: Optional flag, should the final state be stored in the process :param store_results: Optional flag, should the final result be stored in the process
Returns: the final state, the final result
ngclearn.utils.data_loader module
Data functions and utilies for data loading.
- class ngclearn.utils.data_loader.DataLoader(design_matrices, batch_size, disable_shuffle=False, ensure_equal_batches=True, key=None)[source]
Bases:
objectA data loader object, meant to allow sampling w/o replacement of one or more named design matrices. Note that this object is iterable (and implements an __iter__() method).
- Parameters:
design_matrices – list of named data design matrices - [(“name”, matrix), …]
batch_size – number of samples to place inside a mini-batch
disable_shuffle – if True, turns off sample shuffling (thus no sampling w/o replacement)
ensure_equal_batches – if True, ensures sampled batches are equal in size (Default = True). Note that this means the very last batch, if it’s not the same size as the rest, will reuse random samples from previously seen batches (yielding a batch with a mix of vectors sampled with and without replacement).
key – PRNG key to control determinism of any underlying random values associated with this synaptic cable
ngclearn.utils.distribution_generator module
- class ngclearn.utils.distribution_generator.DistributionGenerator[source]
Bases:
object- static constant(value: float, **params: Unpack[DistributionParams]) DistributionInitializer[source]
Produces a distribution initializer for a constant distribution.
- Parameters:
value – the constant value to fill the array with
**params – the extra distribution parameters
- Returns:
a distribution initializer
- static fan_in_gaussian(**params: Unpack[DistributionParams]) DistributionInitializer[source]
Produces a distribution initializer using a fan-in Gaussian (normal) strategy. The values are sampled from a normal distribution with mean 0 and stddev = sqrt(1 / fan_in), where fan_in is inferred from the shape.
He, Kaiming, et al. “Delving deep into rectifiers: Surpassing human-level performance on imagenetclassification.” Proceedings of the IEEE international conference on computer vision. 2015.- Parameters:
**params – extra distribution parameters
- Returns:
a distribution initializer
- static fan_in_uniform(**params: Unpack[DistributionParams]) DistributionInitializer[source]
Produces a distribution initializer using a fan-in uniform strategy. The values are sampled from a uniform distribution in the range [-limit, limit], where limit = sqrt(1 / fan_in), and fan_in is inferred from the shape.
Glorot, Xavier, and Yoshua Bengio. “Understanding the difficulty of training deep feedforward neuralnetworks.” Proceedings of the thirteenth international conference on artificial intelligence and statistics.JMLR Workshop and Conference Proceedings, 2010.- Parameters:
**params – extra distribution parameters
- Returns:
a distribution initializer
- static gaussian(mean: float = 0.0, std: float = 1.0, **params: Unpack[DistributionParams]) DistributionInitializer[source]
Produces a distribution initializer for a Gaussian (normal) distribution.
- Parameters:
mean – the mean of the normal distribution
std – the standard deviation of the normal distribution
**params – the extra distribution parameters
- Returns:
a distribution initializer
- static log_gaussian(sigma: float = 1.0, **params: Unpack[DistributionParams]) DistributionInitializer[source]
Produces a distribution initializer for a log-Gaussian/normal distribution. Note that this distribution is constrained to be centered (zero-mean); thus, only a scale/standard-devation sigma can be provided as argument. This is a useful distribution to produce non-negative/ positive-valued sample values.
- Parameters:
sigma – standard deviation of the underlying normal distribution (Default: 1.)
**params – the extra distribution parameters
- Returns:
a distribution initializer
- static uniform(low: float = 0.0, high: float = 1.0, **params: Unpack[DistributionParams]) DistributionInitializer[source]
Produces a distribution initializer for a uniform distribution.
- Parameters:
low – lower bound of the uniform distribution (inclusive)
high – upper bound of the uniform distribution (exclusive)
**params – the extra distribution parameters
- Returns:
a distribution initializer
- class ngclearn.utils.distribution_generator.DistributionInitializer(*args, **kwargs)[source]
Bases:
Protocol
- class ngclearn.utils.distribution_generator.DistributionParams[source]
Bases:
TypedDictExtra parameters to be used when generating distributions. (Attributes listed below)
- Parameters:
amin – sets the lower bound of the distribution
amax – sets the upper bound of the distribution
lower_triangle – keeps the lower triangle, sets the rest to zero
upper_triangle – keeps the upper triangle, sets the rest to zero
hollow – produces a hollow distribution (zeros along the diagonal)
eye – produces an eye distribution (zeros the off-diagonal)
col_mask – single value, keeps n random columns; list values, keeps the provided column indices
row_mask – single value, keeps n random rows; list values, keeps the provided row indices
use_numpy – use default numpy
- amax: float
- amin: float
- col_mask: int | List[int]
- dtype: dtype
- eye: bool
- hollow: bool
- lower_triangle: bool
- row_mask: int | List[int]
- upper_triangle: bool
- use_numpy: bool
ngclearn.utils.io_utils module
File and OS input/output (reading/writing) utilities.
- ngclearn.utils.io_utils.deserialize(fname)[source]
Deserialization (loading) routine
- Parameters:
fname – file name from disk to deserialize
- Returns:
deserialized object from disk
- ngclearn.utils.io_utils.makedir(directory)[source]
Creates a folder/directory on disk
- Parameters:
directory – string name of directory/folder to create on disk
ngclearn.utils.matrix_utils module
- ngclearn.utils.matrix_utils.decompose_to_mps(W, bond_dim=16)[source]
Decomposes a dense matrix W into two MPS cores using SVD.
- Parameters:
W – The dense matrix to decompose of shape (in_dim, out_dim).
bond_dim – The internal rank/bond-dimension of the MPS compression.
- Returns:
core1: First tensor core of shape (1, in_dim, bond_dim). core2: Second tensor core of shape (bond_dim, out_dim, 1).
- Return type:
A tuple containing
ngclearn.utils.metric_utils module
Metric and measurement routines and co-routines. These functions are useful for model-level/simulation analysis as well as experimental inspection and probing (many of these are neuroscience-oriented measurement functions).
- ngclearn.utils.metric_utils.analyze_scores(mu, y, extract_label_indx=True)[source]
Analyzes a set of prediction matrix and target/ground-truth matrix or vector.
- Parameters:
mu – prediction (design) matrix; shape is (N x C) where C is number of classes and N is the number of patterns examined
y – target / ground-truth (design) matrix; shape is (N x C) OR an array of class integers of length N (with “extract_label_indx = True”)
extract_label_indx – wehn True, run an argmax to pull class integer indices from “y”, assuming y is a one-hot binary encoding matrix (Default: True), otherwise, if False, this treats “y” is an array of class integer indices of length N
- Returns:
confusion matrix, precision, recall, misses (empty predictions/all-zero rows), accuracy, adjusted-accuracy (counts all misses as incorrect)
- ngclearn.utils.metric_utils.measure_ACC(mu, y, extract_label_indx=True)[source]
Calculates the accuracy (ACC) given a matrix of predictions and matrix of targets.
- Parameters:
mu – prediction (design) matrix; shape is (N x C) where C is number of classes and N is the number of patterns examined
y – target / ground-truth (design) matrix; shape is (N x C) OR an array of class integers of length N (with “extract_label_indx = True”)
extract_label_indx – run an argmax to pull class integer indices from “y”, assuming y is a one-hot binary encoding matrix (Default: True), otherwise, this assumes “y” is an array of class integer indices of length N
- Returns:
scalar accuracy score
- ngclearn.utils.metric_utils.measure_ARI(labels_true: Array, labels_pred: Array) Array[source]
Computes the adjusted random index (ARI), which measures similarity between two sets of indices (ground truth against a clustering’s produced indices) via counting the pairs of data points assigned to same or different clusters (adjusted for chance). This measurement lies in [0, 1], where 0 indicates a random labeling/assignment and 1 indicates perfect agreement.
- Parameters:
labels_true – 1D array of shape (n_samples,) with true integer class labels.
labels_pred – 1D array of shape (n_samples,) with predicted integer cluster labels.
- Returns:
scalar ARI of these two sets of indices
- ngclearn.utils.metric_utils.measure_BCE(p, x, offset=1e-07, preserve_batch=False)[source]
Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE). Note: If batch is preserved, this returns a column vector where each row is the BCE(p, x) for that row’s datapoint.
- Parameters:
p – predicted probabilities of shape; (N x D matrix)
x – target binary values (data) of shape; (N x D matrix)
offset – factor to control for numerical stability (Default: 1e-7)
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score
- Returns:
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
- ngclearn.utils.metric_utils.measure_BIC(X, n_model_params, max_model_score, is_log=True)[source]
Measures the Bayesian information criterion (BIC) with respect to the final score obtained by the model on a given dataset.
BIC = -2 ln(L) + K * ln(N);where N is number of data-points/rows of design matrix X,K is total number parameters of the model of interest, andL is the max/best-found value of a likelihood-like score L of the model- Parameters:
X – dataset/design matrix that a model was fit to (max-likelihood optimized)
n_model_params – total number of model parameters (int)
max_model_score – max likelihood-like score obtained by model on X
is_log – is supplied max_model_score a log-likelihood? if this is False, this metric will apply a natural logarithm of the score (Default: True)
- Returns:
scalar for the Bayesian information criterion score
- ngclearn.utils.metric_utils.measure_CatNLL(p, x, offset=1e-07, preserve_batch=False)[source]
Measures the negative Categorical log likelihood (Cat.NLL). Note: If batch is preserved, this returns a column vector where each row is the Cat.NLL(p, x) for that row’s datapoint.
- Parameters:
p – predicted probabilities; (N x C matrix, where C is number of categories)
x – true one-hot encoded targets; (N x C matrix, where C is number of categories)
offset – factor to control for numerical stability (Default: 1e-7)
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score
- Returns:
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
- ngclearn.utils.metric_utils.measure_FMI(labels_true: Array, labels_pred: Array) Array[source]
Calculates the Fowlkes-Mallows Index (FMI), which measures similarity between two sets of indices - this score is the geometric mean of pair-wise recall and precision. This measurement lies in [0, 1], where higher is better (indicating greater similarity between two clustering sets of identifiers).
- Parameters:
labels_true – 1D array of shape (n_samples,) with true integer class labels.
labels_pred – 1D array of shape (n_samples,) with predicted integer cluster labels.
- Returns:
scalar FMI of these two sets of indices
- ngclearn.utils.metric_utils.measure_KLD(p_xHat, p_x, preserve_batch=False)[source]
Measures the (raw) Kullback-Leibler divergence (KLD), assuming that the two input arguments contain valid probability distributions (in each row, if they are matrices). Note: If batch is preserved, this returns a column vector where each row is the KLD(x_pred, x_true) for that row’s datapoint. (Further note that this function does not assume any particular distribution when calculating KLD)
Formula:KLD(p_xHat, p_x) = (1/N) [ sum_i(p_x * jnp.log(p_x)) - sum_i(p_x * jnp.log(p_xHat)) ]where sum_i implies summing across dimensions of vector-space of p_x- Parameters:
p_xHat – predicted probabilities; (N x C matrix, where C is number of categories)
p_x – ground true probabilities; (N x C matrix, where C is number of categories)
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score
- Returns:
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
- ngclearn.utils.metric_utils.measure_MAE(shift, x, preserve_batch=False)[source]
Measures mean absolute error (MAE), or the negative Laplacian log likelihood with scale of 1.0. Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row’s datapoint.
- Parameters:
shift – predicted values (mean); (N x D matrix)
x – target values (data); (N x D matrix)
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score
- Returns:
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
- ngclearn.utils.metric_utils.measure_MSE(mu, x, preserve_batch=False)[source]
Measures mean squared error (MSE), or the negative Gaussian log likelihood with variance of 1.0. Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row’s datapoint.
- Parameters:
mu – predicted values (mean); (N x D matrix)
x – target values (data); (N x D matrix)
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score
- Returns:
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
- ngclearn.utils.metric_utils.measure_RMSE(mu, x, preserve_batch=False)[source]
Measures root mean squared error (RMSE). Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row’s datapoint. (THis is a simple wrapper/extension of the in-built MSE.)
- Parameters:
mu – predicted values (mean); (N x D matrix)
x – target values (data); (N x D matrix)
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score
- Returns:
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
- ngclearn.utils.metric_utils.measure_Vmeasure(labels_true: Array, labels_pred: Array, beta: float = 1.0) Array[source]
Calculates the V-Measure scoring metric for class conformity. This measurement compares predicted cluster indices (“labels_pred”) against ground truth indices (“labels_true”) and represents the harmonic mean of homogeneity (where each cluster contains only members of a single class) as well as completeness (where all members of a given class are assigned to the same cluster). This measurement (higher is better) lies in [0,1] where 1 indicates perfect, correct clustering.
- Parameters:
labels_true – 1D array of shape (n_samples,) with true integer class labels
labels_pred –
1D array of shape (n_samples,) with predicted integer cluster labels
beta: Weight factor. Ratios > 1.0 favor completeness, < 1.0 favor homogeneity.
- Returns:
scalar V-measure of these two sets of indices
- ngclearn.utils.metric_utils.measure_breadth_TC(spikes, preserve_batch=False)[source]
Calculates the breath tuning curve (BTC) of a group of neurons given full spike train.(s). BTC measures the neural selectivity such that the sparse code distribution concentrates near zero with a heavy tail. For a neural layer where most of the neurons fire, the activity distribution is more uniformly spread and BTC > 0.5. When most of the neurons do not fire, the firing distribution is peaked at zero and BTC < 0.5.
- Parameters:
spikes – full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster
preserve_batch – if True, will return one score per neuron in train/window (Default: False), otherwise, returns scalar average score
- Returns:
a 1 x D BTC vector (one factor per neuron) OR a single average BTC across the neuronal group
- ngclearn.utils.metric_utils.measure_excess_kurtosis(codes: Array, preserve_batch: bool = False) float[source]
Measures the peak and heavy-tailedness of a set of neural activation codes. Note that higher values (> 0) indicate sparse, localized ‘high-burst’ activations.
- Parameters:
codes – matrix (shape: N x D) of non-negative codes to measure sparsity of (per row)
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score
- Returns:
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
- ngclearn.utils.metric_utils.measure_fanoFactor(spikes, preserve_batch=False)[source]
Calculates the Fano factor, i.e., a secondary statistics that probes the variability of a spike train within a particular time interval.
- Parameters:
spikes – full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score
- Returns:
a 1 x D Fano factor vector (one factor per neuron) OR a single average Fano factor across the neuronal group
- ngclearn.utils.metric_utils.measure_firingRate(spikes, preserve_batch=False)[source]
Calculates the firing rate(s) of a group of neurons given full spike train.(s)
- Parameters:
spikes – full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score
- Returns:
a 1 x D firing rate vector (one firing rate per neuron) OR a single average firing rate across the neuronal group
- ngclearn.utils.metric_utils.measure_gaussian_KLD(mu1, Sigma1, mu2, Sigma2, use_chol_prec=True)[source]
Calculates the Kullback-Leibler (KL) divergence between two multivariate Gaussian distributions, i.e., KL(N(mu1, Sigma1) || N(mu2, Sigma2)). Formally, this means this routine calculates:
KL(N1 || N2) = [log(det(Sigma2)/det(Sigma1)) + trace(Prec2 * Sigma1) + (z * Prec2 * z) - D] * (1/2)where N1 is the 1st Gaussian, i.e., N(mu1,Sigma1), and N2 is the 2nd Gaussian, i.e., N(mu2,Sigma2);and where: Prec2 = (Sigma2)^{-1}, z = mu2 - mu1, and D is the data dimensionality- Parameters:
mu1 – mean vector of first Gaussian distribution
Sigma1 – covariance matrix of first Gaussian distribution
mu2 – mean vector of second Gaussian distribution
Sigma2 – covariance matrix of second Gaussian distribution
use_chol_prec – should this routine use Cholesky-factor computation of the precision of Sigma2 (Default: True)
- Returns:
scalar representing KL-divergence between N(mu1, Sigma1) and N(mu2, Sigma2)
- ngclearn.utils.metric_utils.measure_gini_index(codes, preserve_batch=True)[source]
Calculates the gini index a group of neurons represented as vector code samples. Gini index measures the sparseness of the values within each vector code, where a higher index value indicates higher sparsity and a lower index value indicates a lower sparsity (higher density).
- Parameters:
codes – a batch of neural codes; shape is (N x D) where D is number of neurons in a group/cluster and N is number of samples
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score
- Returns:
a N x 1 Gini index vector (one score per neuron) OR a single average Gini score for the whole sample/set of codes
- ngclearn.utils.metric_utils.measure_hoyer_sparsity(codes: Array, preserve_batch: bool = False) float[source]
Measures the Hoyer sparsity for a set of latent codes. Hoyer sparsity lies in [0, 1], where a value of 0.0 indicates if something is dense and a value of 1 indicates something is extremely sparse.
- Parameters:
codes – matrix (shape: N x D) of non-negative codes to measure sparsity of (per row); D is flattened latent code size
preserve_batch – if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score
- Returns:
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
- ngclearn.utils.metric_utils.measure_sparsity(codes, tolerance=0.0, preserve_batch=True, flip_measure=False)[source]
Calculates the sparsity (ratio) of an input matrix, assuming each row within this matrix is a non-negative vector.
Formally, this means we compute, per i-th row:
rho(x_i) = num_zeros(x_i) / dim(x_i)and for a global score for matrix X with N codes/rows, we measure:
rho_mean(X) = 1/N Sum^N_{i=1} rho(x_i)where lower/closer to 0 means codes more sparse and closer to 1 means codes are more dense.
Note that this definition of sparsity aligns with Foldiak’s definition of the ratio of active neurons to inactive ones (assuming binary coding):
Foldiak, Peter. “Sparse and explicit neural coding.” Principles of neuralcoding. CRC Press, 2013. 379-389.- Parameters:
codes – matrix (shape: N x D) of non-negative codes to measure sparsity of (per row)
tolerance – lowest number to consider as “empty”/non-existent (Default: 0.)
preserve_batch – if True, will return one score per sample (N x 1) in batch (Default: True), otherwise, returns scalar average/mean score
flip_measure – if True, will score sparsity via “1 - nzero/dim” (Default: False)
- Returns:
N x 1) or single score (shape: 1 x 1)
- Return type:
sparsity measurements per code (shape
ngclearn.utils.model_utils module
General modeling utility routines and co-routines. This contains useful commonly jit-i-fied mathematical functions and operations needed to design and develop ngc-learn internal components.
- ngclearn.utils.model_utils.binarize(data, threshold=0.5)[source]
Converts the vector data to its binary equivalent.
- Parameters:
data – the data to binarize (real-valued)
threshold – the cut-off point for 0, i.e., if threshold = 0.5, then any number/value inside of data < 0.5 is set to 0, otherwise, it is set to 1.0
- Returns:
the binarized equivalent of “data”
- ngclearn.utils.model_utils.bkwta(x, nWTA=5)[source]
The binarized K winner-take-all (K-WTA) function:
- Parameters:
x – input (tensor) value (real-valued)
- Returns:
output (tensor) value (binary values)
- ngclearn.utils.model_utils.chebyshev_norm(d, axis=-1, keepdims=False)[source]
Calculate the Chebyshev distance between two tensor-arrays.
- Parameters:
d – tensor d to measure against the origin
axis – axis to measure distance between the two tensors
keepdims – preserve dimensions of d
- Returns:
the Chebyshev distance (values) within d
- ngclearn.utils.model_utils.clamp_max(x, max_val)[source]
Clamps values in data x that exceed a maximum value to that value.
- Parameters:
x – data to upper-bound clamp
max_val – maximum value threshold
- Returns:
x with maximum clamped values
- ngclearn.utils.model_utils.clamp_min(x, min_val)[source]
Clamps values in data x that exceed a minimum value to that value.
- Parameters:
x – data to lower-bound clamp
min_val – minimum value threshold
- Returns:
x with minimum clamped values
- ngclearn.utils.model_utils.create_block_matrix(map_matrix, group_shape, alpha_inh=-1.0, alpha_exc=1.0)[source]
- ngclearn.utils.model_utils.create_function(fun_name, args=None)[source]
Activation function creation routine.
- Parameters:
fun_name – string name of activation function to produce; Currently supports: “tanh”, “bkwta” (binary K-winners-take-all), “sigmoid”, “relu”, “lrelu”, “relu6”, “elu”, “silu”, “gelu”, “softplus”, “softmax”, “unit_threshold”, “heaviside”, “identity”
- Returns:
function fx, first derivative of function (w.r.t. input) dfx
- ngclearn.utils.model_utils.d_heaviside(x)[source]
Derivative of the Heaviside function; specifically, this employs the straight-through estimator (STE) as a proxy/surrogate derivative instead.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.d_identity(x)[source]
Derivative of the identity function.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.d_lkwta(x, m, nWTA=1, clipval=-1.0)[source]
- Derivative of group-based K-WTA function with respect to its input. This function, local K-WTA (LKWTA),
was proposed in:
Ororbia, Alexander, Karl Friston, and Rajesh PN Rao. “Meta-representational predictive coding: biomimeticself-supervised learning.” arXiv preprint arXiv:2503.21796 (2025).- Parameters:
x – data to apply quantile-KWTA function over
m – masking tensor
nWTA – number of winners
clipval
- Returns:
y = LKWTA(x)
- ngclearn.utils.model_utils.d_lrelu(x)[source]
Derivative of the leaky linear rectifier.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.d_quantile_lkwta(x, m, nWTA=20, clipval=-1.0)[source]
First derivative of quantile-based K-WTA function with respect to its input - NOTE: this is experimental and no guarantees are offered at this point.
- Parameters:
x – data to apply quantile-KWTA function over
m – masking tensor
nWTA – number of winners
clipval
- Returns:
y = KWTA(x)
- ngclearn.utils.model_utils.d_relu(x)[source]
Derivative of the linear rectifier.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.d_relu6(x)[source]
Derivative of the bounded leaky linear rectifier (upper bounded at 6).
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.d_sigmoid(x)[source]
Derivative of the sigmoid / logistic-link function.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.d_sine(x, omega_0=30)[source]
The derivative of the sine function:
f’(x) = frequency * cos(x * frequency); where frequency = omega_0- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input)
- ngclearn.utils.model_utils.d_softmax(x, tau=0.0, vmap_form=False)[source]
Derivative of the softmax function. Note that this returns specifically the Jacobian tensor Jx of softmax(x) w.r.t. potential batch set of vectors (one per row).
- Parameters:
x – input (tensor) value (B x D)
vmap_form – optional algorithm switch flag; if True, Jx is computed using Jax vmap (Default: False)
- Returns:
output (tensor) derivative values (Jacobian with respect to input argument; B x D x D)
- ngclearn.utils.model_utils.d_softplus(x)[source]
Derivative of the softplus function.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.d_tanh(x)[source]
Derivative of the hyperbolic tangent function.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.d_telu(x)[source]
Derivative of the hyperbolic tangent exponential linear (TeLU) function. Effectively, this is formally:
f’(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x))- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input)
- ngclearn.utils.model_utils.d_threshold(x, thr=1.0)[source]
Derivative of the threshold function; specifically, this employs the straight-through estimator (STE) as a proxy/surrogate derivative instead.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) derivative value (with respect to input argument)
- ngclearn.utils.model_utils.drop_out(dkey, data, rate=0.0)[source]
Applies a drop-out transform (i.e., a random number of elements will be dropped to zero) to an input matrix.
- Parameters:
dkey – Jax randomness key for this operator
data – input data to apply random/drop-out mask to
rate – probability of a dimension being dropped
- Returns:
output as well as binary mask
- ngclearn.utils.model_utils.elu(x, alpha=1.0)[source]
Applies the exponential linear unit (ELU) activation.
- Parameters:
x – data to transform via inverse logistic function
alpha – coefficient/parameters to weight input x by
- Returns:
output of the GeLU activation
- ngclearn.utils.model_utils.eye_wrapped(N, k, values)[source]
Creates an N x N matrix with a wrapped off-diagonal.
- Parameters:
N – Size of the square matrix (N x N)
k – Diagonal offset (positive=above, negative=below)
values – Array of values to place (length should match n)
- ngclearn.utils.model_utils.gelu(x)[source]
Applies the Gaussian Error Linear Unit (GeLU) activation (specifically, a fast approximation is used via a weighted swish).
- Parameters:
x – data to transform via inverse logistic function
- Returns:
output of the GeLU activation
- ngclearn.utils.model_utils.heaviside(x)[source]
The Heaviside function:
f(x) = 1 if x >= 0, otherwise 0 (for x < 0)- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.identity(x)[source]
The identity function: x = f(x).
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.inverse_logistic(x, clip_bound=0.03)[source]
The inverse logistic link - the logit function.
- Parameters:
x – data to transform via inverse logistic function
clip_bound – pre-processing lower/upper bounds to enforce on data before applying inverse logistic
- Returns:
x transformed via inverse logistic function
- ngclearn.utils.model_utils.inverse_tanh(x)[source]
The inverse hyperbolic tangent.
- Parameters:
x – data to transform via inverse hyperbolic tangent
clip_bound – pre-processing lower/upper bounds to enforce on data before applying inverse hyperbolic tangent
- Returns:
x transformed via inverse hyperbolic tangent
- ngclearn.utils.model_utils.layer_normalize(x, shift=0.0, scale=1.0)[source]
Applies layer normalization to input data x
- Parameters:
x – data to apply threshold function over
shift – the compensating mean/shift factor/parameters (to undo mean subtraction)
scale – the compensating re-scaling factor/parameters (to undo standard deviation division)
- Returns:
layer-normalized data samples x
- ngclearn.utils.model_utils.lkwta(x, m, nWTA=1, clipval=-1.0)[source]
A group-based K-WTA function, i.e., local K-WTA (LKWTA), as proposed in:
Ororbia, Alexander, Karl Friston, and Rajesh PN Rao. “Meta-representational predictive coding: biomimeticself-supervised learning.” arXiv preprint arXiv:2503.21796 (2025).- Parameters:
x – data to apply quantile-KWTA function over
m – masking tensor
nWTA – number of winners
clipval
- Returns:
y = LKWTA(x)
- ngclearn.utils.model_utils.lrelu(x)[source]
The leaky linear rectifier: max(0, x) if x >= 0, 0.01 * x if x < 0 = f(x).
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.normalize_block_matrix(matrix, block_size, order=2, axis=0, norm_targ=1.0)[source]
Normalizes columns of blocks within a matrix.
- Parameters:
matrix – 2D JAX Array (M, N)
block_size – Tuple (block_rows, block_cols)
order
axis – (relative) axis for normalization within block; 0 -> by rows, 1 -> by cols
norm_targ
- Returns:
block-normalized (M, N) matrix
- ngclearn.utils.model_utils.normalize_matrix(data, wnorm, order=1, axis=0, scale=1.0)[source]
Normalizes the values in matrix to have a particular norm across each vector span.
- Parameters:
data – (2D) data matrix to normalize
wnorm – target norm for each row/column of data matrix
order – order of norm to use in normalization (Default: 1); note that ord=1 results in the L1-norm, ord=2 results in the L2-norm
axis – 0 (apply to column vectors), 1 (apply to row vectors)
scale – step modifier to produce the projected matrix (Unused)
- Returns:
a normalized value matrix
- ngclearn.utils.model_utils.one_hot(P)[source]
Converts a matrix of probabilities to a corresponding binary one-hot matrix (each row is a one-hot encoding).
- Parameters:
P – a probability matrix where each row corresponds to a particular data probability vector
- Returns:
the one-hot encoding (matrix) of probabilities in P
- ngclearn.utils.model_utils.pull_equations(controller)[source]
Extracts the dynamics string of this controller (model/system).
- Parameters:
controller – model/system to extract dynamics equation(s) from
- Returns:
string containing this model/system’s dynamics equation(s)
- ngclearn.utils.model_utils.quantile_lkwta(x, m, nWTA=20, clipval=-1.0)[source]
A quantile-based K-WTA function - NOTE: this is experimental and no guarantees are offered at this point.
- Parameters:
x – data to apply quantile-KWTA function over
m – masking tensor
nWTA – number of winners
clipval
- Returns:
y = KWTA(x)
- ngclearn.utils.model_utils.relu(x)[source]
The linear rectifier: max(0, x) = f(x).
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.relu6(x)[source]
The linear rectifier upper bounded at the value of 6: min(max(0, x), 6.).
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.sigmoid(x)[source]
The sigmoid / logistic-link function:
f(x) = 1/(1 + exp(-x)- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.silu(x)[source]
Applies the sigmoid-weighted linear unit (SiLU or SiL) activation. Note that this is primarily a convenience wrapper function for the swish activation.
- Parameters:
x – data to transform via inverse logistic function
- Returns:
output of the Swish activation
- ngclearn.utils.model_utils.sine(x, omega_0=30)[source]
The sine function, parameterized by frequency omega:
f(x) = sin(x * omega_0).- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.softmax(x, tau=0.0)[source]
Softmax function with overflow control built in directly. Contains optional temperature parameter to control sharpness (tau > 1 softens probs, < 1 sharpens –> 0 yields point-mass).
- Parameters:
x – a (N x D) input argument (pre-activity) to the softmax operator
tau – probability sharpening/softening factor, if > 0.; else, <= 0 disables this (Default: 0.)
- Returns:
a (N x D) probability distribution output block
- ngclearn.utils.model_utils.softplus(x)[source]
The softplus elementwise function:
f(x) = ln(1 + exp(-x))- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.swish(x, beta)[source]
Applies the Swish parameterized activation, proposed in Ramachandran et al., 2017 (“Searching for Activation Functions”).
- Parameters:
x – data to transform via inverse logistic function
beta – coefficient/parameters to weight input x by
- Returns:
output of the Swish activation
- ngclearn.utils.model_utils.tanh(x)[source]
The hyperbolic tangent function.
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.telu(x)[source]
The hyperbolic tangent exponential linear (TeLU) function:
f(x) = x * tanh(e^x)This was proposed by Fernandez and Mali 24 in:
- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.tensorstats(tensor)[source]
Prints tensor statistics (debugging tool).
- Parameters:
tensor – argument tensor object to examine
- Returns:
useful statistics to print to I/O
- ngclearn.utils.model_utils.threshold(x, thr=1.0)[source]
The threshold function (or Heaviside but with a non-zero boundary):
f(x) = 1 if x >= thr, otherwise 0 (for x < thr)- Parameters:
x – input (tensor) value
- Returns:
output (tensor) value
- ngclearn.utils.model_utils.threshold_cauchy(x, lmbda)[source]
A Cauchy distributional threshold routine applied to each dimension of input. (Note that this function does not contain a complementary derivative.)
- Parameters:
x – data to apply threshold function over
lmbda – scalar to control strength/influence of Cauchy thresholding
- Returns:
thresholded x
- ngclearn.utils.model_utils.threshold_soft(x, lmbda)[source]
A soft threshold routine applied to each dimension of input. (Note that this function does not contain a complementary derivative.)
- Parameters:
x – data to apply threshold function over
lmbda – scalar to control strength/influence of thresholding
- Returns:
thresholded x
ngclearn.utils.patch module
- class ngclearn.utils.patch.PatchGenerator(patch_height: int, patch_width: int, horizontal_alignment: Literal['left', 'right', 'center', 'fit'] = None, vertical_alignment: Literal['top', 'bottom', 'center', 'fit'] = None, horizontal_stride: int | None = None, vertical_stride: int | None = None)[source]
Bases:
object
ngclearn.utils.patch_utils module
Image/tensor patching utility routines.
- class ngclearn.utils.patch_utils.Create_Patches(img, patch_shape, overlap_shape)[source]
Bases:
objectThis function will create small patches out of the image based on the provided attributes.
- Parameters:
img – jax array of size (H, W)
patch_shape – (height_patch, width_patch)
overlap_shape – (height_overlap, width_overlap)
- Returns:
Array containing the patches, shape: (num_patches, patch_height, patch_width)
- Return type:
jnp.array
- create_patches(add_frame=False, center=True)[source]
This function will create small patches out of the image based on the provided attributes.
- Keyword Arguments:
add_frame – If true the function will add zero frames (increase the dimension) to the image
center
- Returns:
Array containing the patches shape: (num_patches, patch_height, patch_width)
- Return type:
jnp.array
- ngclearn.utils.patch_utils.generate_pacthify_patch_set(x_batch_, patch_size=(5, 5), center=True)[source]
- ngclearn.utils.patch_utils.generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234, vis_mode=False, step_size=(1, 1))[source]
Generates a set of patches from an array/list of image arrays (via random sampling with replacement). This uses scikit-learn’s patch creation function to generate a set of (px x py) patches. Note: this routine also subtracts each patch’s mean from itself.
- Parameters:
x_batch – the array of image arrays to sample from
patch_size – a 2-tuple of the form (pH = patch height, pW = patch width)
max_patches – maximum number of patches to extract/generate from source images
center – centers each patch by subtracting the patch mean (per-patch)
seed – seed to control the random state of internal patch sampling
- Returns:
an array (D x (pH * pW)), where each row is a flattened patch sample
ngclearn.utils.surrogate_fx module
A builder sub-package for spike emission functions that are mapped to surrogate (derivative) functions; these function builders are useful if differentiation through the discrete spike emission steps in spiking neuronal cells is required (e.g., cases of surrogate backprop, broadcast feedback alignment schemes, etc.). Calling the builder estimator functions below returns the following routines: 1) a spike emission routine spike_fx; 2) the surrogate function used to approximate spike emission surr_fx; 3) the corresponding surrogate derivative routine d_spike_fx.
- ngclearn.utils.surrogate_fx.arctan_estimator(get_surr_fx=False)[source]
The arctan surrogate gradient estimator for binary spike emission.
E(x) = (1/pi) arctan(pi * x)dE(x)/dx = (1/pi) (1/(1 + (pi * x)^2))where x = v (membrane potential/voltage)- Returns:
( spike_fx(x, thr), d_spike_fx(x, thr) ) OR ( spike_fx(x, thr), surr_fx(x, thr, args), d_spike_fx(x, thr, args) )
- ngclearn.utils.surrogate_fx.secant_lif_estimator(get_surr_fx=False)[source]
Surrogate function for computing derivative of (binary) spike function with respect to the input electrical current/drive to a leaky integrate-and-fire (LIF) neuron. (Note this is only useful for LIF neuronal dynamics.)
spike_fx(x) ~ E(x) = sech(x) = 1/cosh(x), cosh(x) = (e^x + e^(-x))/2dE(x)/dx = (c1 c2) * sech^2(c2 * x) for x > 0 and 0 for x <= 0where x = j (electrical current)Reference:Samadi, Arash, Timothy P. Lillicrap, and Douglas B. Tweed. “Deep learning withdynamic spiking neurons and fixed feedback weights.” Neural computation 29.3(2017): 578-602.- Parameters:
get_surr_fx – if True, makes this function also return the surrogate function that the derivative corresponds to
- Returns:
( spike_fx(x, thr), d_spike_fx(x, thr) ) OR ( spike_fx(x, thr), surr_fx(x, thr, args), d_spike_fx(x, thr, args) )
- ngclearn.utils.surrogate_fx.straight_through_estimator(get_surr_fx=False)[source]
The straight-through estimator (STE) applied to binary spike emission (the Heaviside function).
Bengio, Yoshua, Nicholas Léonard, and Aaron Courville. “Estimating orpropagating gradients through stochastic neurons for conditionalcomputation.” arXiv preprint arXiv:1308.3432 (2013).- Returns:
( spike_fx(x, thr), d_spike_fx(x, thr) ) OR ( spike_fx(x, thr), surr_fx(x, thr, args), d_spike_fx(x, thr, args) )