"""
General modeling utility routines and co-routines. This contains useful
commonly jit-i-fied mathematical functions and operations needed to design
and develop ngc-learn internal components.
"""
import jax
from jax import numpy as jnp, grad, jit, vmap, random, lax, nn
from jax.lax import scan as _scan
import os, sys
from functools import partial
import numpy as np
[docs]
def tensorstats(tensor):
"""
Prints tensor statistics (debugging tool).
Args:
tensor: argument tensor object to examine
Returns:
useful statistics to print to I/O
"""
if isinstance(tensor, (np.ndarray, jax.Array, jnp.ndarray)):
_tensor = np.asarray(tensor)
return {
'mean': _tensor.mean(),
'std': _tensor.std(),
'mag': np.abs(_tensor).max(),
'min': _tensor.min(),
'max': _tensor.max(),
}
elif isinstance(tensor, (list, tuple, dict)):
try:
values, _ = jax.tree.flatten(jax.tree.map(lambda x: x.flatten(), tensor))
values = np.asarray(np.stack(values))
return {
'mean': values.mean(),
'std': values.std(),
'mag': np.abs(values).max(),
'min': values.min(),
'max': values.max(),
}
except:
return None
else:
return None
[docs]
def pull_equations(controller):
"""
Extracts the dynamics string of this controller (model/system).
Args:
controller: model/system to extract dynamics equation(s) from
Returns:
string containing this model/system's dynamics equation(s)
"""
eqn_set = ""
for _name in controller.components:
component = controller.components[_name]
## determine if component has an equation and pull it out if so
for attr in dir(component):
if not callable(getattr(component, attr)) and not attr.startswith("__"):
if attr == "equation":
eqn = "{}".format(attr) ## extract defined equation
eqn_set = "{}\n{}: {}".format(_name, eqn)
return eqn_set
[docs]
def create_function(fun_name, args=None):
"""
Activation function creation routine.
Args:
fun_name: string name of activation function to produce;
Currently supports: "tanh", "bkwta" (binary K-winners-take-all), "sigmoid", "relu", "lrelu", "relu6",
"elu", "silu", "gelu", "softplus", "softmax" (derivative not supported), "unit_threshold", "heaviside",
"identity"
Returns:
function fx, first derivative of function (w.r.t. input) dfx
"""
fx = None ## the function
dfx = None ## the first derivative of function w.r.t. its input
if fun_name == "tanh":
fx = tanh
dfx = d_tanh
elif fun_name == "bkwta":
fx = bkwta
dfx = bkwta #d_identity
elif fun_name == "sine":
fx = sine
dfx = d_sine
elif fun_name == "sigmoid":
fx = sigmoid
dfx = d_sigmoid
elif fun_name == "relu":
fx = relu
dfx = d_relu
elif fun_name == "lrelu":
fx = lrelu
dfx = d_lrelu
elif fun_name == "relu6":
fx = relu6
dfx = d_relu6
elif fun_name == "elu":
fx = elu
dfx = d_elu
elif fun_name == "silu":
fx = silu
dfx = d_silu
elif fun_name == "gelu":
fx = gelu
dfx = d_gelu
elif fun_name == "telu":
fx = telu
dfx = d_telu
elif fun_name == "softplus":
fx = softplus
dfx = d_softplus
elif fun_name == "softmax":
fx = softmax
dfx = d_identity ## TODO: currently Jacobian of softmax not supported!
elif fun_name == "unit_threshold":
fx = threshold ## default threshold is 1 (thus unit)
dfx = d_threshold ## STE approximation
elif "heaviside" in fun_name:
fx = heaviside
dfx = d_heaviside ## STE approximation
elif fun_name == "identity":
fx = identity
dfx = d_identity
else:
raise RuntimeError(
"Activation function (" + fun_name + ") is not recognized/supported!"
)
return fx, dfx
[docs]
@partial(jit, static_argnums=[1])
def bkwta(x, nWTA=5): #5 10 15 #K=50):
values, indices = lax.top_k(x, nWTA) # Note: we do not care to sort the indices
kth = jnp.expand_dims(jnp.min(values,axis=1),axis=1) # must do comparison per sample in potential mini-batch
topK = jnp.greater_equal(x, kth).astype(jnp.float32) # cast booleans to floats
return topK
[docs]
@partial(jit, static_argnums=[2, 3, 4])
def normalize_matrix(data, wnorm, order=1, axis=0, scale=1.):
"""
Normalizes the values in matrix to have a particular norm across each vector span.
Args:
data: (2D) data matrix to normalize
wnorm: target norm for each row/column of data matrix
order: order of norm to use in normalization (Default: 1);
note that `ord=1` results in the L1-norm, `ord=2` results in the L2-norm
axis: 0 (apply to column vectors), 1 (apply to row vectors)
scale: step modifier to produce the projected matrix (Unused)
Returns:
a normalized value matrix
"""
if order == 2: ## denominator is L2 norm
wOrdSum = jnp.maximum(jnp.sqrt(jnp.sum(jnp.square(data), axis=axis, keepdims=True)), 1e-8)
else: ## denominator is L1 norm
wOrdSum = jnp.maximum(jnp.sum(jnp.abs(data), axis=axis, keepdims=True), 1e-8)
m = (wOrdSum == 0.).astype(dtype=jnp.float32)
wOrdSum = wOrdSum * (1. - m) + m #wAbsSum[wAbsSum == 0.] = 1.
_data = data * (wnorm/wOrdSum)
#d_data = ((wnorm/wOrdSum) - 1.) * data
#_data = data + d_data * scale
return _data
[docs]
@jit
def clamp_min(x, min_val):
"""
Clamps values in data x that exceed a minimum value to that value.
Args:
x: data to lower-bound clamp
min_val: minimum value threshold
Returns:
x with minimum clamped values
"""
mask = (x > min_val).astype(jnp.float32)
_x = x * mask + (1. - mask) * min_val
return _x
[docs]
@jit
def clamp_max(x, max_val):
"""
Clamps values in data x that exceed a maximum value to that value.
Args:
x: data to upper-bound clamp
max_val: maximum value threshold
Returns:
x with maximum clamped values
"""
# condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type]
#a = torch.where(condition, a, max)
mask = (x < max_val).astype(jnp.float32)
_x = x * mask + (1. - mask) * max_val
return _x
[docs]
@jit
def one_hot(P):
"""
Converts a matrix of probabilities to a corresponding binary one-hot matrix
(each row is a one-hot encoding).
Args:
P: a probability matrix where each row corresponds to a particular
data probability vector
Returns:
the one-hot encoding (matrix) of probabilities in P
"""
nC = P.shape[1] # compute number of dimensions/classes
p_t = jnp.argmax(P, axis=1)
return nn.one_hot(p_t, num_classes=nC, dtype=jnp.float32)
[docs]
def binarize(data, threshold=0.5):
"""
Converts the vector *data* to its binary equivalent
Args:
data: the data to binarize (real-valued)
threshold: the cut-off point for 0, i.e., if threshold = 0.5, then any
number/value inside of data < 0.5 is set to 0, otherwise, it is set
to 1.0
Returns:
the binarized equivalent of "data"
"""
return (data > threshold).astype(jnp.float32)
[docs]
@jit
def identity(x):
"""
The identity function: x = f(x).
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return x + 0
[docs]
@jit
def d_identity(x):
"""
Derivative of the identity function.
Args:
x: input (tensor) value
Returns:
output (tensor) derivative value (with respect to input argument)
"""
return x * 0 + 1.
[docs]
@jit
def relu(x):
"""
The linear rectifier: max(0, x) = f(x).
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return nn.relu(x)
[docs]
@jit
def d_relu(x):
"""
Derivative of the linear rectifier.
Args:
x: input (tensor) value
Returns:
output (tensor) derivative value (with respect to input argument)
"""
return (x >= 0.).astype(jnp.float32)
[docs]
@jit
def telu(x):
"""
The hyperbolic tangent exponential linear (TeLU) function:
| f(x) = x * tanh(e^x)
This was proposed by Fernandez and Mali 24 in:
| https://arxiv.org/abs/2412.20269 and in,
| https://arxiv.org/abs/2402.02790
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return x * jnp.tanh(jnp.exp(x))
[docs]
@jit
def d_telu(x):
"""
Derivative of the hyperbolic tangent exponential linear (TeLU) function.
Effectively, this is formally:
| f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x))
Args:
x: input (tensor) value
Returns:
output (tensor) derivative value (with respect to input)
"""
ex = jnp.exp(x)
tanh_ex = jnp.tanh(ex)
return tanh_ex + x * ex * (1.0 - tanh_ex ** 2)
[docs]
@jit
def sine(x, omega_0=30):
"""
f(x) = sin(x * omega_0).
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return jnp.sin(omega_0 * x)
[docs]
@jit
def d_sine(x, omega_0=30):
"""
frequency = omega_0
frequency * cos(x * frequency).
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return omega_0 * jnp.cos(omega_0 * x)
[docs]
@jit
def tanh(x):
"""
The hyperbolic tangent function.
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return nn.tanh(x)
[docs]
@jit
def d_tanh(x):
"""
Derivative of the hyperbolic tangent function.
Args:
x: input (tensor) value
Returns:
output (tensor) derivative value (with respect to input argument)
"""
tanh_x = nn.tanh(x)
return -(tanh_x * tanh_x) + 1.0
[docs]
@jit
def inverse_tanh(x):
"""
The inverse hyperbolic tangent.
Args:
x: data to transform via inverse hyperbolic tangent
clip_bound: pre-processing lower/upper bounds to enforce on data
before applying inverse hyperbolic tangent
Returns:
x transformed via inverse hyperbolic tangent
"""
#m = 0.5 * log ( (ones(size(x)) + x) ./ (ones(size(x)) - x))
return jnp.log((1. + x)/(1. - x))
[docs]
@jit
def lrelu(x): ## activation fx
"""
The leaky linear rectifier: max(0, x) if x >= 0, 0.01 * x if x < 0 = f(x).
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return nn.leaky_relu(x)
[docs]
@jit
def d_lrelu(x): ## deriv of fx (dampening function)
"""
Derivative of the leaky linear rectifier.
Args:
x: input (tensor) value
Returns:
output (tensor) derivative value (with respect to input argument)
"""
m = (x >= 0.).astype(jnp.float32)
dx = m + (1. - m) * 0.01
return dx
[docs]
@jit
def relu6(x):
"""
The linear rectifier upper bounded at the value of 6: min(max(0, x), 6.).
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return nn.relu6(x)
[docs]
@jit
def d_relu6(x):
"""
Derivative of the bounded leaky linear rectifier (upper bounded at 6).
Args:
x: input (tensor) value
Returns:
output (tensor) derivative value (with respect to input argument)
"""
# df/dx = 1 if 0<x<6 else 0
# I_x = (z >= a_min) *@ (z <= b_max) //create an indicator function a = 0 b = 6
Ix1 = (x > 0.).astype(jnp.float32) #tf.cast(tf.math.greater_equal(x, 0.0),dtype=tf.float32)
Ix2 = (x <= 6.).astype(jnp.float32) #tf.cast(tf.math.less_equal(x, 6.0),dtype=tf.float32)
Ix = Ix1 * Ix2
return Ix
[docs]
@jit
def softplus(x):
"""
The softplus elementwise function.
Args:
x: input (tensor) value
Returns:
output (tensor) value
"""
return nn.softplus(x)
[docs]
@jit
def d_softplus(x):
"""
Derivative of the softplus function.
Args:
x: input (tensor) value
Returns:
output (tensor) derivative value (with respect to input argument)
"""
## d/dx of softplus = logistic sigmoid
return nn.sigmoid(x)
[docs]
@jit
def threshold(x, thr=1.):
return (x >= thr).astype(jnp.float32)
[docs]
@jit
def d_threshold(x, thr=1.):
return x * 0. + 1. ## straight-thru estimator
[docs]
@jit
def heaviside(x):
return (x >= 0.).astype(jnp.float32)
[docs]
@jit
def d_heaviside(x):
return x * 0. + 1. ## straight-thru estimator
[docs]
@jit
def sigmoid(x):
return nn.sigmoid(x)
[docs]
@jit
def d_sigmoid(x):
sigm_x = nn.sigmoid(x) ## pre-compute once
return sigm_x * (1. - sigm_x)
[docs]
def inverse_sigmoid(x, clip_bound=0.03): ## wrapper call for naming convention ease
return inverse_logistic(x, clip_bound=clip_bound)
[docs]
@jit
def inverse_logistic(x, clip_bound=0.03):
"""
The inverse logistic link - the logit function.
Args:
x: data to transform via inverse logistic function
clip_bound: pre-processing lower/upper bounds to enforce on data
before applying inverse logistic
Returns:
x transformed via inverse logistic function
"""
x_ = x
if clip_bound > 0.0:
x_ = jnp.clip(x_, clip_bound, 1.0 - clip_bound)
return jnp.log( x_/((1.0 - x_) + 1e-6) )
[docs]
@jit
def swish(x, beta):
"""
Applies the Swish parameterized activation, proposed in Ramachandran et al., 2017
("Searching for Activation Functions").
Args:
x: data to transform via inverse logistic function
beta: coefficient/parameters to weight input x by
Returns:
output of the Swish activation
"""
return x * sigmoid(x * beta)
[docs]
@jit
def d_swish(x, beta):
# df/dx = beta * [ 1/(exp(-x) + 1) + (exp(-x) * x) / (exp(-x) + 1)^2]
# df/dx = beta * sigmoid(x * beta) * (1 - sigmoid(x) * beta)
exp_neg_x = jnp.exp(-x)
_x = (1./(exp_neg_x + 1.)) + (exp_neg_x * x)/jnp.square(exp_neg_x+1)
return _x * beta
[docs]
@jit
def silu(x):
"""
Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
Args:
x: data to transform via inverse logistic function
Returns:
output of the Swish activation
"""
return swish(x, beta=1.)
[docs]
@jit
def d_silu(x):
return d_swish(x, beta=1.)
[docs]
@jit
def gelu(x):
"""
Applies the Gaussian Error Linear Unit (GeLU) activation (specifically, a fast approximation is used).
Args:
x: data to transform via inverse logistic function
Returns:
output of the GeLU activation
"""
return swish(x, beta=1.702) ## approximate GeLU # beta=1.4
[docs]
@jit
def d_gelu(x):
# df/dx = 1.702 * [ 1/(exp(-x) + 1) + (exp(-x) * x) / (exp(-x) + 1)^2]
return d_swish(x, beta=1.702) # beta=1.4
[docs]
@jit
def elu(x, alpha=1.):
"""
Applies the exponential linear unit (ELU) activation.
Args:
x: data to transform via inverse logistic function
alpha: coefficient/parameters to weight input x by
Returns:
output of the GeLU activation
"""
mask = x >= 0.
return x * mask + ((jnp.exp(x) - 1) * alpha) * (1. - mask)
[docs]
@jit
def d_elu(x, alpha=1.):
mask = (x >= 0.)
return mask + (1. - mask) * (jnp.exp(x) * alpha)
[docs]
@jit
def softmax(x, tau=0.0):
"""
Softmax function with overflow control built in directly. Contains optional
temperature parameter to control sharpness
(tau > 1 softens probs, < 1 sharpens --> 0 yields point-mass).
Args:
x: a (N x D) input argument (pre-activity) to the softmax operator
tau: probability sharpening/softening factor
Returns:
a (N x D) probability distribution output block
"""
if tau > 0.0:
x = x / tau
max_x = jnp.max(x, axis=1, keepdims=True)
exp_x = jnp.exp(x - max_x)
return exp_x / jnp.sum(exp_x, axis=1, keepdims=True)
[docs]
@jit
def threshold_soft(x, lmbda):
"""
A soft threshold routine applied to each dimension of input
Args:
x: data to apply threshold function over
lmbda: scalar to control strength/influence of thresholding
Returns:
thresholded x
"""
# soft thresholding fx - S(x) = (|x| - lmbda) *@ sign(x)
## legacy ngclearn: tf.math.maximum(x - lmbda, 0.) - tf.math.maximum(-x - lmbda, 0.)
return jnp.maximum(x - lmbda, 0.) - jnp.maximum(-x - lmbda, 0.)
[docs]
@jit
def threshold_cauchy(x, lmbda):
"""
A Cauchy distributional threshold routine applied to each dimension of input
Args:
x: data to apply threshold function over
lmbda: scalar to control strength/influence of Cauchy thresholding
Returns:
thresholded x
"""
# threshold function based on that proposed in: https://arxiv.org/abs/2003.12507
inner_term = jnp.sqrt(jnp.maximum(jnp.square(x) - lmbda), 0.)
f = (x + inner_term) * 0.5
g = (x - inner_term) * 0.5
term1 = f * (x >= lmbda).astype(jnp.float32) ## f * (x >= lmda)
term2 = g * (x <= -lmbda).astype(jnp.float32) ## g * (x <= -lmda)
return term1 + term2
[docs]
@jit
def layer_normalize(x, shift=0., scale=1.):
"""
Applies layer normalization to input data `x`
Args:
x: data to apply threshold function over
shift: the compensating mean/shift factor/parameters (to undo mean subtraction)
scale: the compensating re-scaling factor/parameters (to undo standard deviation division)
Returns:
layer-normalized data samples `x`
"""
xmu = jnp.mean(x, axis=1, keepdims=True)
xsigma = jnp.sqrt(jnp.mean(jnp.square(x - xmu)).clip(min=1e-6)).clip(min=1e-6)
_x = (x - xmu) / xsigma
return _x * scale + shift
[docs]
@jit
def drop_out(dkey, data, rate=0.0):
"""
Applies a drop-out transform (i.e., a random number of elements will be dropped to zero) to an input matrix.
Args:
dkey: Jax randomness key for this operator
data: input data to apply random/drop-out mask to
rate: probability of a dimension being dropped
Returns:
output as well as binary mask
"""
eps = random.uniform(dkey, shape=data.shape, minval=0.0, maxval=1.0)
mask = (eps <= (1.0 - rate)).astype(jnp.float32)
mask = mask * (1.0 / (1.0 - rate)) ## apply inverted dropout scheme
output = data * mask
return output, mask
[docs]
@jit
def clip(x, min_val, max_val):
return jnp.clip(x, min_val, max_val)
[docs]
@jit
def d_clip(x, min_val, max_val):
return jnp.where((x < min_val) | (x > max_val), 0.0, 1.0)
# def scanner(fn):
# """
# A wrapper for Jax's scanner that handles the "getting" of the current
# state and "setting" of the final state to and from the model.
#
# | @scanner
# | def process(current_state, args):
# | t = args[0]
# | dt = args[1]
# | current_state = model.advance_state(current_state, t, dt)
# | current_state = model.evolve(current_state, t, dt)
# | return current_state, (current_state[COMPONENT.COMPARTMENT.path], ...)
# |
# | outputs = models.process(jnp.array([[ARG0, ARG1] for i in range(NUM_LOOPS)]))
#
# | Notes on the scanner function call:
# | 1) `current_state` is a hash-map mapped to all compartment values by path
# | 2) `args` is the external arguments defined in the passed Jax array
# | 3) `outputs` is a tuple containing time-concatenated Jax arrays of the
# | compartment statistics you want tracked
#
# Args:
# fn: function that is executed at every time step of a Jax-unrolled loop,
# it must take in the current state and external arguments
#
# Returns:
# wrapped (fast) function that is Jax-scanned/jit-i-fied
# """
# def _scanned(_xs):
# vals, stacked = _scan(fn, init=Get_Compartment_Batch(), xs=_xs)
# Set_Compartment_Batch(vals)
# return stacked
#
# if get_current_context() is not None:
# get_current_context().__setattr__(fn.__name__, _scanned)
# return _scanned