Source code for ngclearn.utils.metric_utils

"""
Metric and measurement routines and co-routines. These functions are useful for model-level/simulation analysis as well
as experimental inspection and probing (many of these are neuroscience-oriented measurement functions).
"""
from jax import numpy as jnp, jit
from functools import partial
from sklearn.metrics import confusion_matrix, precision_score, recall_score

[docs] @partial(jit, static_argnums=[1]) def measure_fanoFactor(spikes, preserve_batch=False): """ Calculates the Fano factor, i.e., a secondary statistics that probes the variability of a spike train within a particular time interval. Args: spikes: full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score Returns: a 1 x D Fano factor vector (one factor per neuron) OR a single average Fano factor across the neuronal group """ mu = jnp.mean(spikes, axis=0, keepdims=True) sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True)) fano = sigSqr/mu if not preserve_batch: fano = jnp.mean(fano) return fano
[docs] @partial(jit, static_argnums=[1]) def measure_firingRate(spikes, preserve_batch=False): """ Calculates the firing rate(s) of a group of neurons given full spike train.(s) Args: spikes: full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score Returns: a 1 x D firing rate vector (one firing rate per neuron) OR a single average firing rate across the neuronal group """ counts = jnp.sum(spikes, axis=0, keepdims=True) T = spikes.shape[0] * 1. fireRates = counts/T if not preserve_batch: fireRates = jnp.mean(fireRates) return fireRates
[docs] @partial(jit, static_argnums=[1]) def measure_breadth_TC(spikes, preserve_batch=False): """ Calculates the breath tuning curve (BTC) of a group of neurons given full spike train.(s). BTC measures the neural selectivity such that the sparse code distribution concentrates near zero with a heavy tail. For a neural layer where most of the neurons fire, the activity distribution is more uniformly spread and BTC > 0.5. When most of the neurons do not fire, the firing distribution is peaked at zero and BTC < 0.5. Args: spikes: full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score Returns: a 1 x D Fano factor vector (one factor per neuron) OR a single average Fano factor across the neuronal group """ mu = jnp.mean(spikes, axis=0, keepdims=True) sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True)) C = sigSqr/mu BTC = 1./(1 + jnp.square(C)) if not preserve_batch: BTC = jnp.mean(BTC) return BTC
[docs] @partial(jit, static_argnums=[2, 3]) def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=False): """ Calculates the sparsity (ratio) of an input matrix, assuming each row within this matrix is a non-negative vector. Formally, this means we compute, per i-th row: | rho(x_i) = num_zeros(x_i) / dim(x_i) and for a global score for matrix X with N codes/rows, we measure: | rho_mean(X) = 1/N Sum^N_{i=1} rho(x_i) where lower/closer to 0 means codes more sparse and closer to 1 means codes are more dense. Note that this definition of sparsity aligns with Foldiak's definition of the ratio of active neurons to inactive ones (assuming binary coding): | Foldiak, Peter. "Sparse and explicit neural coding." Principles of neural | coding. CRC Press, 2013. 379-389. Args: codes: matrix (shape: N x D) of non-negative codes to measure sparsity of (per row) tolerance: lowest number to consider as "empty"/non-existent (Default: 0.) preserve_batch: if True, will return one score per sample (N x 1) in batch (Default: True), otherwise, returns scalar average/mean score flip_measure: if True, will score sparsity via "1 - nzero/dim" (Default: False) Returns: sparsity measurements per code (shape: N x 1) or single score (shape: 1 x 1) """ dim = codes.shape[1] m = (codes > tolerance).astype(jnp.float32) rho = jnp.sum(m, axis=1, keepdims=True)/(dim * 1.) ## per-code sparsity if flip_measure: ## closer to 1 = more sparse, closer to 0, more dense rho = 1. - rho if not preserve_batch: rho = jnp.mean(rho) return rho
#@partial(jit, static_argnums=[2])
[docs] def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation statistics """ Analyzes a set of prediction matrix and target/ground-truth matrix or vector. Args: mu: prediction (design) matrix; shape is (N x C) where C is number of classes and N is the number of patterns examined y: target / ground-truth (design) matrix; shape is (N x C) OR an array of class integers of length N (with "extract_label_indx = True") extract_label_indx: run an argmax to pull class integer indices from "y", assuming y is a one-hot binary encoding matrix (Default: True), otherwise, this assumes "y" is an array of class integer indices of length N Returns: confusion matrix, precision, recall, misses (empty predictions/all-zero rows), accuracy, adjusted-accuracy (counts all misses as incorrect) """ miss_mask = (jnp.sum(mu, axis=1) == 0.) * 1. misses = jnp.sum(miss_mask) ## how many misses? labels = y if extract_label_indx: labels = jnp.argmax(y, axis=1) guesses = jnp.argmax(mu, axis=1) conf_matrix = confusion_matrix(labels, guesses) precision = precision_score(labels, guesses, average='macro') recall = recall_score(labels, guesses, average='macro') ## produce accuracy score measurements guess = jnp.argmax(mu, axis=1) ## gather all model/output guesses equality_mask = jnp.equal(guess, labels) * 1. ### compute raw accuracy acc = jnp.sum(equality_mask) / (y.shape[0] * 1.) ### compute hit-masked accuracy (adjusted accuracy adj_acc = jnp.sum(equality_mask * (1. - miss_mask)) / (y.shape[0] * 1.) ## output analysis statistics return conf_matrix, precision, recall, misses, acc, adj_acc
[docs] @partial(jit, static_argnums=[2]) def measure_ACC(mu, y, extract_label_indx=True): ## measures/calculates accuracy """ Calculates the accuracy (ACC) given a matrix of predictions and matrix of targets. Args: mu: prediction (design) matrix; shape is (N x C) where C is number of classes and N is the number of patterns examined y: target / ground-truth (design) matrix; shape is (N x C) OR an array of class integers of length N (with "extract_label_indx = True") extract_label_indx: run an argmax to pull class integer indices from "y", assuming y is a one-hot binary encoding matrix (Default: True), otherwise, this assumes "y" is an array of class integer indices of length N Returns: scalar accuracy score """ guess = jnp.argmax(mu, axis=1) if extract_label_indx: lab = jnp.argmax(y, axis=1) acc = jnp.sum( jnp.equal(guess, lab) )/(y.shape[0] * 1.) return acc
[docs] @partial(jit, static_argnums=[3]) def measure_BIC(X, n_model_params, max_model_score, is_log=True): """ Measures the Bayesian information criterion (BIC) with respect to the final score obtained by the model on a given dataset. | BIC = -2 ln(L) + K * ln(N); | where N is number of data-points/rows of design matrix X, | K is total number parameters of the model of interest, and | L is the max/best-found value of a likelihood-like score L of the model Args: X: dataset/design matrix that a model was fit to (max-likelihood optimized) n_model_params: total number of model parameters (int) max_model_score: max likelihood-like score obtained by model on X is_log: is supplied `max_model_score` a log-likelihood? if this is False, this metric will apply a natural logarithm of the score (Default: True) Returns: scalar for the Bayesian information criterion score """ ## BIC = K * ln(N) - 2 ln(L) L_hat = max_model_score ## model's likelihood-like score (at max point) K = n_model_params ## number of model params N = X.shape[0] ## number of data-points if not is_log: L_hat = jnp.log(L_hat) ## get log likelihood bic = -L_hat * 2. + jnp.log(N * 1.) * K return bic
[docs] @partial(jit, static_argnums=[2]) def measure_KLD(p_xHat, p_x, preserve_batch=False): """ Measures the (raw) Kullback-Leibler divergence (KLD), assuming that the two input arguments contain valid probability distributions (in each row, if they are matrices). Note: If batch is preserved, this returns a column vector where each row is the KLD(x_pred, x_true) for that row's datapoint. (Further note that this function does not assume any particular distribution when calculating KLD) | Formula: | KLD(p_xHat, p_x) = (1/N) [ sum_i(p_x * jnp.log(p_x)) - sum_i(p_x * jnp.log(p_xHat)) ] | where sum_i implies summing across dimensions of vector-space of p_x Args: p_xHat: predicted probabilities; (N x C matrix, where C is number of categories) p_x: ground true probabilities; (N x C matrix, where C is number of categories) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ ## numerical control step offset = 1e-6 _p_x = jnp.clip(p_x, offset, 1. - offset) _p_xHat = jnp.clip(p_xHat, offset, 1. - offset) ## calc raw KLD scores N = p_x.shape[1] term1 = jnp.sum(_p_x * jnp.log(_p_x), axis=1, keepdims=True) # * (1/N) term2 = -jnp.sum(_p_x * jnp.log(_p_xHat), axis=1, keepdims=True) # * (1/N) kld = (term1 + term2) * (1/N) ## KLD-per-datapoint if not preserve_batch: kld = jnp.mean(kld) return kld
[docs] def measure_gaussian_KLD(mu1, Sigma1, mu2, Sigma2, use_chol_prec=True): """ Calculates the Kullback-Leibler (KL) divergence between two multivariate Gaussian distributions, i.e., KL(N(mu1, Sigma1) || N(mu2, Sigma2)). Formally, this means this routine calculates: | KL(N1 || N2) = [log(det(Sigma2)/det(Sigma1)) + trace(Prec2 * Sigma1) + (z * Prec2 * z) - D] * (1/2) | where N1 is the 1st Gaussian, i.e., N(mu1,Sigma1), and N2 is the 2nd Gaussian, i.e., N(mu2,Sigma2); | and where: Prec2 = (Sigma2)^{-1}, z = mu2 - mu1, and D is the data dimensionality Args: mu1: mean vector of first Gaussian distribution Sigma1: covariance matrix of first Gaussian distribution mu2: mean vector of second Gaussian distribution Sigma2: covariance matrix of second Gaussian distribution use_chol_prec: should this routine use Cholesky-factor computation of the precision of Sigma2 (Default: True) Returns: scalar representing KL-divergence between N(mu1, Sigma1) and N(mu2, Sigma2) """ D = mu1.shape[1] ## dimensionality of data ## log(|Sigma2|/|Sigma1|) = log(|Sigma2|) - log(|Sigma1|) sgn_s1, val_s1 = jnp.linalg.slogdet(Sigma1) log_detSigma1 = val_s1 * sgn_s1 sgn_s2, val_s2 = jnp.linalg.slogdet(Sigma2) log_detSigma2 = val_s2 * sgn_s2 if use_chol_prec: ## use Cholesky-factor calc of (Sigma2)^{-1} C = jnp.linalg.cholesky(Sigma2) ## cholesky factor matrix inv_C = jnp.linalg.pinv(C) Prec2 = jnp.matmul(inv_C.T, inv_C) else: Prec2 = jnp.linalg.pinv(Sigma2) ## pseudo-inverse calc of (Sigma2)^{-1} trace_term = jnp.trace(jnp.dot(Prec2, Sigma1)) ## trace term of KL divergence delta_mu = mu2 - mu1 quadratic_term = jnp.sum((jnp.matmul(delta_mu, Prec2) * delta_mu), axis=1, keepdims=True) #quadratic_term = jnp.matmul(jnp.matmul(delta_mu.T, Prec2), delta_mu) ## quadratic term of KL divergence # calc full KL divergence kld = ((log_detSigma2 - log_detSigma1) + quadratic_term + trace_term + quadratic_term - D) * 0.5 return kld
[docs] @partial(jit, static_argnums=[3]) def measure_CatNLL(p, x, offset=1e-7, preserve_batch=False): """ Measures the negative Categorical log likelihood (Cat.NLL). Note: If batch is preserved, this returns a column vector where each row is the Cat.NLL(p, x) for that row's datapoint. Args: p: predicted probabilities; (N x C matrix, where C is number of categories) x: true one-hot encoded targets; (N x C matrix, where C is number of categories) offset: factor to control for numerical stability (Default: 1e-7) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ p_ = jnp.clip(p, offset, 1.0 - offset) loss = -(x * jnp.log(p_)) nll = jnp.sum(loss, axis=1, keepdims=True) #/(y_true.shape[0] * 1.0) ## CatNLL-per-datapoint if not preserve_batch: nll = jnp.mean(nll) return nll #tf.reduce_mean(nll)
[docs] @partial(jit, static_argnums=[2]) def measure_RMSE(mu, x, preserve_batch=False): """ Measures root mean squared error (RMSE). Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint. (THis is a simple wrapper/extension of the in-built MSE.) Args: mu: predicted values (mean); (N x D matrix) x: target values (data); (N x D matrix) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ mse = measure_MSE(mu, x, preserve_batch=preserve_batch) return jnp.sqrt(mse) ## sqrt(MSE) is the root-mean-squared-error
[docs] @partial(jit, static_argnums=[2]) def measure_MSE(mu, x, preserve_batch=False): """ Measures mean squared error (MSE), or the negative Gaussian log likelihood with variance of 1.0. Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint. Args: mu: predicted values (mean); (N x D matrix) x: target values (data); (N x D matrix) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ diff = mu - x se = jnp.square(diff) ## squared error mse = jnp.sum(se, axis=1, keepdims=True) ## technically squared-error per data-point if not preserve_batch: mse = jnp.mean(mse) # this is proper mse return mse
[docs] @partial(jit, static_argnums=[2]) def measure_MAE(shift, x, preserve_batch=False): """ Measures mean absolute error (MAE), or the negative Laplacian log likelihood with scale of 1.0. Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint. Args: shift: predicted values (mean); (N x D matrix) x: target values (data); (N x D matrix) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ diff = shift - x se = jnp.abs(diff) ## squared error mae = jnp.sum(se, axis=1, keepdims=True) ## technically abs-error per data-point if not preserve_batch: mae = jnp.mean(mae) # this is proper mae return mae
[docs] @partial(jit, static_argnums=[3]) def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10 """ Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE). Note: If batch is preserved, this returns a column vector where each row is the BCE(p, x) for that row's datapoint. Args: p: predicted probabilities of shape; (N x D matrix) x: target binary values (data) of shape; (N x D matrix) offset: factor to control for numerical stability (Default: 1e-7) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ p_ = jnp.clip(p, offset, 1 - offset) bce = -jnp.sum(x * jnp.log(p_) + (1.0 - x) * jnp.log(1.0 - p_),axis=1, keepdims=True) ## BCE-per-datapoint if not preserve_batch: bce = jnp.mean(bce) return bce