Source code for ngclearn.utils.analysis.knn_probe

import jax
import numpy as np
from ngcsimlib import deprecate_args
from ngclearn.utils.analysis.probe import Probe
from ngclearn.utils.model_utils import kwta
from jax import jit, random, numpy as jnp, lax, nn
from functools import partial as bind
from ngclearn.utils.distribution_generator import DistributionGenerator

@bind(jax.jit, static_argnums=[2, 3])
def _run_knn_probe(_embeddings, Wx, K, dist_order=2):
    ## Notes:
    ### We do some 3D tensor math to handle a batch of predictions that need to be made
    ### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
    _Wx = jnp.expand_dims(Wx, axis=0)  ## 3D tensor format of KNN params (1 x N x D)
    embed_tensor = jnp.expand_dims(_embeddings, axis=1)  ## 3D projection of input signals (B x 1 x D)
    D = embed_tensor - _Wx  ## compute 3D batched delta tensor (B x N x D)
    ## get batched (negative) distance measurements
    dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True)  ## (B x N x 1)
    ## else, default -> euclidean
    ### Note: negative distance allows us to find minimal points w/ maximal functions
    dist = -jnp.squeeze(dist, axis=2)  ## (B x N)
    ## now get K winners per sample in batch
    values, indices = lax.top_k(dist, K)
    return values, indices

[docs] class KNNProbe(Probe): """ This implements a K-nearest neighbors (KNN) probe, which is useful for evaluating the quality of encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot encodings or real-valued vector regression targets). Args: dkey: init seed key source_seq_length: length of input sequence (e.g., height x width of the image feature) input_dim: input dimensionality of probe out_dim: output dimensionality of probe num_neighbors: number of nearest neighbors to perform estimate of output target with batch_size: size of batches to process per internal call to update (or process) K: number of nearest neighbors to estimate output target distance_function: tuple specifying distance function and its order for calculating nearest neighbors (Default: ("minkowski", 2)). usage guide: ("minkowski", 2) or ("euclidean", ?) => use L2 norm (Euclidean) distance; ("minkowski", 1) or ("manhattan", ?) => use L1 norm (taxi-cab/city-block) distance; ("minkowksi", jnp.inf) or ("chebyshev", ?) => use Chebyshev distance; ("minkowski", p > 2) => use a Minkowski distance of p-th order predictor_type: Str what type of problem is this K-NN solving? vote_style: """ @deprecate_args(K="num_neighbors") def __init__( self, dkey, source_seq_length, input_dim, out_dim, batch_size=1, num_neighbors=1, ## number of nearest neighbors (K) to find distance_function=("minkowski", 2), predictor_type="classifier", ## "classifier"; "regressor" vote_style="mode", ## "mode", "mean" **kwargs ): super().__init__(dkey, batch_size, **kwargs) self.dkey, *subkeys = random.split(self.dkey, 3) self.source_seq_length = source_seq_length self.input_dim = input_dim self.out_dim = out_dim self.K = num_neighbors self.vote_fx = 0 ## 0 -> mode prediction; 1 -> mean prediction if vote_style == "mean": self.vote_fx = 1 self.distance_function = distance_function dist_fun, dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean if "euclidean" in dist_fun.lower(): dist_order = 2 elif "manhattan" in dist_fun.lower(): dist_order = 1 elif "chebyshev" in dist_fun.lower(): dist_order = jnp.inf ## TODO: add in cosine-distance (and maybe Mahalanobis distance) self.dist_order = dist_order ## set distance order p self.predictor_type = predictor_type self.pred_fx = 0 if "regressor" == predictor_type: self.pred_fx = 1 #flat_input_dim = input_dim * source_seq_length #W = jnp.zeros((flat_input_dim, out_dim)) Wx = Wy = jnp.ones((1, 1)) ## Wy will be assumed to be one-hot encoded self.probe_params = (Wx, Wy)
[docs] def process(self, embeddings, dkey=None): ## TODO: JIT-i-fy this _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) Wx, Wy = self.probe_params ## pull out KNN parameters values, indices = _run_knn_probe(_embeddings, Wx, self.K, self.dist_order) ## do K-neighbor voting scheme (find mode/frequency prediction) Y_counts = jnp.zeros((_embeddings.shape[0], Wy.shape[1])) for k in range(self.K): winner_k_indx = indices[:, k] ## batch of k-th set of K winners Y_k = Wy[winner_k_indx, :] ## predicted Y's of k-th winner batch Y_counts = Y_counts + Y_k ## do post-processing to conform to problem-type being solved by this K-NN if self.pred_fx == 1: ## (regressor, contus outputs) Y_pred = Y_counts * (1. / self.K) else: ## pred_fx == 0 (classifier, discrete outputs) Y_pred = Y_counts if self.vote_fx == 1: ## calc mean prediction Y_pred = Y_counts * (1. / self.K) ## vote_fx == 0 (mode prediction) Y_pred = nn.one_hot(jnp.argmax(Y_pred, axis=1), num_classes=Wy.shape[1]) # , keepdims=True) return Y_pred ## (B, C)
[docs] def update(self, embeddings, labels, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) ## a K-NN's learning phase is just storing the data internally directly Wx = _embeddings Wy = labels self.probe_params = (Wx, Wy)
if __name__ == '__main__': seed = 42 D = 7 C = 5 dkey = random.PRNGKey(seed) dkey, *subkeys = random.split(dkey, 3) knn = KNNProbe( subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean" ) X = random.uniform(subkeys[1], shape=(10, D)) Y = jnp.concat( [ jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]), jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]), jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]), jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]), jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]]) ], axis=0 ) knn.update(X, Y) ## fit KNN to data print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y