Source code for ngclearn.utils.model_utils

import jax
from jax import numpy as jnp, grad, jit, vmap, random, lax, nn
import os, sys
from functools import partial

[docs] def pull_equations(controller): """ Extracts the dynamics string of this controller (model/system). Args: controller: model/system to extract dynamics equation(s) from Returns: string containing this model/system's dynamics equation(s) """ eqn_set = "" for _name in controller.components: component = controller.components[_name] ## determine if component has an equation and pull it out if so for attr in dir(component): if not callable(getattr(component, attr)) and not attr.startswith("__"): if attr == "equation": eqn = "{}".format(attr) ## extract defined equation eqn_set = "{}\n{}: {}".format(_name, eqn) return eqn_set
@jit def measure_ACC(mu, y): ## measures/calculates accuracy """ Calculates the accuracy (ACC) given a matrix of predictions and matrix of targets. Args: mu: prediction (design) matrix y: target / ground-truth (design) matrix Returns: scalar accuracy score """ guess = jnp.argmax(mu, axis=1) lab = jnp.argmax(y, axis=1) acc = jnp.sum( jnp.equal(guess, lab) )/(y.shape[0] * 1.) return acc
[docs] def measure_KLD(p_xHat, p_x, preserve_batch=False): """ Measures the (raw) Kullback-Leibler divergence (KLD), assuming that the two input arguments contain valid probability distributions (in each row, if they are matrices). Note: If batch is preserved, this returns a column vector where each row is the KLD(x_pred, x_true) for that row's datapoint. | Formula: | KLD(p_xHat, p_x) = (1/N) [ sum_i(p_x * jnp.log(p_x)) - sum_i(p_x * jnp.log(p_xHat)) ] | where sum_i implies summing across dimensions of vector-space of p_x Args: p_xHat: predicted probabilities; (N x C matrix, where C is number of categories) p_x: ground true probabilities; (N x C matrix, where C is number of categories) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ ## numerical control step offset = 1e-6 _p_x = jnp.clip(p_x, offset, 1. - offset) _p_xHat = jnp.clip(p_xHat, offset, 1. - offset) ## calc raw KLD scores N = p_x.shape[1] term1 = jnp.sum(_p_x * jnp.log(_p_x), axis=1, keepdims=True) # * (1/N) term2 = -jnp.sum(_p_x * jnp.log(_p_xHat), axis=1, keepdims=True) # * (1/N) kld = (term1 + term2) * (1/N) if preserve_batch == False: kld = jnp.mean(kld) return kld
@partial(jit, static_argnums=[3]) def measure_CatNLL(p, x, offset=1e-7, preserve_batch=False): """ Measures the negative Categorical log likelihood (Cat.NLL). Note: If batch is preserved, this returns a column vector where each row is the Cat.NLL(p, x) for that row's datapoint. Args: p: predicted probabilities; (N x C matrix, where C is number of categories) x: true one-hot encoded targets; (N x C matrix, where C is number of categories) offset: factor to control for numerical stability (Default: 1e-7) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ p_ = jnp.clip(p, offset, 1.0 - offset) loss = -(x * jnp.log(p_)) nll = jnp.sum(loss, axis=1, keepdims=True) #/(y_true.shape[0] * 1.0) if preserve_batch == False: nll = jnp.mean(nll) return nll #tf.reduce_mean(nll) @jit def measure_MSE(mu, x, preserve_batch=False): """ Measures mean squared error (MSE), or the negative Gaussian log likelihood with variance of 1.0. Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint. Args: mu: predicted values (mean); (N x D matrix) x: target values (data); (N x D matrix) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ diff = mu - x se = jnp.square(diff) ## squared error mse = jnp.sum(se, axis=1, keepdims=True) # technically se at this point if preserve_batch == False: mse = jnp.mean(mse) # this is proper mse return mse @jit def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10 """ Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE). Note: If batch is preserved, this returns a column vector where each row is the BCE(p, x) for that row's datapoint. Args: p: predicted probabilities of shape; (N x D matrix) x: target binary values (data) of shape; (N x D matrix) offset: factor to control for numerical stability (Default: 1e-7) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ p_ = jnp.clip(p, offset, 1 - offset) bce = -jnp.sum(x * jnp.log(p_) + (1.0 - x) * jnp.log(1.0 - p_),axis=1, keepdims=True) if preserve_batch == False: bce = jnp.mean(bce) return bce
[docs] def create_function(fun_name, args=None): """ Activation function creation routine. Args: fun_name: string name of activation function to produce (Currently supports: "tanh", "relu", "lrelu", "identity") Returns: function fx, first derivative of function (w.r.t. input) dfx """ fx = None dfx = None if fun_name == "tanh": fx = tanh dfx = d_tanh elif fun_name == "sigmoid": fx = sigmoid dfx = d_sigmoid elif fun_name == "relu": fx = relu dfx = d_relu elif fun_name == "lrelu": fx = lrelu dfx = d_lrelu elif fun_name == "relu6": fx = relu6 dfx = d_relu6 elif fun_name == "softplus": fx = softplus dfx = d_softplus elif fun_name == "unit_threshold": fx = threshold ## default threshold is 1 (thus unit) dfx = d_threshold ## STE approximation elif "heaviside" in fun_name: fx = heaviside dfx = d_heaviside ## STE approximation elif fun_name == "identity": fx = identity dfx = d_identity else: raise RuntimeError( "Activition function (" + fun_name + ") is not recognized/supported!" ) return fx, dfx
[docs] def initialize_params(dkey, initKernel, shape): """ Creates the intiial condition values for a parameter tensor. Args: dkey: PRNG key to control determinism of this routine initKernel: triplet/3-tuple with 1st element as a string calling the name of initialization scheme to use :Note: Currently supported kernel schemes include: ("hollow", off_diagonal_scale, ~ignored~); ("eye", diagonal_scale, ~ignored~); ("uniform", min_val, max_val); ("gaussian", mu, sigma) OR ("normal", mu, sigma); ("constant", magnitude, ~ignored~) shape: tuple containing the dimensions/shape of the tensor to initialize Returns: output (tensor) value """ initType, *args = initKernel # get out arguments of initialization kernel params = None if initType == "hollow": eyeScale, _ = args params = (1. - jnp.eye(N=shape[0], M=shape[1])) * eyeScale elif initType == "eye": eyeScale, _ = args params = jnp.eye(N=shape[0], M=shape[1]) * eyeScale elif initType == "uniform": ## uniformly distributed values lb, ub = args params = random.uniform(dkey, shape, minval=lb, maxval=ub) elif initType == "gaussian" or initType == "normal": ## gaussian distributed values mu, sigma = args params = random.normal(dkey, shape) * sigma + mu elif initType == "constant": ## constant value(s) scale, _ = args params = jnp.ones(shape) * scale else: raise RuntimeError( "Initialization scheme (" + initType + ") is not recognized/supported!" ) return params
@partial(jit, static_argnums=[2, 3]) def normalize_matrix(M, wnorm, order=1, axis=0): """ Normalizes the values in matrix to have a particular norm across each vector span. Args: M: (2D) matrix to normalize wnorm: target norm for each order: order of norm to use in normalization (Default: 1); note that `ord=1` results in the L1-norm, `ord=2` results in the L2-norm axis: 0 (apply to column vectors), 1 (apply to row vectors) Returns: a normalized value matrix """ if order == 2: ## denominator is L2 norm wOrdSum = jnp.square(jnp.sum(jnp.square(M), axis=axis, keepdims=True)) else: ## denominator is L1 norm wOrdSum = jnp.sum(jnp.abs(M), axis=axis, keepdims=True) m = (wOrdSum == 0.).astype(dtype=jnp.float32) wOrdSum = wOrdSum * (1. - m) + m #wAbsSum[wAbsSum == 0.] = 1. _M = M * (wnorm/wOrdSum) return _M @jit def clamp_min(x, min_val): """ Clamps values in data x that exceed a minimum value to that value. Args: x: data to lower-bound clamp min_val: minimum value threshold Returns: x with minimum clamped values """ mask = (x > min_val).astype(jnp.float32) _x = x * mask + (1. - mask) * min_val return _x @jit def clamp_max(x, max_val): """ Clamps values in data x that exceed a maximum value to that value. Args: x: data to upper-bound clamp max_val: maximum value threshold Returns: x with maximum clamped values """ # condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type] #a = torch.where(condition, a, max) mask = (x < max_val).astype(jnp.float32) _x = x * mask + (1. - mask) * max_val return _x @jit def one_hot(P): """ Converts a matrix of probabilities to a corresponding binary one-hot matrix (each row is a one-hot encoding). Args: P: a probability matrix where each row corresponds to a particular data probability vector Returns: the one-hot encoding (matrix) of probabilities in P """ nC = P.shape[1] # compute number of dimensions/classes p_t = jnp.argmax(P, axis=1) return nn.one_hot(p_t, num_classes=nC, dtype=jnp.float32)
[docs] def binarize(data, threshold=0.5): """ Converts the vector *data* to its binary equivalent Args: data: the data to binarize (real-valued) threshold: the cut-off point for 0, i.e., if threshold = 0.5, then any number/value inside of data < 0.5 is set to 0, otherwise, it is set to 1.0 Returns: the binarized equivalent of "data" """ return (data > threshold).astype(jnp.float32)
@jit def identity(x): """ The identity function: x = f(x). Args: x: input (tensor) value Returns: output (tensor) value """ return x + 0 @jit def d_identity(x): """ Derivative of the identity function. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ return x * 0 + 1. @jit def relu(x): """ The linear rectifier: max(0, x) = f(x). Args: x: input (tensor) value Returns: output (tensor) value """ return nn.relu(x) @jit def d_relu(x): """ Derivative of the linear rectifier. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ return (x >= 0.).astype(jnp.float32) @jit def tanh(x): """ The hyperbolic tangent function. Args: x: input (tensor) value Returns: output (tensor) value """ return nn.tanh(x) @jit def d_tanh(x): """ Derivative of the hyperbolic tangent function. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ tanh_x = nn.tanh(x) return -(tanh_x * tanh_x) + 1.0 @jit def inverse_tanh(x): """ The inverse hyperbolic tangent. Args: x: data to transform via inverse hyperbolic tangent clip_bound: pre-processing lower/upper bounds to enforce on data before applying inverse hyperbolic tangent Returns: x transformed via inverse hyperbolic tangent """ #m = 0.5 * log ( (ones(size(x)) + x) ./ (ones(size(x)) - x)) return jnp.log((1. + x)/(1. - x)) @jit def lrelu(x): ## activation fx """ The leaky linear rectifier: max(0, x) if x >= 0, 0.01 * x if x < 0 = f(x). Args: x: input (tensor) value Returns: output (tensor) value """ return nn.leaky_relu(x) @jit def d_lrelu(x): ## deriv of fx (dampening function) """ Derivative of the leaky linear rectifier. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ m = (x >= 0.).astype(jnp.float32) dx = m + (1. - m) * 0.01 return dx @jit def relu6(x): """ The linear rectifier upper bounded at the value of 6: min(max(0, x), 6.). Args: x: input (tensor) value Returns: output (tensor) value """ return nn.relu6(x) @jit def d_relu6(x): """ Derivative of the bounded leaky linear rectifier (upper bounded at 6). Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ # df/dx = 1 if 0<x<6 else 0 # I_x = (z >= a_min) *@ (z <= b_max) //create an indicator function a = 0 b = 6 Ix1 = (x > 0.).astype(jnp.float32) #tf.cast(tf.math.greater_equal(x, 0.0),dtype=tf.float32) Ix2 = (x <= 6.).astype(jnp.float32) #tf.cast(tf.math.less_equal(x, 6.0),dtype=tf.float32) Ix = Ix1 * Ix2 return Ix @jit def softplus(x): """ The softplus elementwise function. Args: x: input (tensor) value Returns: output (tensor) value """ return nn.softplus(x) @jit def d_softplus(x): """ Derivative of the softplus function. Args: x: input (tensor) value Returns: output (tensor) derivative value (with respect to input argument) """ ## d/dx of softplus = logistic sigmoid return nn.sigmoid(x) @jit def threshold(x, thr=1.): return (x >= thr).astype(jnp.float32) @jit def d_threshold(x, thr=1.): return x * 0. + 1. ## straight-thru estimator @jit def heaviside(x): return (x >= 0.).astype(jnp.float32) @jit def d_heaviside(x): return x * 0. + 1. ## straight-thru estimator @jit def sigmoid(x): return nn.sigmoid(x) @jit def d_sigmoid(x): sigm_x = nn.sigmoid(x) ## pre-compute once return sigm_x * (1. - sigm_x) @jit def inverse_logistic(x, clip_bound=0.03): # 0.03 """ The inverse logistic link - logit function. Args: x: data to transform via inverse logistic function clip_bound: pre-processing lower/upper bounds to enforce on data before applying inverse logistic Returns: x transformed via inverse logistic function """ x_ = x if clip_bound > 0.0: x_ = jnp.clip(x_, clip_bound, 1.0 - clip_bound) return jnp.log( x_/((1.0 - x_) + 1e-6) ) @jit def softmax(x, tau=0.0): """ Softmax function with overflow control built in directly. Contains optional temperature parameter to control sharpness (tau > 1 softens probs, < 1 sharpens --> 0 yields point-mass). Args: x: a (N x D) input argument (pre-activity) to the softmax operator tau: probability sharpening/softening factor Returns: a (N x D) probability distribution output block """ if tau > 0.0: x = x / tau max_x = jnp.max(x, axis=1, keepdims=True) exp_x = jnp.exp(x - max_x) return exp_x / jnp.sum(exp_x, axis=1, keepdims=True)
[docs] def threshold_soft(x, lmbda): """ A soft threshold routine applied to each dimension of input Args: x: data to apply threshold function over lmbda: scalar to control strength/influence of thresholding Returns: thresholded x """ # soft thresholding fx - S(x) = (|x| - lmbda) *@ sign(x) ## legacy ngclearn: tf.math.maximum(x - lmbda, 0.) - tf.math.maximum(-x - lmbda, 0.) return jnp.maximum(x - lmbda, 0.) - jnp.maximum(-x - lmbda, 0.)
[docs] def threshold_cauchy(x, lmbda): """ A Cauchy distributional threshold routine applied to each dimension of input Args: x: data to apply threshold function over lmbda: scalar to control strength/influence of Cauchy thresholding Returns: thresholded x """ # threshold function based on that proposed in: https://arxiv.org/abs/2003.12507 inner_term = jnp.sqrt(jnp.maximum(jnp.square(x) - lmbda), 0.) f = (x + inner_term) * 0.5 g = (x - inner_term) * 0.5 term1 = f * (x >= lmbda).astype(jnp.float32) ## f * (x >= lmda) term2 = g * (x <= -lmbda).astype(jnp.float32) ## g * (x <= -lmda) return term1 + term2
@jit def drop_out(dkey, input, rate=0.0): """ Applies a drop-out transform to an input matrix. Args: dkey: Jax randomness key for this operator input: data to apply random/drop-out mask to rate: probability of a dimension being dropped Returns: output as well as binary mask """ eps = random.uniform(dkey, (input.shape[0],input.shape[1]), minval=0.0, maxval=1.0) mask = (eps <= (1.0 - rate)).astype(jnp.float32) mask = mask * (1.0 / (1.0 - rate)) ## apply inverted dropout scheme output = input * mask return output, mask