Source code for ngclearn.utils.analysis.kmeans_probe

import jax
from ngcsimlib import deprecate_args
from ngclearn.utils.analysis.probe import Probe
from jax import jit, random, numpy as jnp, lax, nn
from functools import partial as bind
from ngclearn.utils.metric_utils import measure_ARI

@bind(jax.jit, static_argnums=[2])
def _run_kmeans_probe(_embeddings, centroids, n_clusters):
    ## Broadcast distances: (n_samples, 1, n_features) - (1, n_clusters, n_features)
    distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1)
    labels_pred = jnp.argmin(distances, axis=1)
    ## Re-estimate centroids/means
    one_hot_preds = labels_pred[:, None] == jnp.arange(n_clusters)
    counts = jnp.maximum(one_hot_preds.sum(axis=0, keepdims=True).T, 1.0)
    centroids = jnp.dot(one_hot_preds.T.astype(jnp.float32), _embeddings) / counts
    return centroids

@bind(jax.jit, static_argnums=[2])
def _predict_with_probe(_embeddings, centroids, n_clusters):
    ## Final pass to compute stable predictions
    distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1)
    labels_pred = jnp.argmin(distances, axis=1)
    Y_pred = nn.one_hot(labels_pred, n_clusters)
    return labels_pred, Y_pred

[docs] class KMeansProbe(Probe): """ This implements a K-means clustering probe, which is useful for evaluating the quality of encodings/embeddings in light of the ability to cluster downstream data. Currently, this probe only supports L2/Euclidean distance-based clustering. Args: dkey: init seed key source_seq_length: length of input sequence (e.g., height x width of the image feature) input_dim: input dimensionality of probe out_dim: output dimensionality of probe - number of clusters for this probe to create batch_size: <Unused> """ def __init__( self, dkey, source_seq_length, input_dim, out_dim=2, ## number of clusters/centroids to uncover batch_size=1, **kwargs ): super().__init__(dkey, batch_size, **kwargs) self.dkey, *subkeys = random.split(self.dkey, 3) self.source_seq_length = source_seq_length self.input_dim = input_dim self.n_clusters = self.out_dim = out_dim ## centroids that will be uncovered by this probe self.centroids : jax.Array = None def _init(self, embeddings): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) ## choose random data-points to serve as centroids at iteration 0 self.dkey, *subkeys = random.split(self.dkey, 15) n_samples, n_features = _embeddings.shape random_indices = random.choice( subkeys[0], n_samples, shape=(self.n_clusters,), replace=False ) self.centroids = _embeddings[random_indices]
[docs] def process(self, embeddings, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) ## Compute final geometric vs semantic conformity via ARI _, Y_pred = _predict_with_probe(_embeddings, self.centroids, self.n_clusters) return Y_pred ## (B, C)
[docs] def update(self, embeddings, labels, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) self.centroids = _run_kmeans_probe(_embeddings, self.centroids, self.n_clusters) L = 0. ## FIXME: should be clustering loss predictions = self.process(_embeddings) return L, predictions
[docs] def fit(self, dataset, dev_dataset=None, n_iter=20, patience=20): data, labels = dataset _labels = jnp.argmax(labels, axis=-1) self._init(data) ## init K-means centroids ari = 0. for i in range(n_iter): ## Run vectorized K-Means optimization loop _L, py = self.update(data, labels) labels_pred = jnp.argmax(py, axis=1) ari_i = measure_ARI(_labels, labels_pred) print(f"\r{i}: ARI = {ari_i}", end="") if ari_i > ari: ari = ari_i print() return ari