Source code for ngclearn.utils.metric_utils

"""
Metric and measurement routines and co-routines. These functions are useful for model-level/simulation analysis as well
as experimental inspection and probing (many of these are neuroscience-oriented measurement functions).
"""
from jax import numpy as jnp, jit
from functools import partial
from sklearn.metrics import confusion_matrix, precision_score, recall_score

[docs] @partial(jit, static_argnums=[1]) def measure_fanoFactor(spikes, preserve_batch=False): """ Calculates the Fano factor, i.e., a secondary statistics that probes the variability of a spike train within a particular time interval. Args: spikes: full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score Returns: a 1 x D Fano factor vector (one factor per neuron) OR a single average Fano factor across the neuronal group """ mu = jnp.mean(spikes, axis=0, keepdims=True) sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True)) fano = sigSqr/mu if not preserve_batch: fano = jnp.mean(fano) return fano
[docs] @partial(jit, static_argnums=[1]) def measure_firingRate(spikes, preserve_batch=False): """ Calculates the firing rate(s) of a group of neurons given full spike train.(s) Args: spikes: full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score Returns: a 1 x D firing rate vector (one firing rate per neuron) OR a single average firing rate across the neuronal group """ counts = jnp.sum(spikes, axis=0, keepdims=True) T = spikes.shape[0] * 1. fireRates = counts/T if not preserve_batch: fireRates = jnp.mean(fireRates) return fireRates
[docs] @partial(jit, static_argnums=[1]) def measure_breadth_TC(spikes, preserve_batch=False): """ Calculates the breath tuning curve (BTC) of a group of neurons given full spike train.(s). BTC measures the neural selectivity such that the sparse code distribution concentrates near zero with a heavy tail. For a neural layer where most of the neurons fire, the activity distribution is more uniformly spread and BTC > 0.5. When most of the neurons do not fire, the firing distribution is peaked at zero and BTC < 0.5. Args: spikes: full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster preserve_batch: if True, will return one score per neuron in train/window (Default: False), otherwise, returns scalar average score Returns: a 1 x D BTC vector (one factor per neuron) OR a single average BTC across the neuronal group """ mu = jnp.mean(spikes, axis=0, keepdims=True) sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True)) C = sigSqr/mu BTC = 1./(1 + jnp.square(C)) if not preserve_batch: BTC = jnp.mean(BTC) return BTC
[docs] @partial(jit, static_argnums=[1]) def measure_gini_index(codes, preserve_batch=True): """ Calculates the gini index a group of neurons represented as vector code samples. Gini index measures the sparseness of the values within each vector code, where a higher index value indicates higher sparsity and a lower index value indicates a lower sparsity (higher density). Args: codes: a batch of neural codes; shape is (N x D) where D is number of neurons in a group/cluster and N is number of samples preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar average score Returns: a N x 1 Gini index vector (one score per neuron) OR a single average Gini score for the whole sample/set of codes """ ## Gini index ### values closer to 1 indicate high sparsity (sparser codes) ### values closer to 0 indicate lower sparsity (denser codes) _codes = codes + (jnp.sum(codes, axis=1, keepdims=True) <= 0.) + 1e-8 ### note that the calculation below is faster than the mean-absolute-value ### form of gini-index; below calculation requires sorting but yields a ### lower-complexity calculation D = codes.shape[1] ## length of vector codes_sorted = jnp.sort(jnp.abs(_codes), axis=1) ## sort all codes w/in batch matrix index = jnp.arange(1, D + 1) term1 = jnp.sum((2 * index - D - 1) * codes_sorted, axis=1, keepdims=True) term2 = D * jnp.sum(codes_sorted, axis=1, keepdims=True) gini = term1 / term2 ## calc final ratio if not preserve_batch: gini = jnp.mean(gini) ## this is the mean gini-index return gini
[docs] @partial(jit, static_argnums=[2, 3]) def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=False): """ Calculates the sparsity (ratio) of an input matrix, assuming each row within this matrix is a non-negative vector. Formally, this means we compute, per i-th row: | rho(x_i) = num_zeros(x_i) / dim(x_i) and for a global score for matrix X with N codes/rows, we measure: | rho_mean(X) = 1/N Sum^N_{i=1} rho(x_i) where lower/closer to 0 means codes more sparse and closer to 1 means codes are more dense. Note that this definition of sparsity aligns with Foldiak's definition of the ratio of active neurons to inactive ones (assuming binary coding): | Foldiak, Peter. "Sparse and explicit neural coding." Principles of neural | coding. CRC Press, 2013. 379-389. Args: codes: matrix (shape: N x D) of non-negative codes to measure sparsity of (per row) tolerance: lowest number to consider as "empty"/non-existent (Default: 0.) preserve_batch: if True, will return one score per sample (N x 1) in batch (Default: True), otherwise, returns scalar average/mean score flip_measure: if True, will score sparsity via "1 - nzero/dim" (Default: False) Returns: sparsity measurements per code (shape: N x 1) or single score (shape: 1 x 1) """ dim = codes.shape[1] m = (codes > tolerance).astype(jnp.float32) rho = jnp.sum(m, axis=1, keepdims=True)/(dim * 1.) ## per-code sparsity if flip_measure: ## closer to 1 = more sparse, closer to 0, more dense rho = 1. - rho if not preserve_batch: rho = jnp.mean(rho) return rho
#@partial(jit, static_argnums=[2])
[docs] def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation statistics """ Analyzes a set of prediction matrix and target/ground-truth matrix or vector. Args: mu: prediction (design) matrix; shape is (N x C) where C is number of classes and N is the number of patterns examined y: target / ground-truth (design) matrix; shape is (N x C) OR an array of class integers of length N (with "extract_label_indx = True") extract_label_indx: wehn True, run an argmax to pull class integer indices from "y", assuming y is a one-hot binary encoding matrix (Default: True), otherwise, if False, this treats "y" is an array of class integer indices of length N Returns: confusion matrix, precision, recall, misses (empty predictions/all-zero rows), accuracy, adjusted-accuracy (counts all misses as incorrect) """ miss_mask = (jnp.sum(mu, axis=1) == 0.) * 1. misses = jnp.sum(miss_mask) ## how many misses? labels = y if extract_label_indx: labels = jnp.argmax(y, axis=1) guesses = jnp.argmax(mu, axis=1) conf_matrix = confusion_matrix(labels, guesses) precision = precision_score(labels, guesses, average='macro') recall = recall_score(labels, guesses, average='macro') ## produce accuracy score measurements guess = jnp.argmax(mu, axis=1) ## gather all model/output guesses equality_mask = jnp.equal(guess, labels) * 1. ### compute raw accuracy acc = jnp.sum(equality_mask) / (y.shape[0] * 1.) ### compute hit-masked accuracy (adjusted accuracy adj_acc = jnp.sum(equality_mask * (1. - miss_mask)) / (y.shape[0] * 1.) ## output analysis statistics return conf_matrix, precision, recall, misses, acc, adj_acc
[docs] @partial(jit, static_argnums=[2]) def measure_ACC(mu, y, extract_label_indx=True): ## measures/calculates accuracy """ Calculates the accuracy (ACC) given a matrix of predictions and matrix of targets. Args: mu: prediction (design) matrix; shape is (N x C) where C is number of classes and N is the number of patterns examined y: target / ground-truth (design) matrix; shape is (N x C) OR an array of class integers of length N (with "extract_label_indx = True") extract_label_indx: run an argmax to pull class integer indices from "y", assuming y is a one-hot binary encoding matrix (Default: True), otherwise, this assumes "y" is an array of class integer indices of length N Returns: scalar accuracy score """ guess = jnp.argmax(mu, axis=1) if extract_label_indx: lab = jnp.argmax(y, axis=1) acc = jnp.sum( jnp.equal(guess, lab) )/(y.shape[0] * 1.) return acc
[docs] @partial(jit, static_argnums=[3]) def measure_BIC(X, n_model_params, max_model_score, is_log=True): """ Measures the Bayesian information criterion (BIC) with respect to the final score obtained by the model on a given dataset. | BIC = -2 ln(L) + K * ln(N); | where N is number of data-points/rows of design matrix X, | K is total number parameters of the model of interest, and | L is the max/best-found value of a likelihood-like score L of the model Args: X: dataset/design matrix that a model was fit to (max-likelihood optimized) n_model_params: total number of model parameters (int) max_model_score: max likelihood-like score obtained by model on X is_log: is supplied `max_model_score` a log-likelihood? if this is False, this metric will apply a natural logarithm of the score (Default: True) Returns: scalar for the Bayesian information criterion score """ ## BIC = K * ln(N) - 2 ln(L) L_hat = max_model_score ## model's likelihood-like score (at max point) K = n_model_params ## number of model params N = X.shape[0] ## number of data-points if not is_log: L_hat = jnp.log(L_hat) ## get log likelihood bic = -L_hat * 2. + jnp.log(N * 1.) * K return bic
[docs] @partial(jit, static_argnums=[2]) def measure_KLD(p_xHat, p_x, preserve_batch=False): """ Measures the (raw) Kullback-Leibler divergence (KLD), assuming that the two input arguments contain valid probability distributions (in each row, if they are matrices). Note: If batch is preserved, this returns a column vector where each row is the KLD(x_pred, x_true) for that row's datapoint. (Further note that this function does not assume any particular distribution when calculating KLD) | Formula: | KLD(p_xHat, p_x) = (1/N) [ sum_i(p_x * jnp.log(p_x)) - sum_i(p_x * jnp.log(p_xHat)) ] | where sum_i implies summing across dimensions of vector-space of p_x Args: p_xHat: predicted probabilities; (N x C matrix, where C is number of categories) p_x: ground true probabilities; (N x C matrix, where C is number of categories) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ ## numerical control step offset = 1e-6 _p_x = jnp.clip(p_x, offset, 1. - offset) _p_xHat = jnp.clip(p_xHat, offset, 1. - offset) ## calc raw KLD scores N = p_x.shape[1] term1 = jnp.sum(_p_x * jnp.log(_p_x), axis=1, keepdims=True) # * (1/N) term2 = -jnp.sum(_p_x * jnp.log(_p_xHat), axis=1, keepdims=True) # * (1/N) kld = (term1 + term2) * (1/N) ## KLD-per-datapoint if not preserve_batch: kld = jnp.mean(kld) return kld
[docs] def measure_gaussian_KLD(mu1, Sigma1, mu2, Sigma2, use_chol_prec=True): """ Calculates the Kullback-Leibler (KL) divergence between two multivariate Gaussian distributions, i.e., KL(N(mu1, Sigma1) || N(mu2, Sigma2)). Formally, this means this routine calculates: | KL(N1 || N2) = [log(det(Sigma2)/det(Sigma1)) + trace(Prec2 * Sigma1) + (z * Prec2 * z) - D] * (1/2) | where N1 is the 1st Gaussian, i.e., N(mu1,Sigma1), and N2 is the 2nd Gaussian, i.e., N(mu2,Sigma2); | and where: Prec2 = (Sigma2)^{-1}, z = mu2 - mu1, and D is the data dimensionality Args: mu1: mean vector of first Gaussian distribution Sigma1: covariance matrix of first Gaussian distribution mu2: mean vector of second Gaussian distribution Sigma2: covariance matrix of second Gaussian distribution use_chol_prec: should this routine use Cholesky-factor computation of the precision of Sigma2 (Default: True) Returns: scalar representing KL-divergence between N(mu1, Sigma1) and N(mu2, Sigma2) """ D = mu1.shape[1] ## dimensionality of data ## log(|Sigma2|/|Sigma1|) = log(|Sigma2|) - log(|Sigma1|) sgn_s1, val_s1 = jnp.linalg.slogdet(Sigma1) log_detSigma1 = val_s1 * sgn_s1 sgn_s2, val_s2 = jnp.linalg.slogdet(Sigma2) log_detSigma2 = val_s2 * sgn_s2 if use_chol_prec: ## use Cholesky-factor calc of (Sigma2)^{-1} C = jnp.linalg.cholesky(Sigma2) ## cholesky factor matrix inv_C = jnp.linalg.pinv(C) Prec2 = jnp.matmul(inv_C.T, inv_C) else: Prec2 = jnp.linalg.pinv(Sigma2) ## pseudo-inverse calc of (Sigma2)^{-1} trace_term = jnp.trace(jnp.dot(Prec2, Sigma1)) ## trace term of KL divergence delta_mu = mu2 - mu1 quadratic_term = jnp.sum((jnp.matmul(delta_mu, Prec2) * delta_mu), axis=1, keepdims=True) #quadratic_term = jnp.matmul(jnp.matmul(delta_mu.T, Prec2), delta_mu) ## quadratic term of KL divergence # calc full KL divergence kld = ((log_detSigma2 - log_detSigma1) + quadratic_term + trace_term + quadratic_term - D) * 0.5 return kld
[docs] @partial(jit, static_argnums=[3]) def measure_CatNLL(p, x, offset=1e-7, preserve_batch=False): """ Measures the negative Categorical log likelihood (Cat.NLL). Note: If batch is preserved, this returns a column vector where each row is the Cat.NLL(p, x) for that row's datapoint. Args: p: predicted probabilities; (N x C matrix, where C is number of categories) x: true one-hot encoded targets; (N x C matrix, where C is number of categories) offset: factor to control for numerical stability (Default: 1e-7) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ p_ = jnp.clip(p, offset, 1.0 - offset) loss = -(x * jnp.log(p_)) nll = jnp.sum(loss, axis=1, keepdims=True) #/(y_true.shape[0] * 1.0) ## CatNLL-per-datapoint if not preserve_batch: nll = jnp.mean(nll) return nll #tf.reduce_mean(nll)
[docs] @partial(jit, static_argnums=[2]) def measure_RMSE(mu, x, preserve_batch=False): """ Measures root mean squared error (RMSE). Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint. (THis is a simple wrapper/extension of the in-built MSE.) Args: mu: predicted values (mean); (N x D matrix) x: target values (data); (N x D matrix) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ mse = measure_MSE(mu, x, preserve_batch=preserve_batch) return jnp.sqrt(mse) ## sqrt(MSE) is the root-mean-squared-error
[docs] @partial(jit, static_argnums=[2]) def measure_MSE(mu, x, preserve_batch=False): """ Measures mean squared error (MSE), or the negative Gaussian log likelihood with variance of 1.0. Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint. Args: mu: predicted values (mean); (N x D matrix) x: target values (data); (N x D matrix) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ diff = mu - x se = jnp.square(diff) ## squared error mse = jnp.sum(se, axis=1, keepdims=True) ## technically squared-error per data-point if not preserve_batch: mse = jnp.mean(mse) # this is proper mse return mse
[docs] @partial(jit, static_argnums=[2]) def measure_MAE(shift, x, preserve_batch=False): """ Measures mean absolute error (MAE), or the negative Laplacian log likelihood with scale of 1.0. Note: If batch is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint. Args: shift: predicted values (mean); (N x D matrix) x: target values (data); (N x D matrix) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ diff = shift - x se = jnp.abs(diff) ## squared error mae = jnp.sum(se, axis=1, keepdims=True) ## technically abs-error per data-point if not preserve_batch: mae = jnp.mean(mae) # this is proper mae return mae
[docs] @partial(jit, static_argnums=[3]) def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10 """ Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE). Note: If batch is preserved, this returns a column vector where each row is the BCE(p, x) for that row's datapoint. Args: p: predicted probabilities of shape; (N x D matrix) x: target binary values (data) of shape; (N x D matrix) offset: factor to control for numerical stability (Default: 1e-7) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ p_ = jnp.clip(p, offset, 1 - offset) bce = -jnp.sum(x * jnp.log(p_) + (1.0 - x) * jnp.log(1.0 - p_),axis=1, keepdims=True) ## BCE-per-datapoint if not preserve_batch: bce = jnp.mean(bce) return bce
[docs] @partial(jit, static_argnums=[1]) def measure_hoyer_sparsity(codes: jnp.ndarray, preserve_batch: bool=False) -> float: """ Measures the Hoyer sparsity for a set of latent codes. Hoyer sparsity lies in [0, 1], where a value of 0.0 indicates if something is dense and a value of 1 indicates something is extremely sparse. Args: codes: matrix (shape: N x D) of non-negative codes to measure sparsity of (per row); D is flattened latent code size preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ # Flatten everything past the batch dimension x = jnp.reshape(codes, (codes.shape[0], -1)) N = x.shape[1] l1 = jnp.sum(jnp.abs(x), axis=1) l2 = jnp.sqrt(jnp.sum(jnp.square(x), axis=1) + 1e-8) # epsilon to avoid division by zero hoyer = (jnp.sqrt(N) - (l1 / l2)) / (jnp.sqrt(N) - 1.0) if not preserve_batch: hoyer = jnp.mean(hoyer) # calc average sparsity across set/batch return hoyer
[docs] @partial(jit, static_argnums=[1]) def measure_excess_kurtosis(codes: jnp.ndarray, preserve_batch: bool=False) -> float: """ Measures the peak and heavy-tailedness of a set of neural activation codes. Note that higher values (> 0) indicate sparse, localized 'high-burst' activations. Args: codes: matrix (shape: N x D) of non-negative codes to measure sparsity of (per row) preserve_batch: if True, will return one score per sample in batch (Default: False), otherwise, returns scalar mean score Returns: an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise """ x = jnp.reshape(codes, (codes.shape[0], -1)) mean = jnp.mean(x, axis=1, keepdims=True) ## 1st moment variance = jnp.var(x, axis=1, keepdims=True) ## 2nd moment ## 4th central moment divided by variance squared fourth_moment = jnp.mean(jnp.power(x - mean, 4), axis=1, keepdims=True) kurtosis = fourth_moment / (jnp.square(variance) + 1e-8) ## kurtosis of distribution excess_kurtosis = kurtosis - 3.0 ## calc "excess kurtosis" by subtracting 3 if not preserve_batch: excess_kurtosis = jnp.mean(excess_kurtosis) ## calc avg excess-kurtosis over set/batch return excess_kurtosis
### class conformity metrics ### @partial(jit, static_argnums=[2, 3]) def _compute_contingency_table( ## vectorized construction of contingency matrix labels_true: jnp.ndarray, labels_pred: jnp.ndarray, n_classes: int, n_clusters: int ) -> jnp.ndarray: ## Computes a contingency matrix table ## This routine expects true integer labels and predicted integer labels (1D arrays of size N) # Create indicator masks across all unique classes/clusters # find unique IDs safely up to a static maximum size (or provide num_classes) # n_classes = n_true = jnp.max(labels_true) + 1 # n_clusters = n_pred = jnp.max(labels_pred) + 1 # Broadcast to form a full one-hot lookup map true_mask = labels_true[:, None] == jnp.arange(n_classes) pred_mask = labels_pred[:, None] == jnp.arange(n_clusters) # Contingency matrix is the matrix product of boolean indicators contingency = jnp.dot(true_mask.T.astype(jnp.float32), pred_mask.astype(jnp.float32)) return contingency
[docs] def measure_ARI( labels_true: jnp.ndarray, labels_pred: jnp.ndarray ) -> jnp.ndarray: """ Computes the adjusted random index (ARI), which measures similarity between two sets of indices (ground truth against a clustering's produced indices) via counting the pairs of data points assigned to same or different clusters (adjusted for chance). This measurement lies in `[0, 1]`, where `0` indicates a random labeling/assignment and `1` indicates perfect agreement. Args: labels_true: 1D array of shape (n_samples,) with true integer class labels. labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels. Returns: scalar ARI of these two sets of indices """ ## Dynamically find dimensions up to a statically bounded maximum n_classes = int(jnp.max(labels_true) + 1) n_clusters = int(jnp.max(labels_pred) + 1) return _calc_adjusted_rand_index(labels_true, labels_pred, n_classes, n_clusters)
@partial(jit, static_argnums=[2, 3]) def _calc_adjusted_rand_index( ## ARI labels_true: jnp.ndarray, labels_pred: jnp.ndarray, n_classes: int, n_clusters: int ) -> jnp.ndarray: n_samples = labels_true.shape[0] if n_samples <= 1: return jnp.array(1.0) ## Get contingency matrix (n_classes x n_clusters) contingency = _compute_contingency_table( labels_true, labels_pred, n_classes, n_clusters ) ## Calculate combination sums n_ijC2 = (n_ij * (n_ij - 1)) / 2 sum_nij_c2 = jnp.sum((contingency * (contingency - 1.0)) / 2.0) ## Sums across margins (rows and columns) sum_a = jnp.sum(contingency, axis=1) sum_b = jnp.sum(contingency, axis=0) ## Margin pair combinations sum_a_c2 = jnp.sum((sum_a * (sum_a - 1.0)) / 2.0) sum_b_c2 = jnp.sum((sum_b * (sum_b - 1.0)) / 2.0) ## Expected index and Max index math formulas total_c2 = (n_samples * (n_samples - 1.0)) / 2.0 expected_index = (sum_a_c2 * sum_b_c2) / total_c2 max_index = (sum_a_c2 + sum_b_c2) / 2.0 ## Prevent division by zero if everything is perfectly clustered or uniform denominator = max_index - expected_index ari = jnp.where(denominator == 0.0, 1.0, (sum_nij_c2 - expected_index) / denominator) return ari
[docs] def measure_FMI( labels_true: jnp.ndarray, labels_pred: jnp.ndarray ) -> jnp.ndarray: """ Calculates the Fowlkes-Mallows Index (FMI), which measures similarity between two sets of indices - this score is the geometric mean of pair-wise recall and precision. This measurement lies in `[0, 1]`, where higher is better (indicating greater similarity between two clustering sets of identifiers). Args: labels_true: 1D array of shape (n_samples,) with true integer class labels. labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels. Returns: scalar FMI of these two sets of indices """ ## Dynamically find dimensions up to a statically bounded maximum n_classes = int(jnp.max(labels_true) + 1) n_clusters = int(jnp.max(labels_pred) + 1) return _measure_fowlkes_mallows_index(labels_true, labels_pred, n_classes, n_clusters)
@partial(jit, static_argnums=[2, 3]) def _measure_fowlkes_mallows_index( ## FMI labels_true: jnp.ndarray, labels_pred: jnp.ndarray, n_classes: int, n_clusters: int ) -> jnp.ndarray: n_samples = labels_true.shape[0] # Handle edge case for single or empty samples safely if n_samples <= 1: return jnp.array(0.0, dtype=jnp.float32) contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters) ## Compute marginal sums (sums along rows and columns) sum_true = jnp.sum(contingency, axis=1) sum_pred = jnp.sum(contingency, axis=0) ## Calculate pairwise combinations using the matrix shortcut: nC2 = 0.5 * (sum(x^2) - N) # True Positives pair combinations (tk) tk = 0.5 * (jnp.sum(contingency ** 2) - n_samples) ## Total pairs clustered together in ground truth (tr) tr = 0.5 * (jnp.sum(sum_true ** 2) - n_samples) ## Total pairs clustered together in predictions (tc) tc = 0.5 * (jnp.sum(sum_pred ** 2) - n_samples) ## Compute FMI = tk / sqrt(tr * tc) # Prevent division by zero if there are no pair splits/matches denominator = jnp.sqrt(tr * tc) fmi = jnp.where(denominator == 0.0, 0.0, tk / denominator) return fmi
[docs] def measure_Vmeasure( ## V-Measure labels_true: jnp.ndarray, labels_pred: jnp.ndarray, beta: float = 1.0 ) -> jnp.ndarray: """ Calculates the V-Measure scoring metric for class conformity. This measurement compares predicted cluster indices ("labels_pred") against ground truth indices ("labels_true") and represents the harmonic mean of homogeneity (where each cluster contains only members of a single class) as well as completeness (where all members of a given class are assigned to the same cluster). This measurement (higher is better) lies in `[0,1]` where `1` indicates perfect, correct clustering. Args: labels_true: 1D array of shape (n_samples,) with true integer class labels labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels beta: Weight factor. Ratios > 1.0 favor completeness, < 1.0 favor homogeneity. Returns: scalar V-measure of these two sets of indices """ ## Dynamically find dimensions up to a statically bounded maximum n_classes = int(jnp.max(labels_true) + 1) n_clusters = int(jnp.max(labels_pred) + 1) return _measure_v_measure_score(labels_true, labels_pred, n_classes, n_clusters, beta)
@partial(jit, static_argnums=[2, 3, 4]) def _measure_v_measure_score( ## V-Measure labels_true: jnp.ndarray, labels_pred: jnp.ndarray, n_classes: int, n_clusters: int, beta: float = 1.0 ) -> jnp.ndarray: n_samples = labels_true.shape[0] ## Handle edge case for single or empty samples safely if n_samples <= 1: return jnp.array(0.0, dtype=jnp.float32) contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters) ## Calculate Marginal Sums (Row and Column totals) sum_true = jnp.sum(contingency, axis=1) sum_pred = jnp.sum(contingency, axis=0) ## Compute Base Entropies H(True) and H(Pred) p_true = sum_true / n_samples h_true = -jnp.sum(jnp.where(p_true > 0.0, p_true * jnp.log(p_true), 0.0)) p_pred = sum_pred / n_samples h_pred = -jnp.sum(jnp.where(p_pred > 0.0, p_pred * jnp.log(p_pred), 0.0)) ## Compute Joint Entropy H(True, Pred) p_joint = contingency / n_samples h_joint = -jnp.sum(jnp.where(p_joint > 0.0, p_joint * jnp.log(p_joint), 0.0)) ## Derive Conditional Entropies: H(True|Pred) and H(Pred|True) using identity rule h_true_given_pred = h_joint - h_pred h_pred_given_true = h_joint - h_true ## Compute Homogeneity (H) and Completeness (C) ## If base entropy is 0, the metric is perfectly satisfied (1.0) homogeneity = jnp.where(h_true == 0.0, 1.0, 1.0 - (h_true_given_pred / h_true)) completeness = jnp.where(h_pred == 0.0, 1.0, 1.0 - (h_pred_given_true / h_pred)) ## Compute Weighted Harmonic Mean (V-Measure) denominator = beta * homogeneity + completeness ## Prevent division by zero if both metrics are zero v_measure = jnp.where( denominator == 0.0, 0.0, (1.0 + beta) * homogeneity * completeness / denominator ) return v_measure