Source code for ngclearn.utils.analysis.knn_probe

import jax
import numpy as np
from ngcsimlib import deprecate_args
from ngclearn.utils.analysis.probe import Probe
from jax import jit, random, numpy as jnp, lax, nn
from functools import partial as bind


@bind(jax.jit, static_argnums=[2, 3, 4])
def _run_knn_probe(_embeddings, Wx, K, dist_order=2, dist_metric="minkowski"):
    if dist_metric == "cosine":
        ## normalize the incoming batch embeddings along the feature axis (axis 1)
        ### add a tiny epsilon to prevent division-by-zero errors
        eps = 1e-12
        embed_norm = _embeddings / (jnp.linalg.norm(_embeddings, axis=1, keepdims=True) + eps)
        ## normalize the internal memory database array (axis 1)
        Wx_norm = Wx / (jnp.linalg.norm(Wx, axis=1, keepdims=True) + eps)
        ## compute batched cosine similarity using standard 2D matrix multiplication
        ### (B x D) @ (D x N) -> yields a (B x N) similarity matrix directly!
        dist = jnp.matmul(embed_norm, Wx_norm.T)
    else:  # Default back to your Minkowski setup
        _Wx = jnp.expand_dims(Wx, axis=0)  ## (1 x N x D)
        embed_tensor = jnp.expand_dims(_embeddings, axis=1)  ## (B x 1 x D)
        D = embed_tensor - _Wx  ## (B x N x D)
        dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True)  ## (B x N x 1)
        dist = -jnp.squeeze(dist, axis=2)  ## (B x N)

    # lax.top_k naturally grabs the maximums (highest similarity or smallest negative distance)
    values, indices = lax.top_k(dist, K)
    return values, indices


[docs] class KNNProbe(Probe): """ This implements a K-nearest neighbors (KNN) probe, which is useful for evaluating the quality of encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot encodings or real-valued vector regression targets). Args: dkey: init seed key source_seq_length: length of input sequence (e.g., height x width of the image feature) input_dim: input dimensionality of probe out_dim: output dimensionality of probe num_neighbors: number of nearest neighbors to perform estimate of output target with batch_size: size of batches to process per internal call to update (or process) K: number of nearest neighbors to estimate output target distance_function: tuple specifying distance function and its order for calculating nearest neighbors (Default: ("minkowski", 2)). usage guide: ("minkowski", 2) or ("euclidean", ?) => use L2 norm (Euclidean) distance; ("minkowski", 1) or ("manhattan", ?) => use L1 norm (taxi-cab/city-block) distance; ("minkowksi", jnp.inf) or ("chebyshev", ?) => use Chebyshev distance; ("minkowski", p > 2) => use a Minkowski distance of p-th order predictor_type: Str what type of problem is this K-NN solving? vote_style: """ @deprecate_args(K="num_neighbors") def __init__( self, dkey, source_seq_length, input_dim, out_dim, batch_size=1, num_neighbors=1, ## number of nearest neighbors (K) to find distance_function=("minkowski", 2), predictor_type="classifier", ## "classifier"; "regressor" vote_style="mode", ## "mode", "mean" **kwargs ): super().__init__(dkey, batch_size, **kwargs) self.dkey, *subkeys = random.split(self.dkey, 3) self.source_seq_length = source_seq_length self.input_dim = input_dim self.out_dim = out_dim self.K = num_neighbors self.vote_fx = 0 ## 0 -> mode prediction; 1 -> mean prediction if vote_style == "mean": self.vote_fx = 1 self.distance_function = distance_function dist_fun, dist_order = distance_function self.dist_metric = "minkowski" # default tracker if "cosine" in dist_fun.lower(): self.dist_metric = "cosine" dist_order = 2 ## fallback assignment elif "euclidean" in dist_fun.lower(): dist_order = 2 elif "manhattan" in dist_fun.lower(): dist_order = 1 elif "chebyshev" in dist_fun.lower(): dist_order = jnp.inf self.dist_order = dist_order ## set distance order p self.predictor_type = predictor_type self.pred_fx = 0 if "regressor" == predictor_type: self.pred_fx = 1 #flat_input_dim = input_dim * source_seq_length #W = jnp.zeros((flat_input_dim, out_dim)) Wx = Wy = jnp.ones((1, 1)) ## Wy will be assumed to be one-hot encoded self.probe_params = (Wx, Wy)
[docs] def process(self, embeddings, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) Wx, Wy = self.probe_params # Pass the explicit metric string directly to the JIT-compiled loop values, indices = _run_knn_probe( _embeddings, Wx, self.K, self.dist_order, self.dist_metric ) ## do K-neighbor voting scheme (find mode/frequency prediction) Y_counts = jnp.zeros((_embeddings.shape[0], Wy.shape[1])) for k in range(self.K): winner_k_indx = indices[:, k] ## batch of k-th set of K winners Y_k = Wy[winner_k_indx, :] ## predicted Y's of k-th winner batch Y_counts = Y_counts + Y_k ## do post-processing to conform to problem-type being solved by this K-NN if self.pred_fx == 1: ## (regressor, contus outputs) Y_pred = Y_counts * (1. / self.K) else: ## pred_fx == 0 (classifier, discrete outputs) Y_pred = Y_counts if self.vote_fx == 1: ## calc mean prediction Y_pred = Y_counts * (1. / self.K) ## vote_fx == 0 (mode prediction) Y_pred = nn.one_hot(jnp.argmax(Y_pred, axis=1), num_classes=Wy.shape[1]) # , keepdims=True) return Y_pred ## (B, C)
[docs] def update(self, embeddings, labels, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) ## a K-NN's learning phase is just storing the data internally directly Wx = _embeddings Wy = labels self.probe_params = (Wx, Wy)
# if __name__ == '__main__': # seed = 42 # D = 7 # C = 5 # dkey = random.PRNGKey(seed) # dkey, *subkeys = random.split(dkey, 3) # knn = KNNProbe( # subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean" # ) # X = random.uniform(subkeys[1], shape=(10, D)) # Y = jnp.concat( # [ # jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]), # jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]), # jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]), # jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]), # jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]]) # ], # axis=0 # ) # knn.update(X, Y) ## fit KNN to data # print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y