ngclearn.utils.analysis package
Submodules
ngclearn.utils.analysis.attentive_probe module
- class ngclearn.utils.analysis.attentive_probe.AttentiveProbe(dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64, target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32, use_LN=True, use_LN_input=False, use_softmax=True, dropout=0.5, eta=0.0002, eta_decay=0.0, min_eta=1e-05, **kwargs)[source]
Bases:
ProbeThis implements a nonlinear attentive probe, which is useful for evaluating the quality of encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot encodings or real-valued vector regression targets).
- Parameters:
dkey – init seed key
source_seq_length – length of input sequence (e.g., height x width of the image feature)
input_dim – input dimensionality of probe
out_dim – output dimensionality of probe
num_heads – number of cross-attention heads
head_dim – output dimensionality of each cross-attention head
target_seq_length – to pool, we set it at one (or map the source sequence to the target sequence of length 1)
learnable_query_dim – target sequence dim (output dimension of cross-attention portion of probe)
batch_size – size of batches to process per internal call to update (or process)
hid_dim – dimensionality of hidden layer(s) of MLP portion of probe
use_LN – should layer normalization be used within MLP portions of probe or not?
use_softmax – should a softmax be applied to output of probe or not?
- process(embeddings, dkey=None)[source]
Runs the probe’s inference scheme given an input batch of sequences of encodings/embeddings.
- Parameters:
embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
dkey – Optional JAX noise key
- Returns:
probe output scores/probability values
- update(embeddings, labels, dkey=None)[source]
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally assigned labels/target vector values.
- Parameters:
embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
labels – target values that map to embedding sequence; shape (B, target_value_dim)
dkey – Optional JAX noise key
- Returns:
probe output scores/probability values
- ngclearn.utils.analysis.attentive_probe.cross_attention(dkey, params: tuple, x1: Array, x2: Array, mask: Array, n_heads: int = 8, dropout_rate: float = 0.0) Array[source]
Run cross-attention function given a list of parameters and two sequences (x1 and x2). The function takes in a query sequence x1 and a key-value sequence x2, and returns an output of the same shape as x1. T is the length of the query sequence, and S is the length of the key-value sequence. Dq is the dimension of the query sequence, and Dkv is the dimension of the key-value sequence. H is the number of attention heads.
- Parameters:
dkey – JAX key to trigger any internal noise (drop-out)
params (tuple) – tuple of parameters
x1 (jax.Array) – query sequence. Shape: (B, T, Dq)
x2 (jax.Array) – key-value sequence. Shape: (B, S, Dkv)
mask (jax.Array) – mask tensor. Shape: (B, T, S)
n_heads (int, optional) – number of attention heads. Defaults to 8.
dropout_rate (float, optional) – dropout rate. Defaults to 0.0.
- Returns:
output of cross-attention
- Return type:
jax.Array
- ngclearn.utils.analysis.attentive_probe.eval_attention_probe(dkey, params, encodings, labels, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True)[source]
Runs and evaluates the nonlinear attentive probe given a paired set of encoding vectors and externally assigned labels/regression targets.
- Parameters:
dkey – JAX key to trigger any internal noise (as in drop-out)
params – parameters tuple/list of probe
encodings – input encoding vectors/data
labels – output target values (e.g., labels, regression target vectors)
mask – optional mask to be applied to internal cross-attention
n_heads – number of attention heads
dropout – if >0, triggers drop-out applied internally to cross-attention
use_LN – use layer normalization?
use_softmax – should softmax be applied to output of attention probe? (useful for classification)
- Returns:
current loss value, output scores/probabilities
- ngclearn.utils.analysis.attentive_probe.masked_fill(x: Array, mask: Array, value=0) Array[source]
Return an output with masked condition, with non-masked value be the other value
- Parameters:
x (jax.Array) – _description_
mask (jax.Array) – _description_
value (int, optional) – _description_. Defaults to 0.
- Returns:
_description_
- Return type:
jax.Array
- ngclearn.utils.analysis.attentive_probe.run_attention_probe(dkey, params, encodings, mask, n_heads: int, dropout: float = 0.0, use_LN=False, use_LN_input=False, use_softmax=True)[source]
Runs full nonlinear attentive probe on input encodings (typically embedding vectors produced by some other model).
- Parameters:
dkey – JAX key for any internal noise to be applied
params – parameters tuple/list of probe
encodings – input encoding vectors/data
mask – optional mask to be applied to internal cross-attention
n_heads – number of attention heads
dropout – if >0, triggers drop-out applied internally to cross-attention
use_LN – use layer normalization?
use_LN_input – use layer normalization on input encodings?
use_softmax – should softmax be applied to output of attention probe? (useful for classification)
- Returns:
output scores/probabilities, cross-attention (hidden) features
ngclearn.utils.analysis.linear_probe module
- class ngclearn.utils.analysis.linear_probe.LinearProbe(dkey, source_seq_length, input_dim, out_dim, batch_size=1, use_LN=False, use_softmax=False, **kwargs)[source]
Bases:
ProbeThis implements a regularized linear probe, which is useful for evaluating the quality of encodings/embeddings in light of some superivsory downstream data (e.g., label one-hot encodings or real-valued vector regression targets). Note that this probe allows for configurable Elastic-net (L1+L2) regularization.
- Parameters:
dkey – init seed key
source_seq_length – length of input sequence (e.g., height x width of the image feature)
input_dim – input dimensionality of probe
out_dim – output dimensionality of probe
batch_size – size of batches to process per internal call to update (or process)
use_LN – should layer normalization be used on incoming input vectors given to this probe?
use_softmax – should a softmax be applied to output of probe or not?
- process(embeddings, dkey=None)[source]
Runs the probe’s inference scheme given an input batch of sequences of encodings/embeddings.
- Parameters:
embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
dkey – Optional JAX noise key
- Returns:
probe output scores/probability values
- update(embeddings, labels, dkey=None)[source]
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally assigned labels/target vector values.
- Parameters:
embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
labels – target values that map to embedding sequence; shape (B, target_value_dim)
dkey – Optional JAX noise key
- Returns:
probe output scores/probability values
ngclearn.utils.analysis.probe module
- class ngclearn.utils.analysis.probe.Probe(dkey, batch_size=1, dev_batch_size=1, **kwargs)[source]
Bases:
objectGeneral framework for an analysis probe (that may or may not be learnable in an iterative fashion).
- Parameters:
dkey – init seed key
batch_size – size of batches to process per internal call to update (or process)
- fit(dataset, dev_dataset=None, n_iter=50, patience=20)[source]
Fits this probe to a pool of data.
- Parameters:
dataset – a dataset tuple containing two design tensors/matrices (X, Y), with the first containing encoding vector sequences of shape (N, T, embed_dim) or (N, embed_dim) and the second containing the corresponding labels/targets for the embedding data of shape (N, target_dim); (Default: None)
dev_dataset – an optional development set tuple, with same format as dataset (Default: None)
n_iter – number of iterations to run model fitting (Default: 50 iterations)
patience – number of iterations of improvement (decrease) in loss before early-stopping enacted
- Returns:
best accuracy found over fitting run
- predict(data, batch_size=None)[source]
Runs this probe’s inference scheme over a pool of data.
- Parameters:
data – a dataset or design tensor/matrix containing encoding vector sequences; shape (N, T, embed_dim) or (N, embed_dim)
batch_size – optional batch-size argument (Default: None, will use training batch size)
- Returns:
the output scores/predictions made by this probe
- process(embeddings, dkey=None)[source]
Runs the probe’s inference scheme given an input batch of sequences of encodings/embeddings.
- Parameters:
embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
dkey – Optional JAX noise key
- Returns:
probe output scores/probability values
- update(embeddings, labels, dkey=None)[source]
Runs and updates this probe given an input batch of sequences of encodings/embeddings and their externally assigned labels/target vector values.
- Parameters:
embeddings – a 3D tensor containing a batch of encoding sequences; shape (B, T, embed_dim)
labels – target values that map to embedding sequence; shape (B, target_value_dim)
dkey – Optional JAX noise key
- Returns:
probe output scores/probability values