Source code for ngclearn.utils.model_utils

"""
General modeling utility routines and co-routines. This contains useful
commonly jit-i-fied mathematical functions and operations needed to design
and develop ngc-learn internal components.
"""
import jax
from jax import numpy as jnp, grad, jit, vmap, random, lax, nn
from jax.lax import scan as _scan
import os, sys
from functools import partial
import numpy as np

[docs] def tensorstats(tensor): """ Prints tensor statistics (debugging tool). Args: tensor: argument tensor object to examine Returns: useful statistics to print to I/O """ if isinstance(tensor, (np.ndarray, jax.Array, jnp.ndarray)): _tensor = np.asarray(tensor) return { 'mean': _tensor.mean(), 'std': _tensor.std(), 'mag': np.abs(_tensor).max(), 'min': _tensor.min(), 'max': _tensor.max(), } elif isinstance(tensor, (list, tuple, dict)): try: values, _ = jax.tree.flatten(jax.tree.map(lambda x: x.flatten(), tensor)) values = np.asarray(np.stack(values)) return { 'mean': values.mean(), 'std': values.std(), 'mag': np.abs(values).max(), 'min': values.min(), 'max': values.max(), } except: return None else: return None
[docs] def pull_equations(controller): """ Extracts the dynamics string of this controller (model/system). Args: controller: model/system to extract dynamics equation(s) from Returns: string containing this model/system's dynamics equation(s) """ eqn_set = "" for _name in controller.components: component = controller.components[_name] ## determine if component has an equation and pull it out if so for attr in dir(component): if not callable(getattr(component, attr)) and not attr.startswith("__"): if attr == "equation": eqn = "{}".format(attr) ## extract defined equation eqn_set = "{}\n{}: {}".format(_name, eqn) return eqn_set
[docs] def create_function(fun_name, args=None): """ Activation function creation routine. Args: fun_name: string name of activation function to produce; Currently supports: "tanh", "bkwta" (binary K-winners-take-all), "sigmoid", "relu", "lrelu", "relu6", "elu", "silu", "gelu", "softplus", "softmax" (derivative not supported), "unit_threshold", "heaviside", "identity" Returns: function fx, first derivative of function (w.r.t. input) dfx """ fx = None ## the function dfx = None ## the first derivative of function w.r.t. its input if fun_name == "tanh": fx = tanh dfx = d_tanh elif fun_name == "bkwta": ## NOTE: this requires an auxiliary arg to be set ## which means this only supports binary WTA fx = bkwta ## nWTA=1 ## NOTE: this is an improper derivative proxy dfx = bkwta #d_identity elif fun_name == "sine": fx = sine dfx = d_sine elif fun_name == "sigmoid": fx = sigmoid dfx = d_sigmoid elif fun_name == "relu": fx = relu dfx = d_relu elif fun_name == "lrelu": fx = lrelu dfx = d_lrelu elif fun_name == "relu6": fx = relu6 dfx = d_relu6 elif fun_name == "elu": fx = elu dfx = d_elu elif fun_name == "silu": fx = silu dfx = d_silu elif fun_name == "gelu": fx = gelu dfx = d_gelu elif fun_name == "telu": fx = telu dfx = d_telu elif fun_name == "softplus": fx = softplus dfx = d_softplus elif fun_name == "softmax": fx = softmax ## NOTE: below is an improper derivative proxy ## correct dfx is a Jacobian of softmax (not currently supported!) dfx = d_identity elif fun_name == "unit_threshold": fx = threshold ## default threshold is 1 (thus unit) dfx = d_threshold ## STE approximation elif "heaviside" in fun_name: fx = heaviside dfx = d_heaviside ## STE approximation elif fun_name == "identity": fx = identity dfx = d_identity else: raise RuntimeError( "Activation function (" + fun_name + ") is not recognized/supported!" ) return fx, dfx
[docs] @partial(jit, static_argnums=[1]) def bkwta(x, nWTA=5): #5 10 15 #K=50): values, indices = lax.top_k(x, nWTA) # Note: we do not care to sort the indices kth = jnp.expand_dims(jnp.min(values,axis=1),axis=1) # must do comparison per sample in potential mini-batch topK = jnp.greater_equal(x, kth).astype(jnp.float32) # cast booleans to floats return topK
[docs] @partial(jit, static_argnums=[2, 3, 4]) def normalize_matrix(data, wnorm, order=1, axis=0, scale=1.): """ Normalizes the values in matrix to have a particular norm across each vector span. Args: data: (2D) data matrix to normalize wnorm: target norm for each row/column of data matrix order: order of norm to use in normalization (Default: 1); note that `ord=1` results in the L1-norm, `ord=2` results in the L2-norm axis: 0 (apply to column vectors), 1 (apply to row vectors) scale: step modifier to produce the projected matrix (Unused) Returns: a normalized value matrix """ if order == 2: ## denominator is L2 norm wOrdSum = jnp.maximum(jnp.sqrt(jnp.sum(jnp.square(data), axis=axis, keepdims=True)), 1e-8) else: ## denominator is L1 norm wOrdSum = jnp.maximum(jnp.sum(jnp.abs(data), axis=axis, keepdims=True), 1e-8) m = (wOrdSum == 0.).astype(dtype=jnp.float32) wOrdSum = wOrdSum * (1. - m) + m #wAbsSum[wAbsSum == 0.] = 1. _data = data * (wnorm/wOrdSum) #d_data = ((wnorm/wOrdSum) - 1.) * data #_data = data + d_data * scale return _data
[docs] @jit def clamp_min(x, min_val): """ Clamps values in data x that exceed a minimum value to that value. Args: x: data to lower-bound clamp min_val: minimum value threshold Returns: x with minimum clamped values """ mask = (x > min_val).astype(jnp.float32) _x = x * mask + (1. - mask) * min_val return _x
[docs] @jit def clamp_max(x, max_val): """ Clamps values in data x that exceed a maximum value to that value. Args: x: data to upper-bound clamp max_val: maximum value threshold Returns: x with maximum clamped values """ # condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type] #a = torch.where(condition, a, max) mask = (x < max_val).astype(jnp.float32) _x = x * mask + (1. - mask) * max_val return _x
[docs] @jit def one_hot(P): """ Converts a matrix of probabilities to a corresponding binary one-hot matrix (each row is a one-hot encoding). Args: P: a probability matrix where each row corresponds to a particular data probability vector Returns: the one-hot encoding (matrix) of probabilities in P """ nC = P.shape[1] # compute number of dimensions/classes p_t = jnp.argmax(P, axis=1) return nn.one_hot(p_t, num_classes=nC, dtype=jnp.float32)
[docs] @partial(jit, static_argnums=[1, 2]) def chebyshev_norm(d, axis=-1, keepdims=False): """ Calculate the Chebyshev distance between two tensor-arrays. Args: d: tensor d to measure against the origin axis: axis to measure distance between the two tensors keepdims: preserve dimensions of d Returns: the Chebyshev distance (values) within d """ abs_diff = jnp.abs(d) ## d could be (a - b) externally dist_vals = jnp.max(abs_diff, axis=axis, keepdims=keepdims) return dist_vals
[docs] @jit def binarize(data, threshold=0.5): """ Converts the vector *data* to its binary equivalent Args: data: the data to binarize (real-valued) threshold: the cut-off point for 0, i.e., if threshold = 0.5, then any number/value inside of data < 0.5 is set to 0, otherwise, it is set to 1.0 Returns: the binarized equivalent of "data" """ return (data > threshold).astype(jnp.float32)
[docs] @jit def identity(x): """ The identity function: x = f(x). Args: x: input (tensor) value Returns: output (tensor) value """ return x + 0
[docs] @jit def d_identity(x): """ Derivative of the identity function. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ return x * 0 + 1.
[docs] @jit def relu(x): """ The linear rectifier: max(0, x) = f(x). Args: x: input (tensor) value Returns: output (tensor) value """ return nn.relu(x)
[docs] @jit def d_relu(x): """ Derivative of the linear rectifier. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ return (x >= 0.).astype(jnp.float32)
[docs] @jit def telu(x): """ The hyperbolic tangent exponential linear (TeLU) function: | f(x) = x * tanh(e^x) This was proposed by Fernandez and Mali 24 in: | https://arxiv.org/abs/2412.20269 and in, | https://arxiv.org/abs/2402.02790 Args: x: input (tensor) value Returns: output (tensor) value """ return x * jnp.tanh(jnp.exp(x))
[docs] @jit def d_telu(x): """ Derivative of the hyperbolic tangent exponential linear (TeLU) function. Effectively, this is formally: | f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x)) Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input) """ ex = jnp.exp(x) tanh_ex = jnp.tanh(ex) return tanh_ex + x * ex * (1.0 - tanh_ex ** 2)
[docs] @jit def sine(x, omega_0=30): """ f(x) = sin(x * omega_0). Args: x: input (tensor) value Returns: output (tensor) value """ return jnp.sin(omega_0 * x)
[docs] @jit def d_sine(x, omega_0=30): """ frequency = omega_0 frequency * cos(x * frequency). Args: x: input (tensor) value Returns: output (tensor) value """ return omega_0 * jnp.cos(omega_0 * x)
[docs] @jit def tanh(x): """ The hyperbolic tangent function. Args: x: input (tensor) value Returns: output (tensor) value """ return nn.tanh(x)
[docs] @jit def d_tanh(x): """ Derivative of the hyperbolic tangent function. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ tanh_x = nn.tanh(x) return -(tanh_x * tanh_x) + 1.0
[docs] @jit def inverse_tanh(x): """ The inverse hyperbolic tangent. Args: x: data to transform via inverse hyperbolic tangent clip_bound: pre-processing lower/upper bounds to enforce on data before applying inverse hyperbolic tangent Returns: x transformed via inverse hyperbolic tangent """ #m = 0.5 * log ( (ones(size(x)) + x) ./ (ones(size(x)) - x)) return jnp.log((1. + x)/(1. - x))
[docs] @jit def lrelu(x): ## activation fx """ The leaky linear rectifier: max(0, x) if x >= 0, 0.01 * x if x < 0 = f(x). Args: x: input (tensor) value Returns: output (tensor) value """ return nn.leaky_relu(x)
[docs] @jit def d_lrelu(x): ## deriv of fx (dampening function) """ Derivative of the leaky linear rectifier. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ m = (x >= 0.).astype(jnp.float32) dx = m + (1. - m) * 0.01 return dx
[docs] @jit def relu6(x): """ The linear rectifier upper bounded at the value of 6: min(max(0, x), 6.). Args: x: input (tensor) value Returns: output (tensor) value """ return nn.relu6(x)
[docs] @jit def d_relu6(x): """ Derivative of the bounded leaky linear rectifier (upper bounded at 6). Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ # df/dx = 1 if 0<x<6 else 0 # I_x = (z >= a_min) *@ (z <= b_max) //create an indicator function a = 0 b = 6 #Ix1 = (x > 0.).astype(jnp.float32) #tf.cast(tf.math.greater_equal(x, 0.0),dtype=tf.float32) Ix1 = (x >= 0.).astype(jnp.float32) Ix2 = (x <= 6.).astype(jnp.float32) #tf.cast(tf.math.less_equal(x, 6.0),dtype=tf.float32) Ix = Ix1 * Ix2 return Ix
[docs] @jit def softplus(x): """ The softplus elementwise function. Args: x: input (tensor) value Returns: output (tensor) value """ return nn.softplus(x)
[docs] @jit def d_softplus(x): """ Derivative of the softplus function. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ ## d/dx of softplus = logistic sigmoid return nn.sigmoid(x)
[docs] @jit def threshold(x, thr=1.): return (x >= thr).astype(jnp.float32)
[docs] @jit def d_threshold(x, thr=1.): return x * 0. + 1. ## straight-thru estimator
[docs] @jit def heaviside(x): return (x >= 0.).astype(jnp.float32)
[docs] @jit def d_heaviside(x): return x * 0. + 1. ## straight-thru estimator
[docs] @jit def sigmoid(x): return nn.sigmoid(x)
[docs] @jit def d_sigmoid(x): sigm_x = nn.sigmoid(x) ## pre-compute once return sigm_x * (1. - sigm_x)
[docs] def inverse_sigmoid(x, clip_bound=0.03): ## wrapper call for naming convention ease return inverse_logistic(x, clip_bound=clip_bound)
[docs] @jit def inverse_logistic(x, clip_bound=0.03): """ The inverse logistic link - the logit function. Args: x: data to transform via inverse logistic function clip_bound: pre-processing lower/upper bounds to enforce on data before applying inverse logistic Returns: x transformed via inverse logistic function """ x_ = x if clip_bound > 0.0: x_ = jnp.clip(x_, clip_bound, 1.0 - clip_bound) return jnp.log( x_/((1.0 - x_) + 1e-6) )
[docs] @jit def swish(x, beta): """ Applies the Swish parameterized activation, proposed in Ramachandran et al., 2017 ("Searching for Activation Functions"). Args: x: data to transform via inverse logistic function beta: coefficient/parameters to weight input x by Returns: output of the Swish activation """ return x * sigmoid(x * beta)
[docs] @jit def d_swish(x, beta): # df/dx = beta * [ 1/(exp(-x) + 1) + (exp(-x) * x) / (exp(-x) + 1)^2] # df/dx = beta * sigmoid(x * beta) * (1 - sigmoid(x) * beta) exp_neg_x = jnp.exp(-x) _x = (1./(exp_neg_x + 1.)) + (exp_neg_x * x)/jnp.square(exp_neg_x+1) return _x * beta
[docs] @jit def silu(x): """ Applies the sigmoid-weighted linear unit (SiLU or SiL) activation. Args: x: data to transform via inverse logistic function Returns: output of the Swish activation """ return swish(x, beta=1.)
[docs] @jit def d_silu(x): return d_swish(x, beta=1.)
[docs] @jit def gelu(x): """ Applies the Gaussian Error Linear Unit (GeLU) activation (specifically, a fast approximation is used). Args: x: data to transform via inverse logistic function Returns: output of the GeLU activation """ return swish(x, beta=1.702) ## approximate GeLU # beta=1.4
[docs] @jit def d_gelu(x): # df/dx = 1.702 * [ 1/(exp(-x) + 1) + (exp(-x) * x) / (exp(-x) + 1)^2] return d_swish(x, beta=1.702) # beta=1.4
[docs] @jit def elu(x, alpha=1.): """ Applies the exponential linear unit (ELU) activation. Args: x: data to transform via inverse logistic function alpha: coefficient/parameters to weight input x by Returns: output of the GeLU activation """ mask = x >= 0. return x * mask + ((jnp.exp(x) - 1) * alpha) * (1. - mask)
[docs] @jit def d_elu(x, alpha=1.): mask = (x >= 0.) return mask + (1. - mask) * (jnp.exp(x) * alpha)
[docs] @jit def softmax(x, tau=0.0): """ Softmax function with overflow control built in directly. Contains optional temperature parameter to control sharpness (tau > 1 softens probs, < 1 sharpens --> 0 yields point-mass). Args: x: a (N x D) input argument (pre-activity) to the softmax operator tau: probability sharpening/softening factor Returns: a (N x D) probability distribution output block """ if tau > 0.0: x = x / tau max_x = jnp.max(x, axis=1, keepdims=True) exp_x = jnp.exp(x - max_x) return exp_x / jnp.sum(exp_x, axis=1, keepdims=True)
[docs] @jit def threshold_soft(x, lmbda): """ A soft threshold routine applied to each dimension of input Args: x: data to apply threshold function over lmbda: scalar to control strength/influence of thresholding Returns: thresholded x """ # soft thresholding fx - S(x) = (|x| - lmbda) *@ sign(x) ## legacy ngclearn: tf.math.maximum(x - lmbda, 0.) - tf.math.maximum(-x - lmbda, 0.) return jnp.maximum(x - lmbda, 0.) - jnp.maximum(-x - lmbda, 0.)
[docs] @jit def threshold_cauchy(x, lmbda): """ A Cauchy distributional threshold routine applied to each dimension of input Args: x: data to apply threshold function over lmbda: scalar to control strength/influence of Cauchy thresholding Returns: thresholded x """ # threshold function based on that proposed in: https://arxiv.org/abs/2003.12507 inner_term = jnp.sqrt(jnp.maximum(jnp.square(x) - lmbda), 0.) f = (x + inner_term) * 0.5 g = (x - inner_term) * 0.5 term1 = f * (x >= lmbda).astype(jnp.float32) ## f * (x >= lmda) term2 = g * (x <= -lmbda).astype(jnp.float32) ## g * (x <= -lmda) return term1 + term2
[docs] @jit def layer_normalize(x, shift=0., scale=1.): """ Applies layer normalization to input data `x` Args: x: data to apply threshold function over shift: the compensating mean/shift factor/parameters (to undo mean subtraction) scale: the compensating re-scaling factor/parameters (to undo standard deviation division) Returns: layer-normalized data samples `x` """ xmu = jnp.mean(x, axis=1, keepdims=True) xsigma = jnp.sqrt(jnp.mean(jnp.square(x - xmu)).clip(min=1e-6)).clip(min=1e-6) _x = (x - xmu) / xsigma return _x * scale + shift
[docs] @jit def drop_out(dkey, data, rate=0.0): """ Applies a drop-out transform (i.e., a random number of elements will be dropped to zero) to an input matrix. Args: dkey: Jax randomness key for this operator data: input data to apply random/drop-out mask to rate: probability of a dimension being dropped Returns: output as well as binary mask """ eps = random.uniform(dkey, shape=data.shape, minval=0.0, maxval=1.0) mask = (eps <= (1.0 - rate)).astype(jnp.float32) mask = mask * (1.0 / (1.0 - rate)) ## apply inverted dropout scheme output = data * mask return output, mask
[docs] @jit def clip(x, min_val, max_val): return jnp.clip(x, min_val, max_val)
[docs] @jit def d_clip(x, min_val, max_val): return jnp.where((x < min_val) | (x > max_val), 0.0, 1.0)
## block-matrix generation routine
[docs] def create_block_matrix(map_matrix, group_shape, alpha_inh=-1., alpha_exc=1.): nrows, ncols = map_matrix.shape gh, gw = group_shape block = jnp.eye(gh, gw) * alpha_exc - (1. - jnp.eye(gh, gw)) * alpha_inh gmat = [] for r in range(nrows): row = [] for c in range(ncols): element = map_matrix[r][c] if element > 0.: row.append(block) else: row.append(block * 0) row = jnp.concatenate(row, axis=1) gmat.append(row) gmat = jnp.concatenate(gmat, axis=0) return gmat
[docs] @partial(jit, static_argnums=[1, 2, 3, 4]) def normalize_block_matrix(matrix, block_size, order=2, axis=0, norm_targ=1.): """ Normalizes columns of blocks within a matrix. Args: matrix: 2D JAX Array (M, N) block_size: Tuple (block_rows, block_cols) order: axis: (relative) axis for normalization within block; 0 -> by rows, 1 -> by cols norm_targ: Returns: block-normalized (M, N) matrix """ _tensor_axis = (2, 3) ## assume row-axis for normalization if axis == 1: _tensor_axis = 3 ## get col-axis for normalization elif axis == 0: _tensor_axis = 2 ## assume row-axis for normalization ## else, we leave row-axis as target for normalization M, N = matrix.shape r_blk, c_blk = block_size # Reshape to 4D to isolate blocks: (num_blocks_row, block_rows, num_blocks_col, block_cols) reshaped = matrix.reshape(M // r_blk, r_blk, N // c_blk, c_blk) # Transpose to group block data: (num_blocks_row, num_blocks_col, block_rows, block_cols) transposed = jnp.transpose(reshaped, (0, 2, 1, 3)) # Calculate norm for each column "w/in" each block ## (over axis 2 -> block_rows); (over axis 3 -> block_cols) norms = jnp.linalg.norm(transposed, ord=order, axis=_tensor_axis, keepdims=True) #normalized_blocks = jnp.divide(transposed, norms + 1e-8) ## normalize (w/ safe-division) normalized_blocks = transposed * (norm_targ/(norms + 1e-8)) # Reverse transpose: (num_blocks_row, block_rows, num_blocks_col, block_cols) reverted = jnp.transpose(normalized_blocks, (0, 2, 1, 3)) # Reshape back to original 2D shape (M, N) return reverted.reshape(M, N)
[docs] @partial(jit, static_argnums=[2, 3]) def quantile_lkwta(x, m, nWTA=20, clipval=-1.): ## local k-WTA """ A quantile-based K-WTA function - NOTE: this is experimental and no guarantees are offered at this point. Args: x: data to apply quantile-KWTA function over m: masking tensor nWTA: number of winners clipval: Returns: y = KWTA(x) """ ## expand to 3D tensor space and do logic in 3D _x = jnp.expand_dims(x, axis=1) ## B x 1 x D _M = jnp.expand_dims(m, axis=0) ## 1 x 1 x D _x = _x * _M + (1. - _M) * (jnp.amin(_x) - 1.) threshold = jnp.quantile( _x, (_x.shape[2] - nWTA) / _x.shape[2], method='linear', axis=2, keepdims=True ) topK = (_x >= threshold) * 1. ## Do thresholding topK = jnp.sum(topK, axis=1) * x # if clipval > 0.: # topK = jnp.clip(topK, -clipval, clipval) return topK
[docs] @partial(jit, static_argnums=[2, 3]) def d_quantile_lkwta(x, m, nWTA=20, clipval=-1.): ## local k-WTA """ First derivative of quantile-based K-WTA function with respect to its input - NOTE: this is experimental and no guarantees are offered at this point. Args: x: data to apply quantile-KWTA function over m: masking tensor nWTA: number of winners clipval: Returns: y = KWTA(x) """ ## expand to 3D tensor space and do logic in 3D _x = jnp.expand_dims(x, axis=1) ## B x 1 x D _M = jnp.expand_dims(m, axis=0) ## 1 x 1 x D _x = _x * _M + (1. - _M) * (jnp.amin(_x) - 1.) threshold = jnp.quantile( _x, (_x.shape[2] - nWTA) / _x.shape[2], method='linear', axis=2, keepdims=True ) topK = (_x >= threshold) * 1. ## Do thresholding topK = jnp.sum(topK, axis=1) #* x return topK
[docs] @jit def group_mean(x, masks): ## expand to 3D tensor space and do logic in 3D (avoids a for-loop over groups in mask) _x = jnp.expand_dims(x, axis=1) ## B x 1 x D _M = jnp.expand_dims(masks, axis=0) ## 1 x 1 x D x3D = _x * _M mu = jnp.sum(x3D, axis=2, keepdims=True) / jnp.sum(_M, axis=2, keepdims=True) ## calc means over axis 2 of 3D tensor mu = jnp.sum(mu * _M, axis=1) ## now contract back to 2D and smear group means to each dimension per group return mu
[docs] @partial(jit, static_argnums=[2]) def kwta(x, m, nWTA=1): return lkwta(x, m=m, nWTA=nWTA)
[docs] @partial(jit, static_argnums=[2]) def d_kwta(x, m, nWTA=1): return d_lkwta(x, m=m, nWTA=nWTA)
[docs] @partial(jit, static_argnums=[2, 3]) def lkwta(x, m, nWTA=1, clipval=-1.): ## local/group K-WTA """ A group-based K-WTA function, i.e., local K-WTA (LKWTA), as proposed in: | Ororbia, Alexander, Karl Friston, and Rajesh PN Rao. "Meta-representational predictive coding: biomimetic | self-supervised learning." arXiv preprint arXiv:2503.21796 (2025). Args: x: data to apply quantile-KWTA function over m: masking tensor nWTA: number of winners clipval: Returns: y = LKWTA(x) """ ## expand to 3D tensor space and do logic in 3D (avoids a for-loop over groups in mask) ## this is as efficient as lax.top_k can be made for batches _x = jnp.expand_dims(x, axis=1) ## B x 1 x D _M = jnp.expand_dims(m, axis=0) ## 1 x 1 x D _x = _x * _M + (1. - _M) * (jnp.amin(_x) - 1.) values, indices = lax.top_k(_x, nWTA) # Note: we do not care to sort the indices ## go back to 2D matrix space kth = jnp.expand_dims(jnp.min(values, axis=(1, 2)), axis=1) # must do comparison per sample in potential mini-batch topK = jnp.greater_equal(x, kth).astype(jnp.float32) # cast booleans to floats topK = topK * x if clipval > 0.: topK = jnp.clip(topK, -clipval, clipval) return topK
[docs] @partial(jit, static_argnums=[2, 3]) def d_lkwta(x, m, nWTA=1, clipval=-1.): ## derivative of local/group K-WTA w.r.t. input """ Derivative of group-based K-WTA function with respect to its input. This function, local K-WTA (LKWTA), was proposed in: | Ororbia, Alexander, Karl Friston, and Rajesh PN Rao. "Meta-representational predictive coding: biomimetic | self-supervised learning." arXiv preprint arXiv:2503.21796 (2025). Args: x: data to apply quantile-KWTA function over m: masking tensor nWTA: number of winners clipval: Returns: y = LKWTA(x) """ ## expand to 3D tensor space and do logic in 3D _x = jnp.expand_dims(x, axis=1) ## B x 1 x D _M = jnp.expand_dims(m, axis=0) ## 1 x 1 x D _x = _x * _M + (1. - _M) * (jnp.amin(_x) - 1.) values, indices = lax.top_k(_x, nWTA) # Note: we do not care to sort the indices ## go back to 2D matrix space kth = jnp.expand_dims(jnp.min(values, axis=(1, 2)), axis=1) # must do comparison per sample in potential mini-batch topK = jnp.greater_equal(x, kth).astype(jnp.float32) # cast booleans to floats return topK