Source code for ngclearn.utils.metric_utils

import jax
from jax import numpy as jnp, grad, jit, vmap, random, lax, nn
import os, sys
from functools import partial

[docs] @jit def measure_sparsity(codes, tolerance=0.): """ Calculates the sparsity (ratio) of an input matrix, assuming each row within it is a non-negative vector. Args: codes: matrix (shape: N x D) of non-negative codes to measure sparsity of (per row) tolerance: lowest number to consider as "empty"/non-existent (Default: 0.) Returns: sparsity measurements per code (output shape: N x 1) """ m = (codes > tolerance).astype(jnp.float32) rho = jnp.sum(m, axis=1, keepdims=True)/(codes.shape[1] * 1.) return rho
[docs] @jit def measure_ACC(mu, y): ## measures/calculates accuracy """ Calculates the accuracy (ACC) given a matrix of predictions and matrix of targets. Args: mu: prediction (design) matrix y: target / ground-truth (design) matrix Returns: scalar accuracy score """ guess = jnp.argmax(mu, axis=1) lab = jnp.argmax(y, axis=1) acc = jnp.sum( jnp.equal(guess, lab) )/(y.shape[0] * 1.) return acc
[docs] def measure_KLD(p_xHat, p_x, preserve_batch=False): """ Measures the (raw) Kullback-Leibler divergence (KLD), assuming that the two input arguments contain valid probability distributions (in each row, if they are matrices). Note: If batch is preserved, this returns a column vector where each row is the KLD(x_pred, x_true) for that row's datapoint. | Formula: | KLD(p_xHat, p_x) = (1/N) [ sum_i(p_x * jnp.log(p_x)) - sum_i(p_x * jnp.log(p_xHat)) ] | where sum_i implies summing across dimensions of vector-space of p_x Args: p_xHat: predicted probabilities; (N x C matrix, where C is number of categories) p_x: ground true probabilities; (N x C matrix, where C is number of categories) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ ## numerical control step offset = 1e-6 _p_x = jnp.clip(p_x, offset, 1. - offset) _p_xHat = jnp.clip(p_xHat, offset, 1. - offset) ## calc raw KLD scores N = p_x.shape[1] term1 = jnp.sum(_p_x * jnp.log(_p_x), axis=1, keepdims=True) # * (1/N) term2 = -jnp.sum(_p_x * jnp.log(_p_xHat), axis=1, keepdims=True) # * (1/N) kld = (term1 + term2) * (1/N) if preserve_batch == False: kld = jnp.mean(kld) return kld
[docs] @partial(jit, static_argnums=[3]) def measure_CatNLL(p, x, offset=1e-7, preserve_batch=False): """ Measures the negative Categorical log likelihood (Cat.NLL). Note: If batch is preserved, this returns a column vector where each row is the Cat.NLL(p, x) for that row's datapoint. Args: p: predicted probabilities; (N x C matrix, where C is number of categories) x: true one-hot encoded targets; (N x C matrix, where C is number of categories) offset: factor to control for numerical stability (Default: 1e-7) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ p_ = jnp.clip(p, offset, 1.0 - offset) loss = -(x * jnp.log(p_)) nll = jnp.sum(loss, axis=1, keepdims=True) #/(y_true.shape[0] * 1.0) if preserve_batch == False: nll = jnp.mean(nll) return nll #tf.reduce_mean(nll)
[docs] @jit def measure_MSE(mu, x, preserve_batch=False): """ Measures mean squared error (MSE), or the negative Gaussian log likelihood with variance of 1.0. Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint. Args: mu: predicted values (mean); (N x D matrix) x: target values (data); (N x D matrix) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ diff = mu - x se = jnp.square(diff) ## squared error mse = jnp.sum(se, axis=1, keepdims=True) # technically se at this point if preserve_batch == False: mse = jnp.mean(mse) # this is proper mse return mse
[docs] @jit def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10 """ Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE). Note: If batch is preserved, this returns a column vector where each row is the BCE(p, x) for that row's datapoint. Args: p: predicted probabilities of shape; (N x D matrix) x: target binary values (data) of shape; (N x D matrix) offset: factor to control for numerical stability (Default: 1e-7) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ p_ = jnp.clip(p, offset, 1 - offset) bce = -jnp.sum(x * jnp.log(p_) + (1.0 - x) * jnp.log(1.0 - p_),axis=1, keepdims=True) if preserve_batch == False: bce = jnp.mean(bce) return bce